1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

[AC-2568] Added invoices and transaction history endpoints. Added cursor paging for each (#4692)

* Added invoices and transaction history endpoints. Added cursor paging for each

* Removed try/catch since it's handled by middleware. Updated condition to use pattern matching

* Added unit tests for PaymentHistoryService

* Removed organizationId from account billing controller endpoints
This commit is contained in:
Conner Turnbull 2024-09-09 09:38:58 -04:00 committed by GitHub
parent ebf8bc0b85
commit 46ac2a9b3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 385 additions and 34 deletions

View File

@ -1,4 +1,5 @@
using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Services;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -10,7 +11,8 @@ namespace Bit.Api.Billing.Controllers;
[Authorize("Application")]
public class AccountsBillingController(
IPaymentService paymentService,
IUserService userService) : Controller
IUserService userService,
IPaymentHistoryService paymentHistoryService) : Controller
{
[HttpGet("history")]
[SelfHosted(NotSelfHostedOnly = true)]
@ -39,4 +41,38 @@ public class AccountsBillingController(
var billingInfo = await paymentService.GetBillingAsync(user);
return new BillingPaymentResponseModel(billingInfo);
}
[HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromQuery] string startAfter = null)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var invoices = await paymentHistoryService.GetInvoiceHistoryAsync(
user,
5,
startAfter);
return TypedResults.Ok(invoices);
}
[HttpGet("transactions")]
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var transactions = await paymentHistoryService.GetTransactionHistoryAsync(
user,
5,
startAfter);
return TypedResults.Ok(transactions);
}
}

View File

@ -19,7 +19,8 @@ public class OrganizationBillingController(
IOrganizationBillingService organizationBillingService,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
ISubscriberService subscriberService) : BaseBillingController
ISubscriberService subscriberService,
IPaymentHistoryService paymentHistoryService) : BaseBillingController
{
[HttpGet("metadata")]
public async Task<IResult> GetMetadataAsync([FromRoute] Guid organizationId)
@ -61,6 +62,52 @@ public class OrganizationBillingController(
return TypedResults.Ok(billingInfo);
}
[HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromRoute] Guid organizationId, [FromQuery] string startAfter = null)
{
if (!await currentContext.ViewBillingHistory(organizationId))
{
return TypedResults.Unauthorized();
}
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
return TypedResults.NotFound();
}
var invoices = await paymentHistoryService.GetInvoiceHistoryAsync(
organization,
5,
startAfter);
return TypedResults.Ok(invoices);
}
[HttpGet("transactions")]
public async Task<IResult> GetTransactionsAsync([FromRoute] Guid organizationId, [FromQuery] DateTime? startAfter = null)
{
if (!await currentContext.ViewBillingHistory(organizationId))
{
return TypedResults.Unauthorized();
}
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
return TypedResults.NotFound();
}
var transactions = await paymentHistoryService.GetTransactionHistoryAsync(
organization,
5,
startAfter);
return TypedResults.Ok(transactions);
}
[HttpGet]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> GetBillingAsync(Guid organizationId)

View File

@ -38,6 +38,7 @@ public class BillingHistoryInfo
{
public BillingInvoice(Invoice inv)
{
Id = inv.Id;
Date = inv.Created;
Url = inv.HostedInvoiceUrl;
PdfUrl = inv.InvoicePdf;
@ -46,6 +47,7 @@ public class BillingHistoryInfo
Amount = inv.Total / 100M;
}
public string Id { get; set; }
public decimal Amount { get; set; }
public DateTime? Date { get; set; }
public string Url { get; set; }

View File

@ -0,0 +1,17 @@
using Bit.Core.Billing.Models;
using Bit.Core.Entities;
namespace Bit.Core.Billing.Services;
public interface IPaymentHistoryService
{
Task<IEnumerable<BillingHistoryInfo.BillingInvoice>> GetInvoiceHistoryAsync(
ISubscriber subscriber,
int pageSize = 5,
string startAfter = null);
Task<IEnumerable<BillingHistoryInfo.BillingTransaction>> GetTransactionHistoryAsync(
ISubscriber subscriber,
int pageSize = 5,
DateTime? startAfter = null);
}

View File

@ -0,0 +1,53 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Models;
using Bit.Core.Entities;
using Bit.Core.Models.BitStripe;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Billing.Services.Implementations;
public class PaymentHistoryService(
IStripeAdapter stripeAdapter,
ITransactionRepository transactionRepository,
ILogger<PaymentHistoryService> logger) : IPaymentHistoryService
{
public async Task<IEnumerable<BillingHistoryInfo.BillingInvoice>> GetInvoiceHistoryAsync(
ISubscriber subscriber,
int pageSize = 5,
string startAfter = null)
{
if (subscriber is not { GatewayCustomerId: not null, GatewaySubscriptionId: not null })
{
return null;
}
var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions
{
Customer = subscriber.GatewayCustomerId,
Subscription = subscriber.GatewaySubscriptionId,
Limit = pageSize,
StartingAfter = startAfter
});
return invoices.Select(invoice => new BillingHistoryInfo.BillingInvoice(invoice));
}
public async Task<IEnumerable<BillingHistoryInfo.BillingTransaction>> GetTransactionHistoryAsync(
ISubscriber subscriber,
int pageSize = 5,
DateTime? startAfter = null)
{
var transactions = subscriber switch
{
User => await transactionRepository.GetManyByUserIdAsync(subscriber.Id, pageSize, startAfter),
Organization => await transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id, pageSize, startAfter),
_ => null
};
return transactions?.OrderByDescending(i => i.CreationDate)
.Select(t => new BillingHistoryInfo.BillingTransaction(t));
}
}

