diff --git a/src/Billing/Constants/HandledStripeWebhook.cs b/src/Billing/Constants/HandledStripeWebhook.cs index f40b370f4..f7baa4675 100644 --- a/src/Billing/Constants/HandledStripeWebhook.cs +++ b/src/Billing/Constants/HandledStripeWebhook.cs @@ -2,12 +2,12 @@ public static class HandledStripeWebhook { - public static string SubscriptionDeleted => "customer.subscription.deleted"; - public static string SubscriptionUpdated => "customer.subscription.updated"; - public static string UpcomingInvoice => "invoice.upcoming"; - public static string ChargeSucceeded => "charge.succeeded"; - public static string ChargeRefunded => "charge.refunded"; - public static string PaymentSucceeded => "invoice.payment_succeeded"; - public static string PaymentFailed => "invoice.payment_failed"; - public static string InvoiceCreated => "invoice.created"; + public const string SubscriptionDeleted = "customer.subscription.deleted"; + public const string SubscriptionUpdated = "customer.subscription.updated"; + public const string UpcomingInvoice = "invoice.upcoming"; + public const string ChargeSucceeded = "charge.succeeded"; + public const string ChargeRefunded = "charge.refunded"; + public const string PaymentSucceeded = "invoice.payment_succeeded"; + public const string PaymentFailed = "invoice.payment_failed"; + public const string InvoiceCreated = "invoice.created"; } diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index 445be1f52..c825d55bd 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -14,6 +14,7 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Options; using Stripe; +using Event = Stripe.Event; using TaxRate = Bit.Core.Entities.TaxRate; namespace Bit.Billing.Controllers; @@ -41,6 +42,7 @@ public class StripeController : Controller private readonly ITaxRateRepository _taxRateRepository; private readonly IUserRepository _userRepository; private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; public StripeController( GlobalSettings globalSettings, @@ -83,6 +85,7 @@ public class StripeController : Controller PrivateKey = globalSettings.Braintree.PrivateKey }; _currentContext = currentContext; + _globalSettings = globalSettings; } [HttpPost("webhook")] @@ -114,6 +117,12 @@ public class StripeController : Controller return new BadRequestResult(); } + // If the customer and server cloud regions don't match, early return 200 to avoid unnecessary errors + if (!await ValidateCloudRegionAsync(parsedEvent)) + { + return new OkResult(); + } + var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted); var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated); @@ -471,6 +480,68 @@ public class StripeController : Controller return new OkResult(); } + /// + /// Ensures that the customer associated with the parsed event's data is in the correct region for this server. + /// We use the customer instead of the subscription given that all subscriptions have customers, but not all + /// customers have subscriptions + /// + /// + /// true if the customer's region and the server's region match, otherwise false + /// + private async Task ValidateCloudRegionAsync(Event parsedEvent) + { + string customerRegion; + + var serverRegion = _globalSettings.BaseServiceUri.CloudRegion; + var eventType = parsedEvent.Type; + + switch (eventType) + { + case HandledStripeWebhook.SubscriptionDeleted: + case HandledStripeWebhook.SubscriptionUpdated: + { + var subscription = await GetSubscriptionAsync(parsedEvent, true, new List { "customer" }); + customerRegion = GetCustomerRegionFromMetadata(subscription.Customer.Metadata); + break; + } + case HandledStripeWebhook.ChargeSucceeded: + case HandledStripeWebhook.ChargeRefunded: + { + var charge = await GetChargeAsync(parsedEvent, true, new List { "customer" }); + customerRegion = GetCustomerRegionFromMetadata(charge.Customer.Metadata); + break; + } + case HandledStripeWebhook.UpcomingInvoice: + case HandledStripeWebhook.PaymentSucceeded: + case HandledStripeWebhook.PaymentFailed: + case HandledStripeWebhook.InvoiceCreated: + { + var invoice = await GetInvoiceAsync(parsedEvent, true, new List { "customer" }); + customerRegion = GetCustomerRegionFromMetadata(invoice.Customer.Metadata); + break; + } + default: + { + // For all Stripe events that we're not listening to, just return 200 + return false; + } + } + + return customerRegion == serverRegion; + } + + /// + /// Gets the region from the customer metadata. If no region is present, defaults to "US" + /// + /// + /// + private static string GetCustomerRegionFromMetadata(Dictionary customerMetadata) + { + return customerMetadata.TryGetValue("region", out var value) + ? value + : "US"; + } + private Tuple GetIdsFromMetaData(IDictionary metaData) { if (metaData == null || !metaData.Any()) @@ -732,7 +803,7 @@ public class StripeController : Controller invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; } - private async Task GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false) + private async Task GetChargeAsync(Event parsedEvent, bool fresh = false, List expandOptions = null) { if (!(parsedEvent.Data.Object is Charge eventCharge)) { @@ -743,7 +814,8 @@ public class StripeController : Controller return eventCharge; } var chargeService = new ChargeService(); - var charge = await chargeService.GetAsync(eventCharge.Id); + var chargeGetOptions = new ChargeGetOptions { Expand = expandOptions }; + var charge = await chargeService.GetAsync(eventCharge.Id, chargeGetOptions); if (charge == null) { throw new Exception("Charge is null. " + eventCharge.Id); @@ -751,7 +823,7 @@ public class StripeController : Controller return charge; } - private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false) + private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false, List expandOptions = null) { if (!(parsedEvent.Data.Object is Invoice eventInvoice)) { @@ -762,7 +834,8 @@ public class StripeController : Controller return eventInvoice; } var invoiceService = new InvoiceService(); - var invoice = await invoiceService.GetAsync(eventInvoice.Id); + var invoiceGetOptions = new InvoiceGetOptions { Expand = expandOptions }; + var invoice = await invoiceService.GetAsync(eventInvoice.Id, invoiceGetOptions); if (invoice == null) { throw new Exception("Invoice is null. " + eventInvoice.Id); @@ -770,9 +843,10 @@ public class StripeController : Controller return invoice; } - private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false) + private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false, + List expandOptions = null) { - if (!(parsedEvent.Data.Object is Subscription eventSubscription)) + if (parsedEvent.Data.Object is not Subscription eventSubscription) { throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id); } @@ -781,7 +855,8 @@ public class StripeController : Controller return eventSubscription; } var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(eventSubscription.Id); + var subscriptionGetOptions = new SubscriptionGetOptions { Expand = expandOptions }; + var subscription = await subscriptionService.GetAsync(eventSubscription.Id, subscriptionGetOptions); if (subscription == null) { throw new Exception("Subscription is null. " + eventSubscription.Id); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index ba9956153..1776f1993 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -4,6 +4,7 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Repositories; +using Bit.Core.Settings; using Microsoft.Extensions.Logging; using StaticStore = Bit.Core.Models.StaticStore; using TaxRate = Bit.Core.Entities.TaxRate; @@ -25,6 +26,7 @@ public class StripePaymentService : IPaymentService private readonly Braintree.IBraintreeGateway _btGateway; private readonly ITaxRateRepository _taxRateRepository; private readonly IStripeAdapter _stripeAdapter; + private readonly IGlobalSettings _globalSettings; public StripePaymentService( ITransactionRepository transactionRepository, @@ -33,7 +35,8 @@ public class StripePaymentService : IPaymentService ILogger logger, ITaxRateRepository taxRateRepository, IStripeAdapter stripeAdapter, - Braintree.IBraintreeGateway braintreeGateway) + Braintree.IBraintreeGateway braintreeGateway, + IGlobalSettings globalSettings) { _transactionRepository = transactionRepository; _userRepository = userRepository; @@ -42,6 +45,7 @@ public class StripePaymentService : IPaymentService _taxRateRepository = taxRateRepository; _stripeAdapter = stripeAdapter; _btGateway = braintreeGateway; + _globalSettings = globalSettings; } public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, @@ -51,9 +55,12 @@ public class StripePaymentService : IPaymentService Braintree.Customer braintreeCustomer = null; string stipeCustomerSourceToken = null; string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary(); + var stripeCustomerMetadata = new Dictionary + { + { "region", _globalSettings.BaseServiceUri.CloudRegion } + }; var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount; + paymentMethodType == PaymentMethodType.BankAccount; if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) { @@ -388,7 +395,7 @@ public class StripePaymentService : IPaymentService if (customer == null && !string.IsNullOrWhiteSpace(paymentToken)) { - var stripeCustomerMetadata = new Dictionary(); + var stripeCustomerMetadata = new Dictionary { { "region", _globalSettings.BaseServiceUri.CloudRegion } }; if (paymentMethodType == PaymentMethodType.PayPal) { var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); @@ -1185,9 +1192,12 @@ public class StripePaymentService : IPaymentService Braintree.Customer braintreeCustomer = null; string stipeCustomerSourceToken = null; string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary(); + var stripeCustomerMetadata = new Dictionary + { + { "region", _globalSettings.BaseServiceUri.CloudRegion } + }; var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount; + paymentMethodType == PaymentMethodType.BankAccount; var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp || paymentMethodType == PaymentMethodType.GoogleInApp; diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index c43e99fb9..d8a67e815 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -4,6 +4,7 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -51,6 +52,9 @@ public class StripePaymentServiceTests Id = "S-1", CurrentPeriodEnd = DateTime.Today.AddDays(10), }); + sutProvider.GetDependency() + .BaseServiceUri.CloudRegion + .Returns("US"); var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo, provider); @@ -67,7 +71,8 @@ public class StripePaymentServiceTests c.Source == paymentToken && c.PaymentMethod == null && c.Coupon == "msp-discount-35" && - !c.Metadata.Any() && + c.Metadata.Count == 1 && + c.Metadata["region"] == "US" && c.InvoiceSettings.DefaultPaymentMethod == null && c.Address.Country == taxInfo.BillingAddressCountry && c.Address.PostalCode == taxInfo.BillingAddressPostalCode && @@ -101,6 +106,9 @@ public class StripePaymentServiceTests Id = "S-1", CurrentPeriodEnd = DateTime.Today.AddDays(10), }); + sutProvider.GetDependency() + .BaseServiceUri.CloudRegion + .Returns("US"); var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); @@ -116,7 +124,8 @@ public class StripePaymentServiceTests c.Email == organization.BillingEmail && c.Source == paymentToken && c.PaymentMethod == null && - !c.Metadata.Any() && + c.Metadata.Count == 1 && + c.Metadata["region"] == "US" && c.InvoiceSettings.DefaultPaymentMethod == null && c.InvoiceSettings.CustomFields != null && c.InvoiceSettings.CustomFields[0].Name == "Organization" && @@ -154,6 +163,9 @@ public class StripePaymentServiceTests Id = "S-1", CurrentPeriodEnd = DateTime.Today.AddDays(10), }); + sutProvider.GetDependency() + .BaseServiceUri.CloudRegion + .Returns("US"); var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); @@ -169,7 +181,8 @@ public class StripePaymentServiceTests c.Email == organization.BillingEmail && c.Source == null && c.PaymentMethod == paymentToken && - !c.Metadata.Any() && + c.Metadata.Count == 1 && + c.Metadata["region"] == "US" && c.InvoiceSettings.DefaultPaymentMethod == paymentToken && c.InvoiceSettings.CustomFields != null && c.InvoiceSettings.CustomFields[0].Name == "Organization" && @@ -300,6 +313,10 @@ public class StripePaymentServiceTests CurrentPeriodEnd = DateTime.Today.AddDays(10), }); + sutProvider.GetDependency() + .BaseServiceUri.CloudRegion + .Returns("US"); + var customer = Substitute.For(); customer.Id.ReturnsForAnyArgs("Braintree-Id"); customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); @@ -323,8 +340,9 @@ public class StripePaymentServiceTests c.Description == organization.BusinessName && c.Email == organization.BillingEmail && c.PaymentMethod == null && - c.Metadata.Count == 1 && + c.Metadata.Count == 2 && c.Metadata["btCustomerId"] == "Braintree-Id" && + c.Metadata["region"] == "US" && c.InvoiceSettings.DefaultPaymentMethod == null && c.Address.Country == taxInfo.BillingAddressCountry && c.Address.PostalCode == taxInfo.BillingAddressPostalCode &&