diff --git a/src/Billing/Constants/HandledStripeWebhook.cs b/src/Billing/Constants/HandledStripeWebhook.cs
index f40b370f4..f7baa4675 100644
--- a/src/Billing/Constants/HandledStripeWebhook.cs
+++ b/src/Billing/Constants/HandledStripeWebhook.cs
@@ -2,12 +2,12 @@
public static class HandledStripeWebhook
{
- public static string SubscriptionDeleted => "customer.subscription.deleted";
- public static string SubscriptionUpdated => "customer.subscription.updated";
- public static string UpcomingInvoice => "invoice.upcoming";
- public static string ChargeSucceeded => "charge.succeeded";
- public static string ChargeRefunded => "charge.refunded";
- public static string PaymentSucceeded => "invoice.payment_succeeded";
- public static string PaymentFailed => "invoice.payment_failed";
- public static string InvoiceCreated => "invoice.created";
+ public const string SubscriptionDeleted = "customer.subscription.deleted";
+ public const string SubscriptionUpdated = "customer.subscription.updated";
+ public const string UpcomingInvoice = "invoice.upcoming";
+ public const string ChargeSucceeded = "charge.succeeded";
+ public const string ChargeRefunded = "charge.refunded";
+ public const string PaymentSucceeded = "invoice.payment_succeeded";
+ public const string PaymentFailed = "invoice.payment_failed";
+ public const string InvoiceCreated = "invoice.created";
}
diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs
index 445be1f52..c825d55bd 100644
--- a/src/Billing/Controllers/StripeController.cs
+++ b/src/Billing/Controllers/StripeController.cs
@@ -14,6 +14,7 @@ using Microsoft.AspNetCore.Mvc;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Options;
using Stripe;
+using Event = Stripe.Event;
using TaxRate = Bit.Core.Entities.TaxRate;
namespace Bit.Billing.Controllers;
@@ -41,6 +42,7 @@ public class StripeController : Controller
private readonly ITaxRateRepository _taxRateRepository;
private readonly IUserRepository _userRepository;
private readonly ICurrentContext _currentContext;
+ private readonly GlobalSettings _globalSettings;
public StripeController(
GlobalSettings globalSettings,
@@ -83,6 +85,7 @@ public class StripeController : Controller
PrivateKey = globalSettings.Braintree.PrivateKey
};
_currentContext = currentContext;
+ _globalSettings = globalSettings;
}
[HttpPost("webhook")]
@@ -114,6 +117,12 @@ public class StripeController : Controller
return new BadRequestResult();
}
+ // If the customer and server cloud regions don't match, early return 200 to avoid unnecessary errors
+ if (!await ValidateCloudRegionAsync(parsedEvent))
+ {
+ return new OkResult();
+ }
+
var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted);
var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated);
@@ -471,6 +480,68 @@ public class StripeController : Controller
return new OkResult();
}
+ ///
+ /// Ensures that the customer associated with the parsed event's data is in the correct region for this server.
+ /// We use the customer instead of the subscription given that all subscriptions have customers, but not all
+ /// customers have subscriptions
+ ///
+ ///
+ /// true if the customer's region and the server's region match, otherwise false
+ ///
+ private async Task ValidateCloudRegionAsync(Event parsedEvent)
+ {
+ string customerRegion;
+
+ var serverRegion = _globalSettings.BaseServiceUri.CloudRegion;
+ var eventType = parsedEvent.Type;
+
+ switch (eventType)
+ {
+ case HandledStripeWebhook.SubscriptionDeleted:
+ case HandledStripeWebhook.SubscriptionUpdated:
+ {
+ var subscription = await GetSubscriptionAsync(parsedEvent, true, new List { "customer" });
+ customerRegion = GetCustomerRegionFromMetadata(subscription.Customer.Metadata);
+ break;
+ }
+ case HandledStripeWebhook.ChargeSucceeded:
+ case HandledStripeWebhook.ChargeRefunded:
+ {
+ var charge = await GetChargeAsync(parsedEvent, true, new List { "customer" });
+ customerRegion = GetCustomerRegionFromMetadata(charge.Customer.Metadata);
+ break;
+ }
+ case HandledStripeWebhook.UpcomingInvoice:
+ case HandledStripeWebhook.PaymentSucceeded:
+ case HandledStripeWebhook.PaymentFailed:
+ case HandledStripeWebhook.InvoiceCreated:
+ {
+ var invoice = await GetInvoiceAsync(parsedEvent, true, new List { "customer" });
+ customerRegion = GetCustomerRegionFromMetadata(invoice.Customer.Metadata);
+ break;
+ }
+ default:
+ {
+ // For all Stripe events that we're not listening to, just return 200
+ return false;
+ }
+ }
+
+ return customerRegion == serverRegion;
+ }
+
+ ///
+ /// Gets the region from the customer metadata. If no region is present, defaults to "US"
+ ///
+ ///
+ ///
+ private static string GetCustomerRegionFromMetadata(Dictionary customerMetadata)
+ {
+ return customerMetadata.TryGetValue("region", out var value)
+ ? value
+ : "US";
+ }
+
private Tuple GetIdsFromMetaData(IDictionary metaData)
{
if (metaData == null || !metaData.Any())
@@ -732,7 +803,7 @@ public class StripeController : Controller
invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null;
}
- private async Task GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false)
+ private async Task GetChargeAsync(Event parsedEvent, bool fresh = false, List expandOptions = null)
{
if (!(parsedEvent.Data.Object is Charge eventCharge))
{
@@ -743,7 +814,8 @@ public class StripeController : Controller
return eventCharge;
}
var chargeService = new ChargeService();
- var charge = await chargeService.GetAsync(eventCharge.Id);
+ var chargeGetOptions = new ChargeGetOptions { Expand = expandOptions };
+ var charge = await chargeService.GetAsync(eventCharge.Id, chargeGetOptions);
if (charge == null)
{
throw new Exception("Charge is null. " + eventCharge.Id);
@@ -751,7 +823,7 @@ public class StripeController : Controller
return charge;
}
- private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false)
+ private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false, List expandOptions = null)
{
if (!(parsedEvent.Data.Object is Invoice eventInvoice))
{
@@ -762,7 +834,8 @@ public class StripeController : Controller
return eventInvoice;
}
var invoiceService = new InvoiceService();
- var invoice = await invoiceService.GetAsync(eventInvoice.Id);
+ var invoiceGetOptions = new InvoiceGetOptions { Expand = expandOptions };
+ var invoice = await invoiceService.GetAsync(eventInvoice.Id, invoiceGetOptions);
if (invoice == null)
{
throw new Exception("Invoice is null. " + eventInvoice.Id);
@@ -770,9 +843,10 @@ public class StripeController : Controller
return invoice;
}
- private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false)
+ private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false,
+ List expandOptions = null)
{
- if (!(parsedEvent.Data.Object is Subscription eventSubscription))
+ if (parsedEvent.Data.Object is not Subscription eventSubscription)
{
throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id);
}
@@ -781,7 +855,8 @@ public class StripeController : Controller
return eventSubscription;
}
var subscriptionService = new SubscriptionService();
- var subscription = await subscriptionService.GetAsync(eventSubscription.Id);
+ var subscriptionGetOptions = new SubscriptionGetOptions { Expand = expandOptions };
+ var subscription = await subscriptionService.GetAsync(eventSubscription.Id, subscriptionGetOptions);
if (subscription == null)
{
throw new Exception("Subscription is null. " + eventSubscription.Id);
diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs
index ba9956153..1776f1993 100644
--- a/src/Core/Services/Implementations/StripePaymentService.cs
+++ b/src/Core/Services/Implementations/StripePaymentService.cs
@@ -4,6 +4,7 @@ using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
+using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using StaticStore = Bit.Core.Models.StaticStore;
using TaxRate = Bit.Core.Entities.TaxRate;
@@ -25,6 +26,7 @@ public class StripePaymentService : IPaymentService
private readonly Braintree.IBraintreeGateway _btGateway;
private readonly ITaxRateRepository _taxRateRepository;
private readonly IStripeAdapter _stripeAdapter;
+ private readonly IGlobalSettings _globalSettings;
public StripePaymentService(
ITransactionRepository transactionRepository,
@@ -33,7 +35,8 @@ public class StripePaymentService : IPaymentService
ILogger logger,
ITaxRateRepository taxRateRepository,
IStripeAdapter stripeAdapter,
- Braintree.IBraintreeGateway braintreeGateway)
+ Braintree.IBraintreeGateway braintreeGateway,
+ IGlobalSettings globalSettings)
{
_transactionRepository = transactionRepository;
_userRepository = userRepository;
@@ -42,6 +45,7 @@ public class StripePaymentService : IPaymentService
_taxRateRepository = taxRateRepository;
_stripeAdapter = stripeAdapter;
_btGateway = braintreeGateway;
+ _globalSettings = globalSettings;
}
public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType,
@@ -51,9 +55,12 @@ public class StripePaymentService : IPaymentService
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
- var stripeCustomerMetadata = new Dictionary();
+ var stripeCustomerMetadata = new Dictionary
+ {
+ { "region", _globalSettings.BaseServiceUri.CloudRegion }
+ };
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
- paymentMethodType == PaymentMethodType.BankAccount;
+ paymentMethodType == PaymentMethodType.BankAccount;
if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken))
{
@@ -388,7 +395,7 @@ public class StripePaymentService : IPaymentService
if (customer == null && !string.IsNullOrWhiteSpace(paymentToken))
{
- var stripeCustomerMetadata = new Dictionary();
+ var stripeCustomerMetadata = new Dictionary { { "region", _globalSettings.BaseServiceUri.CloudRegion } };
if (paymentMethodType == PaymentMethodType.PayPal)
{
var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false);
@@ -1185,9 +1192,12 @@ public class StripePaymentService : IPaymentService
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
- var stripeCustomerMetadata = new Dictionary();
+ var stripeCustomerMetadata = new Dictionary
+ {
+ { "region", _globalSettings.BaseServiceUri.CloudRegion }
+ };
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
- paymentMethodType == PaymentMethodType.BankAccount;
+ paymentMethodType == PaymentMethodType.BankAccount;
var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp ||
paymentMethodType == PaymentMethodType.GoogleInApp;
diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs
index c43e99fb9..d8a67e815 100644
--- a/test/Core.Test/Services/StripePaymentServiceTests.cs
+++ b/test/Core.Test/Services/StripePaymentServiceTests.cs
@@ -4,6 +4,7 @@ using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
+using Bit.Core.Settings;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -51,6 +52,9 @@ public class StripePaymentServiceTests
Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10),
});
+ sutProvider.GetDependency()
+ .BaseServiceUri.CloudRegion
+ .Returns("US");
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo, provider);
@@ -67,7 +71,8 @@ public class StripePaymentServiceTests
c.Source == paymentToken &&
c.PaymentMethod == null &&
c.Coupon == "msp-discount-35" &&
- !c.Metadata.Any() &&
+ c.Metadata.Count == 1 &&
+ c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == null &&
c.Address.Country == taxInfo.BillingAddressCountry &&
c.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
@@ -101,6 +106,9 @@ public class StripePaymentServiceTests
Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10),
});
+ sutProvider.GetDependency()
+ .BaseServiceUri.CloudRegion
+ .Returns("US");
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo);
@@ -116,7 +124,8 @@ public class StripePaymentServiceTests
c.Email == organization.BillingEmail &&
c.Source == paymentToken &&
c.PaymentMethod == null &&
- !c.Metadata.Any() &&
+ c.Metadata.Count == 1 &&
+ c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == null &&
c.InvoiceSettings.CustomFields != null &&
c.InvoiceSettings.CustomFields[0].Name == "Organization" &&
@@ -154,6 +163,9 @@ public class StripePaymentServiceTests
Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10),
});
+ sutProvider.GetDependency()
+ .BaseServiceUri.CloudRegion
+ .Returns("US");
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo);
@@ -169,7 +181,8 @@ public class StripePaymentServiceTests
c.Email == organization.BillingEmail &&
c.Source == null &&
c.PaymentMethod == paymentToken &&
- !c.Metadata.Any() &&
+ c.Metadata.Count == 1 &&
+ c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == paymentToken &&
c.InvoiceSettings.CustomFields != null &&
c.InvoiceSettings.CustomFields[0].Name == "Organization" &&
@@ -300,6 +313,10 @@ public class StripePaymentServiceTests
CurrentPeriodEnd = DateTime.Today.AddDays(10),
});
+ sutProvider.GetDependency()
+ .BaseServiceUri.CloudRegion
+ .Returns("US");
+
var customer = Substitute.For();
customer.Id.ReturnsForAnyArgs("Braintree-Id");
customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() });
@@ -323,8 +340,9 @@ public class StripePaymentServiceTests
c.Description == organization.BusinessName &&
c.Email == organization.BillingEmail &&
c.PaymentMethod == null &&
- c.Metadata.Count == 1 &&
+ c.Metadata.Count == 2 &&
c.Metadata["btCustomerId"] == "Braintree-Id" &&
+ c.Metadata["region"] == "US" &&
c.InvoiceSettings.DefaultPaymentMethod == null &&
c.Address.Country == taxInfo.BillingAddressCountry &&
c.Address.PostalCode == taxInfo.BillingAddressPostalCode &&