View File

@ -7,8 +7,8 @@ namespace Bit.Core.Repositories;
public interface ITransactionRepository : IRepository<Transaction, Guid>
{
Task<ICollection<Transaction>> GetManyByUserIdAsync(Guid userId, int? limit = null);
Task<ICollection<Transaction>> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null);
Task<ICollection<Transaction>> GetManyByProviderIdAsync(Guid providerId, int? limit = null);
Task<ICollection<Transaction>> GetManyByUserIdAsync(Guid userId, int? limit = null, DateTime? startAfter = null);
Task<ICollection<Transaction>> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null, DateTime? startAfter = null);
Task<ICollection<Transaction>> GetManyByProviderIdAsync(Guid providerId, int? limit = null, DateTime? startAfter = null);
Task<Transaction?> GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId);
}

View File

@ -20,38 +20,60 @@ public class TransactionRepository : Repository<Transaction, Guid>, ITransaction
: base(connectionString, readOnlyConnectionString)
{ }
public async Task<ICollection<Transaction>> GetManyByUserIdAsync(Guid userId, int? limit = null)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Transaction>(
$"[{Schema}].[Transaction_ReadByUserId]",
new { UserId = userId, Limit = limit ?? int.MaxValue },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
}
public async Task<ICollection<Transaction>> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null)
public async Task<ICollection<Transaction>> GetManyByUserIdAsync(
Guid userId,
int? limit = null,
DateTime? startAfter = null)
{
await using var connection = new SqlConnection(ConnectionString);
var results = await connection.QueryAsync<Transaction>(
$"[{Schema}].[Transaction_ReadByOrganizationId]",
new { OrganizationId = organizationId, Limit = limit ?? int.MaxValue },
$"[{Schema}].[Transaction_ReadByUserId]",
new
{
UserId = userId,
Limit = limit ?? int.MaxValue,
StartAfter = startAfter
},
commandType: CommandType.StoredProcedure);
return results.ToList();
}
public async Task<ICollection<Transaction>> GetManyByProviderIdAsync(Guid providerId, int? limit = null)
public async Task<ICollection<Transaction>> GetManyByOrganizationIdAsync(
Guid organizationId,
int? limit = null,
DateTime? startAfter = null)
{
await using var connection = new SqlConnection(ConnectionString);
var results = await connection.QueryAsync<Transaction>(
$"[{Schema}].[Transaction_ReadByOrganizationId]",
new
{
OrganizationId = organizationId,
Limit = limit ?? int.MaxValue,
StartAfter = startAfter
},
commandType: CommandType.StoredProcedure);
return results.ToList();
}
public async Task<ICollection<Transaction>> GetManyByProviderIdAsync(
Guid providerId,
int? limit = null,
DateTime? startAfter = null)
{
await using var sqlConnection = new SqlConnection(ConnectionString);
var results = await sqlConnection.QueryAsync<Transaction>(
$"[{Schema}].[Transaction_ReadByProviderId]",
new { ProviderId = providerId, Limit = limit ?? int.MaxValue },
new
{
ProviderId = providerId,
Limit = limit ?? int.MaxValue,
StartAfter = startAfter
},
commandType: CommandType.StoredProcedure);
return results.ToList();

View File

@ -24,7 +24,10 @@ public class TransactionRepository : Repository<Core.Entities.Transaction, Trans
return Mapper.Map<Core.Entities.Transaction>(results);
}
public async Task<ICollection<Core.Entities.Transaction>> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null)
public async Task<ICollection<Core.Entities.Transaction>> GetManyByOrganizationIdAsync(
Guid organizationId,
int? limit = null,
DateTime? startAfter = null)
{
using var scope = ServiceScopeFactory.CreateScope();
@ -32,6 +35,11 @@ public class TransactionRepository : Repository<Core.Entities.Transaction, Trans
var query = dbContext.Transactions
.Where(t => t.OrganizationId == organizationId && !t.UserId.HasValue);
if (startAfter.HasValue)
{
query = query.Where(t => t.CreationDate < startAfter.Value);
}
if (limit.HasValue)
{
query = query.OrderByDescending(o => o.CreationDate).Take(limit.Value);
@ -41,7 +49,10 @@ public class TransactionRepository : Repository<Core.Entities.Transaction, Trans
return Mapper.Map<List<Core.Entities.Transaction>>(results);
}
public async Task<ICollection<Core.Entities.Transaction>> GetManyByUserIdAsync(Guid userId, int? limit = null)
public async Task<ICollection<Core.Entities.Transaction>> GetManyByUserIdAsync(
Guid userId,
int? limit = null,
DateTime? startAfter = null)
{
using var scope = ServiceScopeFactory.CreateScope();
@ -49,6 +60,11 @@ public class TransactionRepository : Repository<Core.Entities.Transaction, Trans
var query = dbContext.Transactions
.Where(t => t.UserId == userId);
if (startAfter.HasValue)
{
query = query.Where(t => t.CreationDate < startAfter.Value);
}
if (limit.HasValue)
{
query = query.OrderByDescending(o => o.CreationDate).Take(limit.Value);
@ -59,13 +75,21 @@ public class TransactionRepository : Repository<Core.Entities.Transaction, Trans
return Mapper.Map<List<Core.Entities.Transaction>>(results);
}
public async Task<ICollection<Core.Entities.Transaction>> GetManyByProviderIdAsync(Guid providerId, int? limit = null)
public async Task<ICollection<Core.Entities.Transaction>> GetManyByProviderIdAsync(
Guid providerId,
int? limit = null,
DateTime? startAfter = null)
{
using var serviceScope = ServiceScopeFactory.CreateScope();
var databaseContext = GetDatabaseContext(serviceScope);
var query = databaseContext.Transactions
.Where(transaction => transaction.ProviderId == providerId);
if (startAfter.HasValue)
{
query = query.Where(transaction => transaction.CreationDate < startAfter.Value);
}
if (limit.HasValue)
{
query = query.Take(limit.Value);

View File

@ -16,6 +16,8 @@ using Bit.Core.Auth.Repositories;
using Bit.Core.Auth.Services;
using Bit.Core.Auth.Services.Implementations;
using Bit.Core.Auth.UserFeatures;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.TrialInitiation;
using Bit.Core.Entities;
using Bit.Core.Enums;
@ -221,6 +223,7 @@ public static class ServiceCollectionExtensions
};
});
services.AddScoped<IPaymentService, StripePaymentService>();
services.AddScoped<IPaymentHistoryService, PaymentHistoryService>();
services.AddSingleton<IStripeSyncService, StripeSyncService>();
services.AddSingleton<IMailService, HandlebarsMailService>();
services.AddSingleton<ILicensingService, LicensingService>();

View File

@ -1,16 +1,16 @@
CREATE PROCEDURE [dbo].[Transaction_ReadByOrganizationId]
@OrganizationId UNIQUEIDENTIFIER,
@Limit INT
@Limit INT,
@StartAfter DATETIME2 = NULL
AS
BEGIN
SET NOCOUNT ON
SELECT
TOP (@Limit) *
FROM
[dbo].[TransactionView]
SELECT TOP (@Limit) *
FROM [dbo].[TransactionView]
WHERE
[OrganizationId] = @OrganizationId
AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter)
ORDER BY
[CreationDate] DESC
END

