1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

[AC-1693] Send InvoiceUpcoming Notification to Client Owners (#3319)

* Add Organization_ReadOwnerEmailAddresses SPROC

* Add IOrganizationRepository.GetOwnerEmailAddressesById

* Add SendInvoiceUpcoming overload for multiple emails

* Update InvoiceUpcoming handler to send multiple emails

* Cy's feedback

* Updates from testing

Hardened against missing entity IDs in Stripe events in the StripeEventService. Updated ValidateCloudRegion to not use a refresh/expansion for the customer because the invoice.upcoming event does not have an invoice.Id. Updated the StripeController's handling of invoice.upcoming to not use a refresh/expansion for the subscription because the invoice does not have an ID.

* Fix broken test
This commit is contained in:
Alex Morask 2023-10-23 13:46:29 -04:00 committed by GitHub
parent 18b43130e8
commit c442bae2bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 190 additions and 39 deletions

View File

@ -52,6 +52,7 @@ public class StripeController : Controller
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IStripeEventService _stripeEventService; private readonly IStripeEventService _stripeEventService;
private readonly IStripeFacade _stripeFacade;
public StripeController( public StripeController(
GlobalSettings globalSettings, GlobalSettings globalSettings,
@ -70,7 +71,8 @@ public class StripeController : Controller
ITaxRateRepository taxRateRepository, ITaxRateRepository taxRateRepository,
IUserRepository userRepository, IUserRepository userRepository,
ICurrentContext currentContext, ICurrentContext currentContext,
IStripeEventService stripeEventService) IStripeEventService stripeEventService,
IStripeFacade stripeFacade)
{ {
_billingSettings = billingSettings?.Value; _billingSettings = billingSettings?.Value;
_hostingEnvironment = hostingEnvironment; _hostingEnvironment = hostingEnvironment;
@ -97,6 +99,7 @@ public class StripeController : Controller
_currentContext = currentContext; _currentContext = currentContext;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_stripeEventService = stripeEventService; _stripeEventService = stripeEventService;
_stripeFacade = stripeFacade;
} }
[HttpPost("webhook")] [HttpPost("webhook")]
@ -209,48 +212,71 @@ public class StripeController : Controller
else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice)) else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice))
{ {
var invoice = await _stripeEventService.GetInvoice(parsedEvent); var invoice = await _stripeEventService.GetInvoice(parsedEvent);
var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); if (string.IsNullOrEmpty(invoice.SubscriptionId))
{
_logger.LogWarning("Received 'invoice.upcoming' Event with ID '{eventId}' that did not include a Subscription ID", parsedEvent.Id);
return new OkResult();
}
var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId);
if (subscription == null) if (subscription == null)
{ {
throw new Exception("Invoice subscription is null. " + invoice.Id); throw new Exception(
$"Received null Subscription from Stripe for ID '{invoice.SubscriptionId}' while processing Event with ID '{parsedEvent.Id}'");
} }
subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); var updatedSubscription = await VerifyCorrectTaxRateForCharge(invoice, subscription);
string email = null; var (organizationId, userId) = GetIdsFromMetaData(updatedSubscription.Metadata);
var ids = GetIdsFromMetaData(subscription.Metadata);
// org var invoiceLineItemDescriptions = invoice.Lines.Select(i => i.Description).ToList();
if (ids.Item1.HasValue)
async Task SendEmails(IEnumerable<string> emails)
{ {
// sponsored org var validEmails = emails.Where(e => !string.IsNullOrEmpty(e));
if (IsSponsoredSubscription(subscription))
{
await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value);
}
var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value); if (invoice.NextPaymentAttempt.HasValue)
if (org != null && OrgPlanForInvoiceNotifications(org))
{ {
email = org.BillingEmail; await _mailService.SendInvoiceUpcoming(
validEmails,
invoice.AmountDue / 100M,
invoice.NextPaymentAttempt.Value,
invoiceLineItemDescriptions,
true);
} }
} }
// user
else if (ids.Item2.HasValue) if (organizationId.HasValue)
{ {
var user = await _userService.GetUserByIdAsync(ids.Item2.Value); if (IsSponsoredSubscription(updatedSubscription))
{
await _validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value);
}
var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);
if (organization == null || !OrgPlanForInvoiceNotifications(organization))
{
return new OkResult();
}
await SendEmails(new List<string> { organization.BillingEmail });
var ownerEmails = await _organizationRepository.GetOwnerEmailAddressesById(organization.Id);
await SendEmails(ownerEmails);
}
else if (userId.HasValue)
{
var user = await _userService.GetUserByIdAsync(userId.Value);
if (user.Premium) if (user.Premium)
{ {
email = user.Email; await SendEmails(new List<string> { user.Email });
} }
} }
if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue)
{
var items = invoice.Lines.Select(i => i.Description).ToList();
await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M,
invoice.NextPaymentAttempt.Value, items, true);
}
} }
else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded)) else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded))
{ {

View File

@ -7,13 +7,16 @@ namespace Bit.Billing.Services.Implementations;
public class StripeEventService : IStripeEventService public class StripeEventService : IStripeEventService
{ {
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly ILogger<StripeEventService> _logger;
private readonly IStripeFacade _stripeFacade; private readonly IStripeFacade _stripeFacade;
public StripeEventService( public StripeEventService(
GlobalSettings globalSettings, GlobalSettings globalSettings,
ILogger<StripeEventService> logger,
IStripeFacade stripeFacade) IStripeFacade stripeFacade)
{ {
_globalSettings = globalSettings; _globalSettings = globalSettings;
_logger = logger;
_stripeFacade = stripeFacade; _stripeFacade = stripeFacade;
} }
@ -26,6 +29,12 @@ public class StripeEventService : IStripeEventService
return eventCharge; return eventCharge;
} }
if (string.IsNullOrEmpty(eventCharge.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Charge for Event with ID '{eventId}' because no Charge ID was included in the Event.", stripeEvent.Id);
return eventCharge;
}
var charge = await _stripeFacade.GetCharge(eventCharge.Id, new ChargeGetOptions { Expand = expand }); var charge = await _stripeFacade.GetCharge(eventCharge.Id, new ChargeGetOptions { Expand = expand });
if (charge == null) if (charge == null)
@ -46,6 +55,12 @@ public class StripeEventService : IStripeEventService
return eventCustomer; return eventCustomer;
} }
if (string.IsNullOrEmpty(eventCustomer.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Customer for Event with ID '{eventId}' because no Customer ID was included in the Event.", stripeEvent.Id);
return eventCustomer;
}
var customer = await _stripeFacade.GetCustomer(eventCustomer.Id, new CustomerGetOptions { Expand = expand }); var customer = await _stripeFacade.GetCustomer(eventCustomer.Id, new CustomerGetOptions { Expand = expand });
if (customer == null) if (customer == null)
@ -66,6 +81,12 @@ public class StripeEventService : IStripeEventService
return eventInvoice; return eventInvoice;
} }
if (string.IsNullOrEmpty(eventInvoice.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Invoice for Event with ID '{eventId}' because no Invoice ID was included in the Event.", stripeEvent.Id);
return eventInvoice;
}
var invoice = await _stripeFacade.GetInvoice(eventInvoice.Id, new InvoiceGetOptions { Expand = expand }); var invoice = await _stripeFacade.GetInvoice(eventInvoice.Id, new InvoiceGetOptions { Expand = expand });
if (invoice == null) if (invoice == null)
@ -86,6 +107,12 @@ public class StripeEventService : IStripeEventService
return eventPaymentMethod; return eventPaymentMethod;
} }
if (string.IsNullOrEmpty(eventPaymentMethod.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Payment Method for Event with ID '{eventId}' because no Payment Method ID was included in the Event.", stripeEvent.Id);
return eventPaymentMethod;
}
var paymentMethod = await _stripeFacade.GetPaymentMethod(eventPaymentMethod.Id, new PaymentMethodGetOptions { Expand = expand }); var paymentMethod = await _stripeFacade.GetPaymentMethod(eventPaymentMethod.Id, new PaymentMethodGetOptions { Expand = expand });
if (paymentMethod == null) if (paymentMethod == null)
@ -106,6 +133,12 @@ public class StripeEventService : IStripeEventService
return eventSubscription; return eventSubscription;
} }
if (string.IsNullOrEmpty(eventSubscription.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Subscription for Event with ID '{eventId}' because no Subscription ID was included in the Event.", stripeEvent.Id);
return eventSubscription;
}
var subscription = await _stripeFacade.GetSubscription(eventSubscription.Id, new SubscriptionGetOptions { Expand = expand }); var subscription = await _stripeFacade.GetSubscription(eventSubscription.Id, new SubscriptionGetOptions { Expand = expand });
if (subscription == null) if (subscription == null)
@ -132,7 +165,7 @@ public class StripeEventService : IStripeEventService
(await GetCharge(stripeEvent, true, customerExpansion))?.Customer?.Metadata, (await GetCharge(stripeEvent, true, customerExpansion))?.Customer?.Metadata,
HandledStripeWebhook.UpcomingInvoice => HandledStripeWebhook.UpcomingInvoice =>
(await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata, await GetCustomerMetadataFromUpcomingInvoiceEvent(stripeEvent),
HandledStripeWebhook.PaymentSucceeded or HandledStripeWebhook.PaymentFailed or HandledStripeWebhook.InvoiceCreated => HandledStripeWebhook.PaymentSucceeded or HandledStripeWebhook.PaymentFailed or HandledStripeWebhook.InvoiceCreated =>
(await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata, (await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata,
@ -154,6 +187,20 @@ public class StripeEventService : IStripeEventService
var customerRegion = GetCustomerRegion(customerMetadata); var customerRegion = GetCustomerRegion(customerMetadata);
return customerRegion == serverRegion; return customerRegion == serverRegion;
/* In Stripe, when we receive an invoice.upcoming event, the event does not include an Invoice ID because
the invoice hasn't been created yet. Therefore, rather than retrieving the fresh Invoice with a 'customer'
expansion, we need to use the Customer ID on the event to retrieve the metadata. */
async Task<Dictionary<string, string>> GetCustomerMetadataFromUpcomingInvoiceEvent(Event localStripeEvent)
{
var invoice = await GetInvoice(localStripeEvent);
var customer = !string.IsNullOrEmpty(invoice.CustomerId)
? await _stripeFacade.GetCustomer(invoice.CustomerId)
: null;
return customer?.Metadata;
}
} }
private static T Extract<T>(Event stripeEvent) private static T Extract<T>(Event stripeEvent)

View File

@ -14,4 +14,5 @@ public interface IOrganizationRepository : IRepository<Organization, Guid>
Task<Organization> GetByLicenseKeyAsync(string licenseKey); Task<Organization> GetByLicenseKeyAsync(string licenseKey);
Task<SelfHostedOrganizationDetails> GetSelfHostedOrganizationDetailsById(Guid id); Task<SelfHostedOrganizationDetails> GetSelfHostedOrganizationDetailsById(Guid id);
Task<ICollection<Organization>> SearchUnassignedToProviderAsync(string name, string ownerEmail, int skip, int take); Task<ICollection<Organization>> SearchUnassignedToProviderAsync(string name, string ownerEmail, int skip, int take);
Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId);
} }

