diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index 63a9bb44e..a72d79673 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -1,4 +1,5 @@ using Bit.Api.Billing.Models.Responses; +using Bit.Core.Billing.Services; using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; @@ -10,7 +11,8 @@ namespace Bit.Api.Billing.Controllers; [Authorize("Application")] public class AccountsBillingController( IPaymentService paymentService, - IUserService userService) : Controller + IUserService userService, + IPaymentHistoryService paymentHistoryService) : Controller { [HttpGet("history")] [SelfHosted(NotSelfHostedOnly = true)] @@ -39,4 +41,38 @@ public class AccountsBillingController( var billingInfo = await paymentService.GetBillingAsync(user); return new BillingPaymentResponseModel(billingInfo); } + + [HttpGet("invoices")] + public async Task GetInvoicesAsync([FromQuery] string startAfter = null) + { + var user = await userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var invoices = await paymentHistoryService.GetInvoiceHistoryAsync( + user, + 5, + startAfter); + + return TypedResults.Ok(invoices); + } + + [HttpGet("transactions")] + public async Task GetTransactionsAsync([FromQuery] DateTime? startAfter = null) + { + var user = await userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var transactions = await paymentHistoryService.GetTransactionHistoryAsync( + user, + 5, + startAfter); + + return TypedResults.Ok(transactions); + } } diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 1a2406d3d..b1c7feb56 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -19,7 +19,8 @@ public class OrganizationBillingController( IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, IPaymentService paymentService, - ISubscriberService subscriberService) : BaseBillingController + ISubscriberService subscriberService, + IPaymentHistoryService paymentHistoryService) : BaseBillingController { [HttpGet("metadata")] public async Task GetMetadataAsync([FromRoute] Guid organizationId) @@ -61,6 +62,52 @@ public class OrganizationBillingController( return TypedResults.Ok(billingInfo); } + [HttpGet("invoices")] + public async Task GetInvoicesAsync([FromRoute] Guid organizationId, [FromQuery] string startAfter = null) + { + if (!await currentContext.ViewBillingHistory(organizationId)) + { + return TypedResults.Unauthorized(); + } + + var organization = await organizationRepository.GetByIdAsync(organizationId); + + if (organization == null) + { + return TypedResults.NotFound(); + } + + var invoices = await paymentHistoryService.GetInvoiceHistoryAsync( + organization, + 5, + startAfter); + + return TypedResults.Ok(invoices); + } + + [HttpGet("transactions")] + public async Task GetTransactionsAsync([FromRoute] Guid organizationId, [FromQuery] DateTime? startAfter = null) + { + if (!await currentContext.ViewBillingHistory(organizationId)) + { + return TypedResults.Unauthorized(); + } + + var organization = await organizationRepository.GetByIdAsync(organizationId); + + if (organization == null) + { + return TypedResults.NotFound(); + } + + var transactions = await paymentHistoryService.GetTransactionHistoryAsync( + organization, + 5, + startAfter); + + return TypedResults.Ok(transactions); + } + [HttpGet] [SelfHosted(NotSelfHostedOnly = true)] public async Task GetBillingAsync(Guid organizationId) diff --git a/src/Core/Billing/Models/BillingHistoryInfo.cs b/src/Core/Billing/Models/BillingHistoryInfo.cs index 2a7f2b758..03017b9b4 100644 --- a/src/Core/Billing/Models/BillingHistoryInfo.cs +++ b/src/Core/Billing/Models/BillingHistoryInfo.cs @@ -38,6 +38,7 @@ public class BillingHistoryInfo { public BillingInvoice(Invoice inv) { + Id = inv.Id; Date = inv.Created; Url = inv.HostedInvoiceUrl; PdfUrl = inv.InvoicePdf; @@ -46,6 +47,7 @@ public class BillingHistoryInfo Amount = inv.Total / 100M; } + public string Id { get; set; } public decimal Amount { get; set; } public DateTime? Date { get; set; } public string Url { get; set; } diff --git a/src/Core/Billing/Services/IPaymentHistoryService.cs b/src/Core/Billing/Services/IPaymentHistoryService.cs new file mode 100644 index 000000000..e38659b94 --- /dev/null +++ b/src/Core/Billing/Services/IPaymentHistoryService.cs @@ -0,0 +1,17 @@ +using Bit.Core.Billing.Models; +using Bit.Core.Entities; + +namespace Bit.Core.Billing.Services; + +public interface IPaymentHistoryService +{ + Task> GetInvoiceHistoryAsync( + ISubscriber subscriber, + int pageSize = 5, + string startAfter = null); + + Task> GetTransactionHistoryAsync( + ISubscriber subscriber, + int pageSize = 5, + DateTime? startAfter = null); +} diff --git a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs new file mode 100644 index 000000000..1e5e3ea0e --- /dev/null +++ b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs @@ -0,0 +1,53 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Models; +using Bit.Core.Entities; +using Bit.Core.Models.BitStripe; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.Billing.Services.Implementations; + +public class PaymentHistoryService( + IStripeAdapter stripeAdapter, + ITransactionRepository transactionRepository, + ILogger logger) : IPaymentHistoryService +{ + public async Task> GetInvoiceHistoryAsync( + ISubscriber subscriber, + int pageSize = 5, + string startAfter = null) + { + if (subscriber is not { GatewayCustomerId: not null, GatewaySubscriptionId: not null }) + { + return null; + } + + var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + { + Customer = subscriber.GatewayCustomerId, + Subscription = subscriber.GatewaySubscriptionId, + Limit = pageSize, + StartingAfter = startAfter + }); + + return invoices.Select(invoice => new BillingHistoryInfo.BillingInvoice(invoice)); + + } + + public async Task> GetTransactionHistoryAsync( + ISubscriber subscriber, + int pageSize = 5, + DateTime? startAfter = null) + { + var transactions = subscriber switch + { + User => await transactionRepository.GetManyByUserIdAsync(subscriber.Id, pageSize, startAfter), + Organization => await transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id, pageSize, startAfter), + _ => null + }; + + return transactions?.OrderByDescending(i => i.CreationDate) + .Select(t => new BillingHistoryInfo.BillingTransaction(t)); + } +} diff --git a/src/Core/Repositories/ITransactionRepository.cs b/src/Core/Repositories/ITransactionRepository.cs index 8491ef201..1039a0d8f 100644 --- a/src/Core/Repositories/ITransactionRepository.cs +++ b/src/Core/Repositories/ITransactionRepository.cs @@ -7,8 +7,8 @@ namespace Bit.Core.Repositories; public interface ITransactionRepository : IRepository { - Task> GetManyByUserIdAsync(Guid userId, int? limit = null); - Task> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null); - Task> GetManyByProviderIdAsync(Guid providerId, int? limit = null); + Task> GetManyByUserIdAsync(Guid userId, int? limit = null, DateTime? startAfter = null); + Task> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null, DateTime? startAfter = null); + Task> GetManyByProviderIdAsync(Guid providerId, int? limit = null, DateTime? startAfter = null); Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId); } diff --git a/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs b/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs index 88f10368c..0dad07c2c 100644 --- a/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs @@ -20,38 +20,60 @@ public class TransactionRepository : Repository, ITransaction : base(connectionString, readOnlyConnectionString) { } - public async Task> GetManyByUserIdAsync(Guid userId, int? limit = null) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByUserId]", - new { UserId = userId, Limit = limit ?? int.MaxValue }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null) + public async Task> GetManyByUserIdAsync( + Guid userId, + int? limit = null, + DateTime? startAfter = null) { await using var connection = new SqlConnection(ConnectionString); - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByOrganizationId]", - new { OrganizationId = organizationId, Limit = limit ?? int.MaxValue }, + $"[{Schema}].[Transaction_ReadByUserId]", + new + { + UserId = userId, + Limit = limit ?? int.MaxValue, + StartAfter = startAfter + }, commandType: CommandType.StoredProcedure); return results.ToList(); } - public async Task> GetManyByProviderIdAsync(Guid providerId, int? limit = null) + public async Task> GetManyByOrganizationIdAsync( + Guid organizationId, + int? limit = null, + DateTime? startAfter = null) + { + await using var connection = new SqlConnection(ConnectionString); + + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByOrganizationId]", + new + { + OrganizationId = organizationId, + Limit = limit ?? int.MaxValue, + StartAfter = startAfter + }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + + public async Task> GetManyByProviderIdAsync( + Guid providerId, + int? limit = null, + DateTime? startAfter = null) { await using var sqlConnection = new SqlConnection(ConnectionString); var results = await sqlConnection.QueryAsync( $"[{Schema}].[Transaction_ReadByProviderId]", - new { ProviderId = providerId, Limit = limit ?? int.MaxValue }, + new + { + ProviderId = providerId, + Limit = limit ?? int.MaxValue, + StartAfter = startAfter + }, commandType: CommandType.StoredProcedure); return results.ToList(); diff --git a/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs index 2150bb8fe..2aba3416d 100644 --- a/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs @@ -24,7 +24,10 @@ public class TransactionRepository : Repository(results); } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId, int? limit = null) + public async Task> GetManyByOrganizationIdAsync( + Guid organizationId, + int? limit = null, + DateTime? startAfter = null) { using var scope = ServiceScopeFactory.CreateScope(); @@ -32,6 +35,11 @@ public class TransactionRepository : Repository t.OrganizationId == organizationId && !t.UserId.HasValue); + if (startAfter.HasValue) + { + query = query.Where(t => t.CreationDate < startAfter.Value); + } + if (limit.HasValue) { query = query.OrderByDescending(o => o.CreationDate).Take(limit.Value); @@ -41,7 +49,10 @@ public class TransactionRepository : Repository>(results); } - public async Task> GetManyByUserIdAsync(Guid userId, int? limit = null) + public async Task> GetManyByUserIdAsync( + Guid userId, + int? limit = null, + DateTime? startAfter = null) { using var scope = ServiceScopeFactory.CreateScope(); @@ -49,6 +60,11 @@ public class TransactionRepository : Repository t.UserId == userId); + if (startAfter.HasValue) + { + query = query.Where(t => t.CreationDate < startAfter.Value); + } + if (limit.HasValue) { query = query.OrderByDescending(o => o.CreationDate).Take(limit.Value); @@ -59,13 +75,21 @@ public class TransactionRepository : Repository>(results); } - public async Task> GetManyByProviderIdAsync(Guid providerId, int? limit = null) + public async Task> GetManyByProviderIdAsync( + Guid providerId, + int? limit = null, + DateTime? startAfter = null) { using var serviceScope = ServiceScopeFactory.CreateScope(); var databaseContext = GetDatabaseContext(serviceScope); var query = databaseContext.Transactions .Where(transaction => transaction.ProviderId == providerId); + if (startAfter.HasValue) + { + query = query.Where(transaction => transaction.CreationDate < startAfter.Value); + } + if (limit.HasValue) { query = query.Take(limit.Value); diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index c2d670bd6..be451ea31 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -16,6 +16,8 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; using Bit.Core.Auth.Services.Implementations; using Bit.Core.Auth.UserFeatures; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.TrialInitiation; using Bit.Core.Entities; using Bit.Core.Enums; @@ -221,6 +223,7 @@ public static class ServiceCollectionExtensions }; }); services.AddScoped(); + services.AddScoped(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Sql/dbo/Stored Procedures/Transaction_ReadByOrganizationId.sql b/src/Sql/dbo/Stored Procedures/Transaction_ReadByOrganizationId.sql index e6f600c1f..f4e7dc030 100644 --- a/src/Sql/dbo/Stored Procedures/Transaction_ReadByOrganizationId.sql +++ b/src/Sql/dbo/Stored Procedures/Transaction_ReadByOrganizationId.sql @@ -1,16 +1,16 @@ CREATE PROCEDURE [dbo].[Transaction_ReadByOrganizationId] @OrganizationId UNIQUEIDENTIFIER, - @Limit INT + @Limit INT, + @StartAfter DATETIME2 = NULL AS BEGIN SET NOCOUNT ON - SELECT - TOP (@Limit) * - FROM - [dbo].[TransactionView] + SELECT TOP (@Limit) * + FROM [dbo].[TransactionView] WHERE [OrganizationId] = @OrganizationId + AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter) ORDER BY [CreationDate] DESC END diff --git a/src/Sql/dbo/Stored Procedures/Transaction_ReadByProviderId.sql b/src/Sql/dbo/Stored Procedures/Transaction_ReadByProviderId.sql index 5b5ccd3d0..42bbedb13 100644 --- a/src/Sql/dbo/Stored Procedures/Transaction_ReadByProviderId.sql +++ b/src/Sql/dbo/Stored Procedures/Transaction_ReadByProviderId.sql @@ -1,6 +1,7 @@ CREATE PROCEDURE [dbo].[Transaction_ReadByProviderId] @ProviderId UNIQUEIDENTIFIER, - @Limit INT + @Limit INT, + @StartAfter DATETIME2 = NULL AS BEGIN SET NOCOUNT ON @@ -11,6 +12,7 @@ BEGIN [dbo].[TransactionView] WHERE [ProviderId] = @ProviderId + AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter) ORDER BY [CreationDate] DESC END diff --git a/src/Sql/dbo/Stored Procedures/Transaction_ReadByUserId.sql b/src/Sql/dbo/Stored Procedures/Transaction_ReadByUserId.sql index 4d905d88c..18ba3fb0a 100644 --- a/src/Sql/dbo/Stored Procedures/Transaction_ReadByUserId.sql +++ b/src/Sql/dbo/Stored Procedures/Transaction_ReadByUserId.sql @@ -1,6 +1,7 @@ CREATE PROCEDURE [dbo].[Transaction_ReadByUserId] @UserId UNIQUEIDENTIFIER, - @Limit INT + @Limit INT, + @StartAfter DATETIME2 = NULL AS BEGIN SET NOCOUNT ON @@ -11,6 +12,7 @@ BEGIN [dbo].[TransactionView] WHERE [UserId] = @UserId + AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter) ORDER BY [CreationDate] DESC END diff --git a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs new file mode 100644 index 000000000..34bbb368b --- /dev/null +++ b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs @@ -0,0 +1,89 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Entities; +using Bit.Core.Models.BitStripe; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Services; + +public class PaymentHistoryServiceTests +{ + [Fact] + public async Task GetInvoiceHistoryAsync_Succeeds() + { + // Arrange + var subscriber = new Organization { GatewayCustomerId = "cus_id", GatewaySubscriptionId = "sub_id" }; + var invoices = new List { new() { Id = "in_id" } }; + var stripeAdapter = Substitute.For(); + stripeAdapter.InvoiceListAsync(Arg.Any()).Returns(invoices); + var transactionRepository = Substitute.For(); + var logger = Substitute.For>(); + var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository, logger); + + // Act + var result = await paymentHistoryService.GetInvoiceHistoryAsync(subscriber); + + // Assert + Assert.NotNull(result); + Assert.Single(result); + await stripeAdapter.Received(1).InvoiceListAsync(Arg.Any()); + } + + [Fact] + public async Task GetInvoiceHistoryAsync_SubscriberNull_ReturnsNull() + { + // Arrange + var paymentHistoryService = new PaymentHistoryService( + Substitute.For(), + Substitute.For(), + Substitute.For>()); + + // Act + var result = await paymentHistoryService.GetInvoiceHistoryAsync(null); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task GetTransactionHistoryAsync_Succeeds() + { + // Arrange + var subscriber = new Organization { Id = Guid.NewGuid() }; + var transactions = new List { new() { Id = Guid.NewGuid() } }; + var transactionRepository = Substitute.For(); + transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id, Arg.Any(), Arg.Any()).Returns(transactions); + var stripeAdapter = Substitute.For(); + var logger = Substitute.For>(); + var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository, logger); + + // Act + var result = await paymentHistoryService.GetTransactionHistoryAsync(subscriber); + + // Assert + Assert.NotNull(result); + Assert.Single(result); + await transactionRepository.Received(1).GetManyByOrganizationIdAsync(subscriber.Id, Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task GetTransactionHistoryAsync_SubscriberNull_ReturnsNull() + { + // Arrange + var paymentHistoryService = new PaymentHistoryService( + Substitute.For(), + Substitute.For(), + Substitute.For>()); + + // Act + var result = await paymentHistoryService.GetTransactionHistoryAsync(null); + + // Assert + Assert.Null(result); + } +} diff --git a/util/Migrator/DbScripts/2024-08-21_00_OrganizationTransactionsReadCursor.sql b/util/Migrator/DbScripts/2024-08-21_00_OrganizationTransactionsReadCursor.sql new file mode 100644 index 000000000..0c3631052 --- /dev/null +++ b/util/Migrator/DbScripts/2024-08-21_00_OrganizationTransactionsReadCursor.sql @@ -0,0 +1,18 @@ +CREATE OR ALTER PROCEDURE [dbo].[Transaction_ReadByOrganizationId] + @OrganizationId UNIQUEIDENTIFIER, + @Limit INT, + @StartAfter DATETIME2 = NULL +AS +BEGIN + SET NOCOUNT ON + + SELECT + TOP (@Limit) * + FROM + [dbo].[TransactionView] + WHERE + [OrganizationId] = @OrganizationId + AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter) + ORDER BY + [CreationDate] DESC +END diff --git a/util/Migrator/DbScripts/2024-08-21_01_ProviderTransactionsReadCursor.sql b/util/Migrator/DbScripts/2024-08-21_01_ProviderTransactionsReadCursor.sql new file mode 100644 index 000000000..174b32f40 --- /dev/null +++ b/util/Migrator/DbScripts/2024-08-21_01_ProviderTransactionsReadCursor.sql @@ -0,0 +1,18 @@ +CREATE OR ALTER PROCEDURE [dbo].[Transaction_ReadByProviderId] + @ProviderId UNIQUEIDENTIFIER, + @Limit INT, + @StartAfter DATETIME2 = NULL +AS +BEGIN + SET NOCOUNT ON + + SELECT + TOP (@Limit) * + FROM + [dbo].[TransactionView] + WHERE + [ProviderId] = @ProviderId + AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter) + ORDER BY + [CreationDate] DESC +END diff --git a/util/Migrator/DbScripts/2024-08-21_02_UserTransactionsReadCursor.sql b/util/Migrator/DbScripts/2024-08-21_02_UserTransactionsReadCursor.sql new file mode 100644 index 000000000..21032b167 --- /dev/null +++ b/util/Migrator/DbScripts/2024-08-21_02_UserTransactionsReadCursor.sql @@ -0,0 +1,18 @@ +CREATE OR ALTER PROCEDURE [dbo].[Transaction_ReadByUserId] + @UserId UNIQUEIDENTIFIER, + @Limit INT, + @StartAfter DATETIME2 = NULL +AS +BEGIN + SET NOCOUNT ON + + SELECT + TOP (@Limit) * + FROM + [dbo].[TransactionView] + WHERE + [UserId] = @UserId + AND (@StartAfter IS NULL OR [CreationDate] < @StartAfter) + ORDER BY + [CreationDate] DESC +END