View File

@ -1,6 +1,7 @@
CREATE PROCEDURE [dbo].[Transaction_ReadByProviderId]
@ProviderId UNIQUEIDENTIFIER,
@Limit INT
@Limit INT,
@StartAfter DATETIME2 = NULL
AS
BEGIN
SET NOCOUNT ON
@ -11,6 +12,7 @@ BEGIN
[dbo].[TransactionView]
WHERE
[ProviderId] = @ProviderId
AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter)
ORDER BY
[CreationDate] DESC
END

View File

@ -1,6 +1,7 @@
CREATE PROCEDURE [dbo].[Transaction_ReadByUserId]
@UserId UNIQUEIDENTIFIER,
@Limit INT
@Limit INT,
@StartAfter DATETIME2 = NULL
AS
BEGIN
SET NOCOUNT ON
@ -11,6 +12,7 @@ BEGIN
[dbo].[TransactionView]
WHERE
[UserId] = @UserId
AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter)
ORDER BY
[CreationDate] DESC
END

View File

@ -0,0 +1,89 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Entities;
using Bit.Core.Models.BitStripe;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Billing.Services;
public class PaymentHistoryServiceTests
{
[Fact]
public async Task GetInvoiceHistoryAsync_Succeeds()
{
// Arrange
var subscriber = new Organization { GatewayCustomerId = "cus_id", GatewaySubscriptionId = "sub_id" };
var invoices = new List<Invoice> { new() { Id = "in_id" } };
var stripeAdapter = Substitute.For<IStripeAdapter>();
stripeAdapter.InvoiceListAsync(Arg.Any<StripeInvoiceListOptions>()).Returns(invoices);
var transactionRepository = Substitute.For<ITransactionRepository>();
var logger = Substitute.For<ILogger<PaymentHistoryService>>();
var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository, logger);
// Act
var result = await paymentHistoryService.GetInvoiceHistoryAsync(subscriber);
// Assert
Assert.NotNull(result);
Assert.Single(result);
await stripeAdapter.Received(1).InvoiceListAsync(Arg.Any<StripeInvoiceListOptions>());
}
[Fact]
public async Task GetInvoiceHistoryAsync_SubscriberNull_ReturnsNull()
{
// Arrange
var paymentHistoryService = new PaymentHistoryService(
Substitute.For<IStripeAdapter>(),
Substitute.For<ITransactionRepository>(),
Substitute.For<ILogger<PaymentHistoryService>>());
// Act
var result = await paymentHistoryService.GetInvoiceHistoryAsync(null);
// Assert
Assert.Null(result);
}
[Fact]
public async Task GetTransactionHistoryAsync_Succeeds()
{
// Arrange
var subscriber = new Organization { Id = Guid.NewGuid() };
var transactions = new List<Transaction> { new() { Id = Guid.NewGuid() } };
var transactionRepository = Substitute.For<ITransactionRepository>();
transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id, Arg.Any<int>(), Arg.Any<DateTime?>()).Returns(transactions);
var stripeAdapter = Substitute.For<IStripeAdapter>();
var logger = Substitute.For<ILogger<PaymentHistoryService>>();
var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository, logger);
// Act
var result = await paymentHistoryService.GetTransactionHistoryAsync(subscriber);
// Assert
Assert.NotNull(result);
Assert.Single(result);
await transactionRepository.Received(1).GetManyByOrganizationIdAsync(subscriber.Id, Arg.Any<int>(), Arg.Any<DateTime?>());
}
[Fact]
public async Task GetTransactionHistoryAsync_SubscriberNull_ReturnsNull()
{
// Arrange
var paymentHistoryService = new PaymentHistoryService(
Substitute.For<IStripeAdapter>(),
Substitute.For<ITransactionRepository>(),
Substitute.For<ILogger<PaymentHistoryService>>());
// Act
var result = await paymentHistoryService.GetTransactionHistoryAsync(null);
// Assert
Assert.Null(result);
}
}