View File

@ -24,7 +24,17 @@ public interface IMailService
Task SendOrganizationConfirmedEmailAsync(string organizationName, string email); Task SendOrganizationConfirmedEmailAsync(string organizationName, string email);
Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email);
Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); Task SendPasswordlessSignInAsync(string returnUrl, string token, string email);
Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List<string> items, Task SendInvoiceUpcoming(
string email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices);
Task SendInvoiceUpcoming(
IEnumerable<string> email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices); bool mentionInvoices);
Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices);
Task SendAddedCreditAsync(string email, decimal amount); Task SendAddedCreditAsync(string email, decimal amount);

View File

@ -285,10 +285,21 @@ public class HandlebarsMailService : IMailService
await _mailDeliveryService.SendEmailAsync(message); await _mailDeliveryService.SendEmailAsync(message);
} }
public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, public async Task SendInvoiceUpcoming(
List<string> items, bool mentionInvoices) string email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices) => await SendInvoiceUpcoming(new List<string> { email }, amount, dueDate, items, mentionInvoices);
public async Task SendInvoiceUpcoming(
IEnumerable<string> emails,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices)
{ {
var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email); var message = CreateDefaultMessage("Your Subscription Will Renew Soon", emails);
var model = new InvoiceUpcomingViewModel var model = new InvoiceUpcomingViewModel
{ {
WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash,

View File

@ -88,11 +88,19 @@ public class NoopMailService : IMailService
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, public Task SendInvoiceUpcoming(
List<string> items, bool mentionInvoices) string email,
{ decimal amount,
return Task.FromResult(0); DateTime dueDate,
} List<string> items,
bool mentionInvoices) => Task.FromResult(0);
public Task SendInvoiceUpcoming(
IEnumerable<string> emails,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices) => Task.FromResult(0);
public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices)
{ {

View File

@ -149,4 +149,14 @@ public class OrganizationRepository : Repository<Organization, Guid>, IOrganizat
return results.ToList(); return results.ToList();
} }
} }
public async Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId)
{
await using var connection = new SqlConnection(ConnectionString);
return await connection.QueryAsync<string>(
$"[{Schema}].[{Table}_ReadOwnerEmailAddressesById]",
new { OrganizationId = organizationId },
commandType: CommandType.StoredProcedure);
}
} }

