1
0
mirror of https://github.com/bitwarden/server.git synced 2025-02-16 01:51:21 +01:00

[AC-1508] Stripe changes for the EU datacenter (#3092)

* Added region to customer metadata

* Updated webhook to filter out events for other DCs

* Flipped ternary to be positive, fixed indentation

* Updated to allow for unit testing andupdated tests

---------

Co-authored-by: cyprain-okeke <108260115+cyprain-okeke@users.noreply.github.com>
This commit is contained in:
Conner Turnbull 2023-07-20 17:00:40 -04:00 committed by GitHub
parent 1fe2f0fb57
commit a61290a3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 128 additions and 25 deletions

View File

@ -2,12 +2,12 @@
public static class HandledStripeWebhook public static class HandledStripeWebhook
{ {
public static string SubscriptionDeleted => "customer.subscription.deleted"; public const string SubscriptionDeleted = "customer.subscription.deleted";
public static string SubscriptionUpdated => "customer.subscription.updated"; public const string SubscriptionUpdated = "customer.subscription.updated";
public static string UpcomingInvoice => "invoice.upcoming"; public const string UpcomingInvoice = "invoice.upcoming";
public static string ChargeSucceeded => "charge.succeeded"; public const string ChargeSucceeded = "charge.succeeded";
public static string ChargeRefunded => "charge.refunded"; public const string ChargeRefunded = "charge.refunded";
public static string PaymentSucceeded => "invoice.payment_succeeded"; public const string PaymentSucceeded = "invoice.payment_succeeded";
public static string PaymentFailed => "invoice.payment_failed"; public const string PaymentFailed = "invoice.payment_failed";
public static string InvoiceCreated => "invoice.created"; public const string InvoiceCreated = "invoice.created";
} }

View File

@ -14,6 +14,7 @@ using Microsoft.AspNetCore.Mvc;
using Microsoft.Data.SqlClient; using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Stripe; using Stripe;
using Event = Stripe.Event;
using TaxRate = Bit.Core.Entities.TaxRate; using TaxRate = Bit.Core.Entities.TaxRate;
namespace Bit.Billing.Controllers; namespace Bit.Billing.Controllers;
@ -41,6 +42,7 @@ public class StripeController : Controller
private readonly ITaxRateRepository _taxRateRepository; private readonly ITaxRateRepository _taxRateRepository;
private readonly IUserRepository _userRepository; private readonly IUserRepository _userRepository;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
public StripeController( public StripeController(
GlobalSettings globalSettings, GlobalSettings globalSettings,
@ -83,6 +85,7 @@ public class StripeController : Controller
PrivateKey = globalSettings.Braintree.PrivateKey PrivateKey = globalSettings.Braintree.PrivateKey
}; };
_currentContext = currentContext; _currentContext = currentContext;
_globalSettings = globalSettings;
} }
[HttpPost("webhook")] [HttpPost("webhook")]
@ -114,6 +117,12 @@ public class StripeController : Controller
return new BadRequestResult(); return new BadRequestResult();
} }
// If the customer and server cloud regions don't match, early return 200 to avoid unnecessary errors
if (!await ValidateCloudRegionAsync(parsedEvent))
{
return new OkResult();
}
var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted); var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted);
var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated); var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated);
@ -471,6 +480,68 @@ public class StripeController : Controller
return new OkResult(); return new OkResult();
} }
/// <summary>
/// Ensures that the customer associated with the parsed event's data is in the correct region for this server.
/// We use the customer instead of the subscription given that all subscriptions have customers, but not all
/// customers have subscriptions
/// </summary>
/// <param name="parsedEvent"></param>
/// <returns>true if the customer's region and the server's region match, otherwise false</returns>
/// <exception cref="Exception"></exception>
private async Task<bool> ValidateCloudRegionAsync(Event parsedEvent)
{
string customerRegion;
var serverRegion = _globalSettings.BaseServiceUri.CloudRegion;
var eventType = parsedEvent.Type;
switch (eventType)
{
case HandledStripeWebhook.SubscriptionDeleted:
case HandledStripeWebhook.SubscriptionUpdated:
{
var subscription = await GetSubscriptionAsync(parsedEvent, true, new List<string> { "customer" });
customerRegion = GetCustomerRegionFromMetadata(subscription.Customer.Metadata);
break;
}
case HandledStripeWebhook.ChargeSucceeded:
case HandledStripeWebhook.ChargeRefunded:
{
var charge = await GetChargeAsync(parsedEvent, true, new List<string> { "customer" });
customerRegion = GetCustomerRegionFromMetadata(charge.Customer.Metadata);
break;
}
case HandledStripeWebhook.UpcomingInvoice:
case HandledStripeWebhook.PaymentSucceeded:
case HandledStripeWebhook.PaymentFailed:
case HandledStripeWebhook.InvoiceCreated:
{
var invoice = await GetInvoiceAsync(parsedEvent, true, new List<string> { "customer" });
customerRegion = GetCustomerRegionFromMetadata(invoice.Customer.Metadata);
break;
}
default:
{
// For all Stripe events that we're not listening to, just return 200
return false;
}
}
return customerRegion == serverRegion;
}
/// <summary>
/// Gets the region from the customer metadata. If no region is present, defaults to "US"
/// </summary>
/// <param name="customerMetadata"></param>
/// <returns></returns>
private static string GetCustomerRegionFromMetadata(Dictionary<string, string> customerMetadata)
{
return customerMetadata.TryGetValue("region", out var value)
? value
: "US";
}
private Tuple<Guid?, Guid?> GetIdsFromMetaData(IDictionary<string, string> metaData) private Tuple<Guid?, Guid?> GetIdsFromMetaData(IDictionary<string, string> metaData)
{ {
if (metaData == null || !metaData.Any()) if (metaData == null || !metaData.Any())
@ -732,7 +803,7 @@ public class StripeController : Controller
invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null;
} }
private async Task<Charge> GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false) private async Task<Charge> GetChargeAsync(Event parsedEvent, bool fresh = false, List<string> expandOptions = null)
{ {
if (!(parsedEvent.Data.Object is Charge eventCharge)) if (!(parsedEvent.Data.Object is Charge eventCharge))
{ {
@ -743,7 +814,8 @@ public class StripeController : Controller
return eventCharge; return eventCharge;
} }
var chargeService = new ChargeService(); var chargeService = new ChargeService();
var charge = await chargeService.GetAsync(eventCharge.Id); var chargeGetOptions = new ChargeGetOptions { Expand = expandOptions };
var charge = await chargeService.GetAsync(eventCharge.Id, chargeGetOptions);
if (charge == null) if (charge == null)
{ {
throw new Exception("Charge is null. " + eventCharge.Id); throw new Exception("Charge is null. " + eventCharge.Id);
@ -751,7 +823,7 @@ public class StripeController : Controller
return charge; return charge;
} }
private async Task<Invoice> GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false) private async Task<Invoice> GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false, List<string> expandOptions = null)
{ {
if (!(parsedEvent.Data.Object is Invoice eventInvoice)) if (!(parsedEvent.Data.Object is Invoice eventInvoice))
{ {
@ -762,7 +834,8 @@ public class StripeController : Controller
return eventInvoice; return eventInvoice;
} }
var invoiceService = new InvoiceService(); var invoiceService = new InvoiceService();
var invoice = await invoiceService.GetAsync(eventInvoice.Id); var invoiceGetOptions = new InvoiceGetOptions { Expand = expandOptions };
var invoice = await invoiceService.GetAsync(eventInvoice.Id, invoiceGetOptions);
if (invoice == null) if (invoice == null)
{ {
throw new Exception("Invoice is null. " + eventInvoice.Id); throw new Exception("Invoice is null. " + eventInvoice.Id);
@ -770,9 +843,10 @@ public class StripeController : Controller
return invoice; return invoice;
} }
private async Task<Subscription> GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false) private async Task<Subscription> GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false,
List<string> expandOptions = null)
{ {
if (!(parsedEvent.Data.Object is Subscription eventSubscription)) if (parsedEvent.Data.Object is not Subscription eventSubscription)
{ {
throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id); throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id);
} }
@ -781,7 +855,8 @@ public class StripeController : Controller
return eventSubscription; return eventSubscription;
} }
var subscriptionService = new SubscriptionService(); var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(eventSubscription.Id); var subscriptionGetOptions = new SubscriptionGetOptions { Expand = expandOptions };
var subscription = await subscriptionService.GetAsync(eventSubscription.Id, subscriptionGetOptions);
if (subscription == null) if (subscription == null)
{ {
throw new Exception("Subscription is null. " + eventSubscription.Id); throw new Exception("Subscription is null. " + eventSubscription.Id);

View File

@ -4,6 +4,7 @@ using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using StaticStore = Bit.Core.Models.StaticStore; using StaticStore = Bit.Core.Models.StaticStore;
using TaxRate = Bit.Core.Entities.TaxRate; using TaxRate = Bit.Core.Entities.TaxRate;
@ -25,6 +26,7 @@ public class StripePaymentService : IPaymentService
private readonly Braintree.IBraintreeGateway _btGateway; private readonly Braintree.IBraintreeGateway _btGateway;
private readonly ITaxRateRepository _taxRateRepository; private readonly ITaxRateRepository _taxRateRepository;
private readonly IStripeAdapter _stripeAdapter; private readonly IStripeAdapter _stripeAdapter;
private readonly IGlobalSettings _globalSettings;
public StripePaymentService( public StripePaymentService(
ITransactionRepository transactionRepository, ITransactionRepository transactionRepository,
@ -33,7 +35,8 @@ public class StripePaymentService : IPaymentService
ILogger<StripePaymentService> logger, ILogger<StripePaymentService> logger,
ITaxRateRepository taxRateRepository, ITaxRateRepository taxRateRepository,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
Braintree.IBraintreeGateway braintreeGateway) Braintree.IBraintreeGateway braintreeGateway,
IGlobalSettings globalSettings)
{ {
_transactionRepository = transactionRepository; _transactionRepository = transactionRepository;
_userRepository = userRepository; _userRepository = userRepository;
@ -42,6 +45,7 @@ public class StripePaymentService : IPaymentService
_taxRateRepository = taxRateRepository; _taxRateRepository = taxRateRepository;
_stripeAdapter = stripeAdapter; _stripeAdapter = stripeAdapter;
_btGateway = braintreeGateway; _btGateway = braintreeGateway;
_globalSettings = globalSettings;
} }
public async Task<string> PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, public async Task<string> PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType,
@ -51,9 +55,12 @@ public class StripePaymentService : IPaymentService
Braintree.Customer braintreeCustomer = null; Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null; string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null; string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>(); var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
paymentMethodType == PaymentMethodType.BankAccount; paymentMethodType == PaymentMethodType.BankAccount;
if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken))
{ {
@ -388,7 +395,7 @@ public class StripePaymentService : IPaymentService
if (customer == null && !string.IsNullOrWhiteSpace(paymentToken)) if (customer == null && !string.IsNullOrWhiteSpace(paymentToken))
{ {
var stripeCustomerMetadata = new Dictionary<string, string>(); var stripeCustomerMetadata = new Dictionary<string, string> { { "region", _globalSettings.BaseServiceUri.CloudRegion } };
if (paymentMethodType == PaymentMethodType.PayPal) if (paymentMethodType == PaymentMethodType.PayPal)
{ {
var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false);
@ -1185,9 +1192,12 @@ public class StripePaymentService : IPaymentService
Braintree.Customer braintreeCustomer = null; Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null; string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null; string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>(); var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
paymentMethodType == PaymentMethodType.BankAccount; paymentMethodType == PaymentMethodType.BankAccount;
var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp || var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp ||
paymentMethodType == PaymentMethodType.GoogleInApp; paymentMethodType == PaymentMethodType.GoogleInApp;

View File

@ -4,6 +4,7 @@ using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
@ -51,6 +52,9 @@ public class StripePaymentServiceTests
Id = "S-1", Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10), CurrentPeriodEnd = DateTime.Today.AddDays(10),
}); });
sutProvider.GetDependency<IGlobalSettings>()
.BaseServiceUri.CloudRegion
.Returns("US");
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo, provider); var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo, provider);
@ -67,7 +71,8 @@ public class StripePaymentServiceTests
c.Source == paymentToken && c.Source == paymentToken &&
c.PaymentMethod == null && c.PaymentMethod == null &&
c.Coupon == "msp-discount-35" && c.Coupon == "msp-discount-35" &&
!c.Metadata.Any() && c.Metadata.Count == 1 &&
c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == null && c.InvoiceSettings.DefaultPaymentMethod == null &&
c.Address.Country == taxInfo.BillingAddressCountry && c.Address.Country == taxInfo.BillingAddressCountry &&
c.Address.PostalCode == taxInfo.BillingAddressPostalCode && c.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
@ -101,6 +106,9 @@ public class StripePaymentServiceTests
Id = "S-1", Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10), CurrentPeriodEnd = DateTime.Today.AddDays(10),
}); });
sutProvider.GetDependency<IGlobalSettings>()
.BaseServiceUri.CloudRegion
.Returns("US");
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo);
@ -116,7 +124,8 @@ public class StripePaymentServiceTests
c.Email == organization.BillingEmail && c.Email == organization.BillingEmail &&
c.Source == paymentToken && c.Source == paymentToken &&
c.PaymentMethod == null && c.PaymentMethod == null &&
!c.Metadata.Any() && c.Metadata.Count == 1 &&
c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == null && c.InvoiceSettings.DefaultPaymentMethod == null &&
c.InvoiceSettings.CustomFields != null && c.InvoiceSettings.CustomFields != null &&
c.InvoiceSettings.CustomFields[0].Name == "Organization" && c.InvoiceSettings.CustomFields[0].Name == "Organization" &&
@ -154,6 +163,9 @@ public class StripePaymentServiceTests
Id = "S-1", Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10), CurrentPeriodEnd = DateTime.Today.AddDays(10),
}); });
sutProvider.GetDependency<IGlobalSettings>()
.BaseServiceUri.CloudRegion
.Returns("US");
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo);
@ -169,7 +181,8 @@ public class StripePaymentServiceTests
c.Email == organization.BillingEmail && c.Email == organization.BillingEmail &&
c.Source == null && c.Source == null &&
c.PaymentMethod == paymentToken && c.PaymentMethod == paymentToken &&
!c.Metadata.Any() && c.Metadata.Count == 1 &&
c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == paymentToken && c.InvoiceSettings.DefaultPaymentMethod == paymentToken &&
c.InvoiceSettings.CustomFields != null && c.InvoiceSettings.CustomFields != null &&
c.InvoiceSettings.CustomFields[0].Name == "Organization" && c.InvoiceSettings.CustomFields[0].Name == "Organization" &&
@ -300,6 +313,10 @@ public class StripePaymentServiceTests
CurrentPeriodEnd = DateTime.Today.AddDays(10), CurrentPeriodEnd = DateTime.Today.AddDays(10),
}); });
sutProvider.GetDependency<IGlobalSettings>()
.BaseServiceUri.CloudRegion
.Returns("US");
var customer = Substitute.For<Customer>(); var customer = Substitute.For<Customer>();
customer.Id.ReturnsForAnyArgs("Braintree-Id"); customer.Id.ReturnsForAnyArgs("Braintree-Id");
customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For<PaymentMethod>() }); customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For<PaymentMethod>() });
@ -323,8 +340,9 @@ public class StripePaymentServiceTests
c.Description == organization.BusinessName && c.Description == organization.BusinessName &&
c.Email == organization.BillingEmail && c.Email == organization.BillingEmail &&
c.PaymentMethod == null && c.PaymentMethod == null &&
c.Metadata.Count == 1 && c.Metadata.Count == 2 &&
c.Metadata["btCustomerId"] == "Braintree-Id" && c.Metadata["btCustomerId"] == "Braintree-Id" &&
c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == null && c.InvoiceSettings.DefaultPaymentMethod == null &&
c.Address.Country == taxInfo.BillingAddressCountry && c.Address.Country == taxInfo.BillingAddressCountry &&
c.Address.PostalCode == taxInfo.BillingAddressPostalCode && c.Address.PostalCode == taxInfo.BillingAddressPostalCode &&