View File

@ -0,0 +1,18 @@
CREATE OR ALTER PROCEDURE [dbo].[Transaction_ReadByOrganizationId]
@OrganizationId UNIQUEIDENTIFIER,
@Limit INT,
@StartAfter DATETIME2 = NULL
AS
BEGIN
SET NOCOUNT ON
SELECT
TOP (@Limit) *
FROM
[dbo].[TransactionView]
WHERE
[OrganizationId] = @OrganizationId
AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter)
ORDER BY
[CreationDate] DESC
END

View File

@ -0,0 +1,18 @@
CREATE OR ALTER PROCEDURE [dbo].[Transaction_ReadByProviderId]
@ProviderId UNIQUEIDENTIFIER,
@Limit INT,
@StartAfter DATETIME2 = NULL
AS
BEGIN
SET NOCOUNT ON
SELECT
TOP (@Limit) *
FROM
[dbo].[TransactionView]
WHERE
[ProviderId] = @ProviderId
AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter)
ORDER BY
[CreationDate] DESC
END

View File

@ -0,0 +1,18 @@
CREATE OR ALTER PROCEDURE [dbo].[Transaction_ReadByUserId]
@UserId UNIQUEIDENTIFIER,
@Limit INT,
@StartAfter DATETIME2 = NULL
AS
BEGIN
SET NOCOUNT ON
SELECT
TOP (@Limit) *
FROM
[dbo].[TransactionView]
WHERE
[UserId] = @UserId
AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter)
ORDER BY
[CreationDate] DESC
END