View File

@ -224,4 +224,24 @@ public class OrganizationRepository : Repository<Core.Entities.Organization, Org
return selfHostedOrganization; return selfHostedOrganization;
} }
} }
public async Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var query =
from u in dbContext.Users
join ou in dbContext.OrganizationUsers on u.Id equals ou.UserId
where
ou.OrganizationId == organizationId &&
ou.Type == OrganizationUserType.Owner &&
ou.Status == OrganizationUserStatusType.Confirmed
group u by u.Email
into grouped
select grouped.Key;
return await query.ToListAsync();
}
} }

View File

@ -3,6 +3,7 @@ using Bit.Billing.Services.Implementations;
using Bit.Billing.Test.Utilities; using Bit.Billing.Test.Utilities;
using Bit.Core.Settings; using Bit.Core.Settings;
using FluentAssertions; using FluentAssertions;
using Microsoft.Extensions.Logging;
using NSubstitute; using NSubstitute;
using Stripe; using Stripe;
using Xunit; using Xunit;
@ -21,7 +22,7 @@ public class StripeEventServiceTests
globalSettings.BaseServiceUri = baseServiceUriSettings; globalSettings.BaseServiceUri = baseServiceUriSettings;
_stripeFacade = Substitute.For<IStripeFacade>(); _stripeFacade = Substitute.For<IStripeFacade>();
_stripeEventService = new StripeEventService(globalSettings, _stripeFacade); _stripeEventService = new StripeEventService(globalSettings, Substitute.For<ILogger<StripeEventService>>(), _stripeFacade);
} }
#region GetCharge #region GetCharge

View File

@ -0,0 +1,17 @@
CREATE OR ALTER PROCEDURE [dbo].[Organization_ReadOwnerEmailAddressesById]
@OrganizationId UNIQUEIDENTIFIER
AS
BEGIN
SET NOCOUNT ON
SELECT
[U].[Email]
FROM [User] AS [U]
INNER JOIN [OrganizationUser] AS [OU] ON [U].[Id] = [OU].[UserId]
WHERE
[OU].[OrganizationId] = @OrganizationId AND
[OU].[Type] = 0 AND -- Owner
[OU].[Status] = 2 -- Confirmed
GROUP BY [U].[Email]
END
GO