1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

[AC-1938] Update provider payment method (#4140)

* Refactored GET provider subscription

Refactoring this endpoint and its associated tests in preparation for the addition of more endpoints that share similar patterns

* Replaced StripePaymentService call in AccountsController, OrganizationsController

This was made in error during a previous PR. Since this is not related to Consolidated Billing, we want to try not to include it in these changes.

* Removing GetPaymentInformation call from ProviderBillingService

This method is a good call for the SubscriberService as we'll want to extend the functionality to all subscriber types

* Refactored GetTaxInformation to use Billing owned DTO

* Add UpdateTaxInformation to SubscriberService

* Added GetTaxInformation and UpdateTaxInformation endpoints to ProviderBillingController

* Added controller to manage creation of Stripe SetupIntents

With the deprecation of the Sources API, we need to move the bank account creation process to using SetupIntents. This controller brings both the creation of "card" and "us_bank_account" SetupIntents
under billing management.

* Added UpdatePaymentMethod method to SubscriberService

This method utilizes the SetupIntents created by the StripeController from the previous commit when a customer adds a card or us_bank_account payment method (Stripe). We need to cache the most recent SetupIntent for the subscriber so that we know which PaymentMethod is their most recent even when it hasn't been confirmed yet.

* Refactored GetPaymentMethod to use billing owned DTO and check setup intents

* Added GetPaymentMethod and UpdatePaymentMethod endpoints to ProviderBillingController

* Re-added GetPaymentInformation endpoint to consolidate API calls on the payment method page

* Added VerifyBankAccount endpoint to ProviderBillingController in order to finalize bank account payment methods

* Updated BitPayInvoiceRequestModel to support providers

* run dotnet format

* Conner's feedback

* Run dotnet format'
This commit is contained in:
Alex Morask 2024-06-03 11:00:52 -04:00 committed by GitHub
parent b42ebe6f1b
commit 2b43cde99b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 2478 additions and 540 deletions

View File

@ -226,22 +226,14 @@ public class ProviderBillingService(
.Sum(providerOrganization => providerOrganization.Seats ?? 0);
}
public async Task<ProviderSubscriptionDTO> GetSubscriptionDTO(Guid providerId)
public async Task<ConsolidatedBillingSubscriptionDTO> GetConsolidatedBillingSubscription(
Provider provider)
{
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
logger.LogError(
"Could not find provider ({ID}) when retrieving subscription data.",
providerId);
return null;
}
ArgumentNullException.ThrowIfNull(provider);
if (provider.Type == ProviderType.Reseller)
{
logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId);
logger.LogError("Consolidated billing subscription cannot be retrieved for reseller-type provider ({ID})", provider.Id);
throw ContactSupport("Consolidated billing does not support reseller-type providers");
}
@ -256,14 +248,14 @@ public class ProviderBillingService(
return null;
}
var providerPlans = await providerPlanRepository.GetByProviderId(providerId);
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var configuredProviderPlans = providerPlans
.Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlanDTO.From)
.ToList();
return new ProviderSubscriptionDTO(
return new ConsolidatedBillingSubscriptionDTO(
configuredProviderPlans,
subscription);
}
@ -454,39 +446,6 @@ public class ProviderBillingService(
await providerRepository.ReplaceAsync(provider);
}
public async Task<ProviderPaymentInfoDTO> GetPaymentInformationAsync(Guid providerId)
{
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
logger.LogError(
"Could not find provider ({ID}) when retrieving payment information.",
providerId);
return null;
}
if (provider.Type == ProviderType.Reseller)
{
logger.LogError("payment information cannot be retrieved for reseller-type provider ({ID})", providerId);
throw ContactSupport("Consolidated billing does not support reseller-type providers");
}
var taxInformation = await subscriberService.GetTaxInformationAsync(provider);
var billingInformation = await subscriberService.GetPaymentMethodAsync(provider);
if (taxInformation == null && billingInformation == null)
{
return null;
}
return new ProviderPaymentInfoDTO(
billingInformation,
taxInformation);
}
private Func<int, int, Task> CurrySeatScalingUpdate(
Provider provider,
ProviderPlan providerPlan,

View File

@ -21,7 +21,6 @@ using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Stripe;
using Xunit;
using static Bit.Core.Test.Billing.Utilities;
@ -701,60 +700,33 @@ public class ProviderBillingServiceTests
#endregion
#region GetSubscriptionData
#region GetConsolidatedBillingSubscription
[Theory, BitAutoData]
public async Task GetSubscriptionData_NullProvider_ReturnsNull(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId)
{
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
providerRepository.GetByIdAsync(providerId).ReturnsNull();
var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId);
Assert.Null(subscriptionData);
await providerRepository.Received(1).GetByIdAsync(providerId);
}
public async Task GetConsolidatedBillingSubscription_NullProvider_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.GetConsolidatedBillingSubscription(null));
[Theory, BitAutoData]
public async Task GetSubscriptionData_NullSubscription_ReturnsNull(
public async Task GetConsolidatedBillingSubscription_NullSubscription_ReturnsNull(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId,
Provider provider)
{
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
var consolidatedBillingSubscription = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider);
providerRepository.GetByIdAsync(providerId).Returns(provider);
Assert.Null(consolidatedBillingSubscription);
var subscriberService = sutProvider.GetDependency<ISubscriberService>();
subscriberService.GetSubscription(provider).ReturnsNull();
var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId);
Assert.Null(subscriptionData);
await providerRepository.Received(1).GetByIdAsync(providerId);
await subscriberService.Received(1).GetSubscription(
await sutProvider.GetDependency<ISubscriberService>().Received(1).GetSubscription(
provider,
Arg.Is<SubscriptionGetOptions>(
options => options.Expand.Count == 1 && options.Expand.First() == "customer"));
}
[Theory, BitAutoData]
public async Task GetSubscriptionData_Success(
public async Task GetConsolidatedBillingSubscription_Success(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId,
Provider provider)
{
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
providerRepository.GetByIdAsync(providerId).Returns(provider);
var subscriberService = sutProvider.GetDependency<ISubscriberService>();
var subscription = new Subscription();
@ -767,7 +739,7 @@ public class ProviderBillingServiceTests
var enterprisePlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = providerId,
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
@ -777,7 +749,7 @@ public class ProviderBillingServiceTests
var teamsPlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = providerId,
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 50,
PurchasedSeats = 10,
@ -786,37 +758,28 @@ public class ProviderBillingServiceTests
var providerPlans = new List<ProviderPlan> { enterprisePlan, teamsPlan, };
providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans);
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId);
var consolidatedBillingSubscription = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider);
Assert.NotNull(subscriptionData);
Assert.NotNull(consolidatedBillingSubscription);
Assert.Equivalent(subscriptionData.Subscription, subscription);
Assert.Equivalent(consolidatedBillingSubscription.Subscription, subscription);
Assert.Equal(2, subscriptionData.ProviderPlans.Count);
Assert.Equal(2, consolidatedBillingSubscription.ProviderPlans.Count);
var configuredEnterprisePlan =
subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan =>
consolidatedBillingSubscription.ProviderPlans.FirstOrDefault(configuredPlan =>
configuredPlan.PlanType == PlanType.EnterpriseMonthly);
var configuredTeamsPlan =
subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan =>
consolidatedBillingSubscription.ProviderPlans.FirstOrDefault(configuredPlan =>
configuredPlan.PlanType == PlanType.TeamsMonthly);
Compare(enterprisePlan, configuredEnterprisePlan);
Compare(teamsPlan, configuredTeamsPlan);
await providerRepository.Received(1).GetByIdAsync(providerId);
await subscriberService.Received(1).GetSubscription(
provider,
Arg.Is<SubscriptionGetOptions>(
options => options.Expand.Count == 1 && options.Expand.First() == "customer"));
await providerPlanRepository.Received(1).GetByProviderId(providerId);
return;
void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan)
@ -1005,106 +968,4 @@ public class ProviderBillingServiceTests
}
#endregion
#region GetPaymentInformationAsync
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_NullProvider_ReturnsNull(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId)
{
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
providerRepository.GetByIdAsync(providerId).ReturnsNull();
var paymentService = sutProvider.GetDependency<ISubscriberService>();
paymentService.GetTaxInformationAsync(Arg.Any<Provider>()).ReturnsNull();
paymentService.GetPaymentMethodAsync(Arg.Any<Provider>()).ReturnsNull();
var sut = sutProvider.Sut;
var paymentInfo = await sut.GetPaymentInformationAsync(providerId);
Assert.Null(paymentInfo);
await providerRepository.Received(1).GetByIdAsync(providerId);
await paymentService.DidNotReceive().GetTaxInformationAsync(Arg.Any<Provider>());
await paymentService.DidNotReceive().GetPaymentMethodAsync(Arg.Any<Provider>());
}
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_NullSubscription_ReturnsNull(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId,
Provider provider)
{
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
providerRepository.GetByIdAsync(providerId).Returns(provider);
var subscriberService = sutProvider.GetDependency<ISubscriberService>();
subscriberService.GetTaxInformationAsync(provider).ReturnsNull();
subscriberService.GetPaymentMethodAsync(provider).ReturnsNull();
var paymentInformation = await sutProvider.Sut.GetPaymentInformationAsync(providerId);
Assert.Null(paymentInformation);
await providerRepository.Received(1).GetByIdAsync(providerId);
await subscriberService.Received(1).GetTaxInformationAsync(provider);
await subscriberService.Received(1).GetPaymentMethodAsync(provider);
}
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_ResellerProvider_ThrowContactSupport(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId,
Provider provider)
{
provider.Id = providerId;
provider.Type = ProviderType.Reseller;
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
providerRepository.GetByIdAsync(providerId).Returns(provider);
var exception = await Assert.ThrowsAsync<BillingException>(
() => sutProvider.Sut.GetPaymentInformationAsync(providerId));
Assert.Equal("Consolidated billing does not support reseller-type providers", exception.Message);
}
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_Success_ReturnsProviderPaymentInfoDTO(
SutProvider<ProviderBillingService> sutProvider,
Guid providerId,
Provider provider)
{
provider.Id = providerId;
provider.Type = ProviderType.Msp;
var taxInformation = new TaxInfo { TaxIdNumber = "12345" };
var paymentMethod = new PaymentMethod
{
Id = "pm_test123",
Type = "card",
Card = new PaymentMethodCard
{
Brand = "visa",
Last4 = "4242",
ExpMonth = 12,
ExpYear = 2024
}
};
var billingInformation = new BillingInfo { PaymentSource = new BillingInfo.BillingSource(paymentMethod) };
var providerRepository = sutProvider.GetDependency<IProviderRepository>();
providerRepository.GetByIdAsync(providerId).Returns(provider);
var subscriberService = sutProvider.GetDependency<ISubscriberService>();
subscriberService.GetTaxInformationAsync(provider).Returns(taxInformation);
subscriberService.GetPaymentMethodAsync(provider).Returns(billingInformation.PaymentSource);
var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId);
// Assert
Assert.NotNull(result);
Assert.Equal(billingInformation.PaymentSource, result.billingSource);
Assert.Equal(taxInformation, result.taxInfo);
}
#endregion
}

View File

@ -835,7 +835,7 @@ public class AccountsController : Controller
throw new UnauthorizedAccessException();
}
var taxInfo = await _subscriberService.GetTaxInformationAsync(user);
var taxInfo = await _paymentService.GetTaxInfoAsync(user);
return new TaxInfoResponseModel(taxInfo);
}

View File

@ -304,7 +304,7 @@ public class OrganizationsController(
throw new NotFoundException();
}
var taxInfo = await subscriberService.GetTaxInformationAsync(organization);
var taxInfo = await paymentService.GetTaxInfoAsync(organization);
return new TaxInfoResponseModel(taxInfo);
}

View File

@ -1,10 +1,17 @@
using Bit.Api.Billing.Models.Responses;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Stripe;
namespace Bit.Api.Billing.Controllers;
@ -13,59 +20,194 @@ namespace Bit.Api.Billing.Controllers;
public class ProviderBillingController(
ICurrentContext currentContext,
IFeatureService featureService,
IProviderBillingService providerBillingService) : Controller
IProviderBillingService providerBillingService,
IProviderRepository providerRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : Controller
{
[HttpGet("subscription")]
public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId)
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
return TypedResults.NotFound();
}
if (!currentContext.ProviderProviderAdmin(providerId))
{
return TypedResults.Unauthorized();
}
var providerSubscriptionDTO = await providerBillingService.GetSubscriptionDTO(providerId);
if (providerSubscriptionDTO == null)
{
return TypedResults.NotFound();
}
var (providerPlans, subscription) = providerSubscriptionDTO;
var providerSubscriptionResponse = ProviderSubscriptionResponse.From(providerPlans, subscription);
return TypedResults.Ok(providerSubscriptionResponse);
}
[HttpGet("payment-information")]
public async Task<IResult> GetPaymentInformationAsync([FromRoute] Guid providerId)
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
var paymentInformation = await subscriberService.GetPaymentInformation(provider);
if (paymentInformation == null)
{
return TypedResults.NotFound();
}
var response = PaymentInformationResponse.From(paymentInformation);
return TypedResults.Ok(response);
}
[HttpGet("payment-method")]
public async Task<IResult> GetPaymentMethodAsync([FromRoute] Guid providerId)
{
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
var maskedPaymentMethod = await subscriberService.GetPaymentMethod(provider);
if (maskedPaymentMethod == null)
{
return TypedResults.NotFound();
}
var response = MaskedPaymentMethodResponse.From(maskedPaymentMethod);
return TypedResults.Ok(response);
}
[HttpPut("payment-method")]
public async Task<IResult> UpdatePaymentMethodAsync(
[FromRoute] Guid providerId,
[FromBody] TokenizedPaymentMethodRequestBody requestBody)
{
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
var tokenizedPaymentMethod = new TokenizedPaymentMethodDTO(
requestBody.Type,
requestBody.Token);
await subscriberService.UpdatePaymentMethod(provider, tokenizedPaymentMethod);
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions
{
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically
});
return TypedResults.Ok();
}
[HttpPost]
[Route("payment-method/verify-bank-account")]
public async Task<IResult> VerifyBankAccountAsync(
[FromRoute] Guid providerId,
[FromBody] VerifyBankAccountRequestBody requestBody)
{
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
await subscriberService.VerifyBankAccount(provider, (requestBody.Amount1, requestBody.Amount2));
return TypedResults.Ok();
}
[HttpGet("subscription")]
public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId)
{
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
var consolidatedBillingSubscription = await providerBillingService.GetConsolidatedBillingSubscription(provider);
if (consolidatedBillingSubscription == null)
{
return TypedResults.NotFound();
}
var response = ConsolidatedBillingSubscriptionResponse.From(consolidatedBillingSubscription);
return TypedResults.Ok(response);
}
[HttpGet("tax-information")]
public async Task<IResult> GetTaxInformationAsync([FromRoute] Guid providerId)
{
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
var taxInformation = await subscriberService.GetTaxInformation(provider);
if (taxInformation == null)
{
return TypedResults.NotFound();
}
var response = TaxInformationResponse.From(taxInformation);
return TypedResults.Ok(response);
}
[HttpPut("tax-information")]
public async Task<IResult> UpdateTaxInformationAsync(
[FromRoute] Guid providerId,
[FromBody] TaxInformationRequestBody requestBody)
{
var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId);
if (provider == null)
{
return result;
}
var taxInformation = new TaxInformationDTO(
requestBody.Country,
requestBody.PostalCode,
requestBody.TaxId,
requestBody.Line1,
requestBody.Line2,
requestBody.City,
requestBody.State);
await subscriberService.UpdateTaxInformation(provider, taxInformation);
return TypedResults.Ok();
}
private async Task<(Provider, IResult)> GetAuthorizedBillableProviderOrResultAsync(Guid providerId)
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
return (null, TypedResults.NotFound());
}
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
return (null, TypedResults.NotFound());
}
if (!currentContext.ProviderProviderAdmin(providerId))
{
return TypedResults.Unauthorized();
return (null, TypedResults.Unauthorized());
}
var providerPaymentInformationDto = await providerBillingService.GetPaymentInformationAsync(providerId);
if (providerPaymentInformationDto == null)
if (!provider.IsBillable())
{
return TypedResults.NotFound();
return (null, TypedResults.Unauthorized());
}
var (paymentSource, taxInfo) = providerPaymentInformationDto;
var providerPaymentInformationResponse = PaymentInformationResponse.From(paymentSource, taxInfo);
return TypedResults.Ok(providerPaymentInformationResponse);
return (provider, null);
}
}

View File

@ -0,0 +1,49 @@
using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Mvc;
using Stripe;
namespace Bit.Api.Billing.Controllers;
[Authorize("Application")]
public class StripeController(
IStripeAdapter stripeAdapter) : Controller
{
[HttpPost]
[Route("~/setup-intent/bank-account")]
public async Task<Ok<string>> CreateSetupIntentForBankAccountAsync()
{
var options = new SetupIntentCreateOptions
{
PaymentMethodOptions = new SetupIntentPaymentMethodOptionsOptions
{
UsBankAccount = new SetupIntentPaymentMethodOptionsUsBankAccountOptions
{
VerificationMethod = "microdeposits"
}
},
PaymentMethodTypes = ["us_bank_account"],
Usage = "off_session"
};
var setupIntent = await stripeAdapter.SetupIntentCreate(options);
return TypedResults.Ok(setupIntent.ClientSecret);
}
[HttpPost]
[Route("~/setup-intent/card")]
public async Task<Ok<string>> CreateSetupIntentForCardAsync()
{
var options = new SetupIntentCreateOptions
{
PaymentMethodTypes = ["card"],
Usage = "off_session"
};
var setupIntent = await stripeAdapter.SetupIntentCreate(options);
return TypedResults.Ok(setupIntent.ClientSecret);
}
}

View File

@ -0,0 +1,16 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
public class TaxInformationRequestBody
{
[Required]
public string Country { get; set; }
[Required]
public string PostalCode { get; set; }
public string TaxId { get; set; }
public string Line1 { get; set; }
public string Line2 { get; set; }
public string City { get; set; }
public string State { get; set; }
}

View File

@ -0,0 +1,18 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Utilities;
using Bit.Core.Enums;
namespace Bit.Api.Billing.Models.Requests;
public class TokenizedPaymentMethodRequestBody
{
[Required]
[EnumMatches<PaymentMethodType>(
PaymentMethodType.BankAccount,
PaymentMethodType.Card,
PaymentMethodType.PayPal,
ErrorMessage = "'type' must be BankAccount, Card or PayPal")]
public PaymentMethodType Type { get; set; }
[Required]
public string Token { get; set; }
}

View File

@ -0,0 +1,11 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
public class VerifyBankAccountRequestBody
{
[Range(0, 99)]
public long Amount1 { get; set; }
[Range(0, 99)]
public long Amount2 { get; set; }
}

View File

@ -1,29 +1,29 @@
using Bit.Core.Billing.Models;
using Bit.Core.Utilities;
using Stripe;
namespace Bit.Api.Billing.Models.Responses;
public record ProviderSubscriptionResponse(
public record ConsolidatedBillingSubscriptionResponse(
string Status,
DateTime CurrentPeriodEndDate,
decimal? DiscountPercentage,
IEnumerable<ProviderPlanDTO> Plans)
IEnumerable<ProviderPlanResponse> Plans)
{
private const string _annualCadence = "Annual";
private const string _monthlyCadence = "Monthly";
public static ProviderSubscriptionResponse From(
IEnumerable<ConfiguredProviderPlanDTO> providerPlans,
Subscription subscription)
public static ConsolidatedBillingSubscriptionResponse From(
ConsolidatedBillingSubscriptionDTO consolidatedBillingSubscription)
{
var (providerPlans, subscription) = consolidatedBillingSubscription;
var providerPlansDTO = providerPlans
.Select(providerPlan =>
{
var plan = StaticStore.GetPlan(providerPlan.PlanType);
var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.SeatPrice;
var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence;
return new ProviderPlanDTO(
return new ProviderPlanResponse(
plan.Name,
providerPlan.SeatMinimum,
providerPlan.PurchasedSeats,
@ -32,7 +32,7 @@ public record ProviderSubscriptionResponse(
cadence);
});
return new ProviderSubscriptionResponse(
return new ConsolidatedBillingSubscriptionResponse(
subscription.Status,
subscription.CurrentPeriodEnd,
subscription.Customer?.Discount?.Coupon?.PercentOff,
@ -40,7 +40,7 @@ public record ProviderSubscriptionResponse(
}
}
public record ProviderPlanDTO(
public record ProviderPlanResponse(
string PlanName,
int SeatMinimum,
int PurchasedSeats,

View File

@ -0,0 +1,16 @@
using Bit.Core.Billing.Models;
using Bit.Core.Enums;
namespace Bit.Api.Billing.Models.Responses;
public record MaskedPaymentMethodResponse(
PaymentMethodType Type,
string Description,
bool NeedsVerification)
{
public static MaskedPaymentMethodResponse From(MaskedPaymentMethodDTO maskedPaymentMethod)
=> new(
maskedPaymentMethod.Type,
maskedPaymentMethod.Description,
maskedPaymentMethod.NeedsVerification);
}

View File

@ -1,37 +1,15 @@
using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Billing.Models;
namespace Bit.Api.Billing.Models.Responses;
public record PaymentInformationResponse(PaymentMethod PaymentMethod, TaxInformation TaxInformation)
public record PaymentInformationResponse(
long AccountCredit,
MaskedPaymentMethodDTO PaymentMethod,
TaxInformationDTO TaxInformation)
{
public static PaymentInformationResponse From(BillingInfo.BillingSource billingSource, TaxInfo taxInfo)
{
var paymentMethodDto = new PaymentMethod(
billingSource.Type, billingSource.Description, billingSource.CardBrand
);
var taxInformationDto = new TaxInformation(
taxInfo.BillingAddressCountry, taxInfo.BillingAddressPostalCode, taxInfo.TaxIdNumber,
taxInfo.BillingAddressLine1, taxInfo.BillingAddressLine2, taxInfo.BillingAddressCity,
taxInfo.BillingAddressState
);
return new PaymentInformationResponse(paymentMethodDto, taxInformationDto);
}
public static PaymentInformationResponse From(PaymentInformationDTO paymentInformation) =>
new(
paymentInformation.AccountCredit,
paymentInformation.PaymentMethod,
paymentInformation.TaxInformation);
}
public record PaymentMethod(
PaymentMethodType Type,
string Description,
string CardBrand);
public record TaxInformation(
string Country,
string PostalCode,
string TaxId,
string Line1,
string Line2,
string City,
string State);

View File

@ -0,0 +1,23 @@
using Bit.Core.Billing.Models;
namespace Bit.Api.Billing.Models.Responses;
public record TaxInformationResponse(
string Country,
string PostalCode,
string TaxId,
string Line1,
string Line2,
string City,
string State)
{
public static TaxInformationResponse From(TaxInformationDTO taxInformation)
=> new(
taxInformation.Country,
taxInformation.PostalCode,
taxInformation.TaxId,
taxInformation.Line1,
taxInformation.Line2,
taxInformation.City,
taxInformation.State);
}

View File

@ -7,6 +7,7 @@ public class BitPayInvoiceRequestModel : IValidatableObject
{
public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; }
public Guid? ProviderId { get; set; }
public bool Credit { get; set; }
[Required]
public decimal? Amount { get; set; }
@ -40,6 +41,10 @@ public class BitPayInvoiceRequestModel : IValidatableObject
{
posData = "organizationId:" + OrganizationId.Value;
}
else if (ProviderId.HasValue)
{
posData = "providerId:" + ProviderId.Value;
}
if (Credit)
{
@ -57,9 +62,9 @@ public class BitPayInvoiceRequestModel : IValidatableObject
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (!UserId.HasValue && !OrganizationId.HasValue)
if (!UserId.HasValue && !OrganizationId.HasValue && !ProviderId.HasValue)
{
yield return new ValidationResult("User or Organization is required.");
yield return new ValidationResult("User, Organization or Provider is required.");
}
}
}

View File

@ -0,0 +1,10 @@
namespace Bit.Core.Billing.Caches;
public interface ISetupIntentCache
{
Task<string> Get(Guid subscriberId);
Task Remove(Guid subscriberId);
Task Set(Guid subscriberId, string setupIntentId);
}

View File

@ -0,0 +1,32 @@
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Billing.Caches.Implementations;
public class SetupIntentDistributedCache(
[FromKeyedServices("persistent")]
IDistributedCache distributedCache) : ISetupIntentCache
{
public async Task<string> Get(Guid subscriberId)
{
var cacheKey = GetCacheKey(subscriberId);
return await distributedCache.GetStringAsync(cacheKey);
}
public async Task Remove(Guid subscriberId)
{
var cacheKey = GetCacheKey(subscriberId);
await distributedCache.RemoveAsync(cacheKey);
}
public async Task Set(Guid subscriberId, string setupIntentId)
{
var cacheKey = GetCacheKey(subscriberId);
await distributedCache.SetStringAsync(cacheKey, setupIntentId);
}
private static string GetCacheKey(Guid subscriberId) => $"pending_bank_account_{subscriberId}";
}

View File

@ -21,6 +21,12 @@ public static class StripeConstants
public const string SecretsManagerStandalone = "sm-standalone";
}
public static class PaymentMethodTypes
{
public const string Card = "card";
public const string USBankAccount = "us_bank_account";
}
public static class ProrationBehavior
{
public const string AlwaysInvoice = "always_invoice";

View File

@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Enums;
using Stripe;
namespace Bit.Core.Billing.Extensions;
@ -26,6 +27,20 @@ public static class BillingExtensions
=> !string.IsNullOrEmpty(organization.GatewayCustomerId) &&
!string.IsNullOrEmpty(organization.GatewaySubscriptionId);
public static bool IsUnverifiedBankAccount(this SetupIntent setupIntent) =>
setupIntent is
{
Status: "requires_action",
NextAction:
{
VerifyWithMicrodeposits: not null
},
PaymentMethod:
{
UsBankAccount: not null
}
};
public static bool SupportsConsolidatedBilling(this PlanType planType)
=> planType is PlanType.TeamsMonthly or PlanType.EnterpriseMonthly;
}

View File

@ -1,4 +1,6 @@
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches.Implementations;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
namespace Bit.Core.Billing.Extensions;
@ -10,6 +12,7 @@ public static class ServiceCollectionExtensions
public static void AddBillingOperations(this IServiceCollection services)
{
services.AddTransient<IOrganizationBillingService, OrganizationBillingService>();
services.AddTransient<ISetupIntentCache, SetupIntentDistributedCache>();
services.AddTransient<ISubscriberService, SubscriberService>();
}
}

View File

@ -2,6 +2,6 @@
namespace Bit.Core.Billing.Models;
public record ProviderSubscriptionDTO(
public record ConsolidatedBillingSubscriptionDTO(
List<ConfiguredProviderPlanDTO> ProviderPlans,
Subscription Subscription);

View File

@ -0,0 +1,156 @@
using Bit.Core.Billing.Extensions;
using Bit.Core.Enums;
namespace Bit.Core.Billing.Models;
public record MaskedPaymentMethodDTO(
PaymentMethodType Type,
string Description,
bool NeedsVerification)
{
public static MaskedPaymentMethodDTO From(Stripe.Customer customer)
{
var defaultPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod;
if (defaultPaymentMethod == null)
{
return customer.DefaultSource != null ? FromStripeLegacyPaymentSource(customer.DefaultSource) : null;
}
return defaultPaymentMethod.Type switch
{
"card" => FromStripeCardPaymentMethod(defaultPaymentMethod.Card),
"us_bank_account" => FromStripeBankAccountPaymentMethod(defaultPaymentMethod.UsBankAccount),
_ => null
};
}
public static MaskedPaymentMethodDTO From(Stripe.SetupIntent setupIntent)
{
if (!setupIntent.IsUnverifiedBankAccount())
{
return null;
}
var bankAccount = setupIntent.PaymentMethod.UsBankAccount;
var description = $"{bankAccount.BankName}, *{bankAccount.Last4}";
return new MaskedPaymentMethodDTO(
PaymentMethodType.BankAccount,
description,
true);
}
public static MaskedPaymentMethodDTO From(Braintree.Customer customer)
{
var defaultPaymentMethod = customer.DefaultPaymentMethod;
if (defaultPaymentMethod == null)
{
return null;
}
switch (defaultPaymentMethod)
{
case Braintree.PayPalAccount payPalAccount:
{
return new MaskedPaymentMethodDTO(
PaymentMethodType.PayPal,
payPalAccount.Email,
false);
}
case Braintree.CreditCard creditCard:
{
var paddedExpirationMonth = creditCard.ExpirationMonth.PadLeft(2, '0');
var description =
$"{creditCard.CardType}, *{creditCard.LastFour}, {paddedExpirationMonth}/{creditCard.ExpirationYear}";
return new MaskedPaymentMethodDTO(
PaymentMethodType.Card,
description,
false);
}
case Braintree.UsBankAccount bankAccount:
{
return new MaskedPaymentMethodDTO(
PaymentMethodType.BankAccount,
$"{bankAccount.BankName}, *{bankAccount.Last4}",
false);
}
default:
{
return null;
}
}
}
private static MaskedPaymentMethodDTO FromStripeBankAccountPaymentMethod(
Stripe.PaymentMethodUsBankAccount bankAccount)
{
var description = $"{bankAccount.BankName}, *{bankAccount.Last4}";
return new MaskedPaymentMethodDTO(
PaymentMethodType.BankAccount,
description,
false);
}
private static MaskedPaymentMethodDTO FromStripeCardPaymentMethod(Stripe.PaymentMethodCard card)
=> new(
PaymentMethodType.Card,
GetCardDescription(card.Brand, card.Last4, card.ExpMonth, card.ExpYear),
false);
#region Legacy Source Payments
private static MaskedPaymentMethodDTO FromStripeLegacyPaymentSource(Stripe.IPaymentSource paymentSource)
=> paymentSource switch
{
Stripe.BankAccount bankAccount => FromStripeBankAccountLegacySource(bankAccount),
Stripe.Card card => FromStripeCardLegacySource(card),
Stripe.Source { Card: not null } source => FromStripeSourceCardLegacySource(source.Card),
_ => null
};
private static MaskedPaymentMethodDTO FromStripeBankAccountLegacySource(Stripe.BankAccount bankAccount)
{
var status = bankAccount.Status switch
{
"verified" => "Verified",
"errored" => "Invalid",
"verification_failed" => "Verification failed",
_ => "Unverified"
};
var description = $"{bankAccount.BankName}, *{bankAccount.Last4} - {status}";
var needsVerification = bankAccount.Status is "new" or "validated";
return new MaskedPaymentMethodDTO(
PaymentMethodType.BankAccount,
description,
needsVerification);
}
private static MaskedPaymentMethodDTO FromStripeCardLegacySource(Stripe.Card card)
=> new(
PaymentMethodType.Card,
GetCardDescription(card.Brand, card.Last4, card.ExpMonth, card.ExpYear),
false);
private static MaskedPaymentMethodDTO FromStripeSourceCardLegacySource(Stripe.SourceCard card)
=> new(
PaymentMethodType.Card,
GetCardDescription(card.Brand, card.Last4, card.ExpMonth, card.ExpYear),
false);
#endregion
private static string GetCardDescription(
string brand,
string last4,
long expirationMonth,
long expirationYear) => $"{brand.ToUpperInvariant()}, *{last4}, {expirationMonth:00}/{expirationYear}";
}

View File

@ -0,0 +1,6 @@
namespace Bit.Core.Billing.Models;
public record PaymentInformationDTO(
long AccountCredit,
MaskedPaymentMethodDTO PaymentMethod,
TaxInformationDTO TaxInformation);

View File

@ -1,6 +0,0 @@
using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Models;
public record ProviderPaymentInfoDTO(BillingInfo.BillingSource billingSource,
TaxInfo taxInfo);

View File

@ -0,0 +1,149 @@
namespace Bit.Core.Billing.Models;
public record TaxInformationDTO(
string Country,
string PostalCode,
string TaxId,
string Line1,
string Line2,
string City,
string State)
{
public string GetTaxIdType()
{
if (string.IsNullOrEmpty(Country) || string.IsNullOrEmpty(TaxId))
{
return null;
}
switch (Country.ToUpper())
{
case "AD":
return "ad_nrt";
case "AE":
return "ae_trn";
case "AR":
return "ar_cuit";
case "AU":
return "au_abn";
case "BO":
return "bo_tin";
case "BR":
return "br_cnpj";
case "CA":
// May break for those in Québec given the assumption of QST
if (State?.Contains("bec") ?? false)
{
return "ca_qst";
}
return "ca_bn";
case "CH":
return "ch_vat";
case "CL":
return "cl_tin";
case "CN":
return "cn_tin";
case "CO":
return "co_nit";
case "CR":
return "cr_tin";
case "DO":
return "do_rcn";
case "EC":
return "ec_ruc";
case "EG":
return "eg_tin";
case "GE":
return "ge_vat";
case "ID":
return "id_npwp";
case "IL":
return "il_vat";
case "IS":
return "is_vat";
case "KE":
return "ke_pin";
case "AT":
case "BE":
case "BG":
case "CY":
case "CZ":
case "DE":
case "DK":
case "EE":
case "ES":
case "FI":
case "FR":
case "GB":
case "GR":
case "HR":
case "HU":
case "IE":
case "IT":
case "LT":
case "LU":
case "LV":
case "MT":
case "NL":
case "PL":
case "PT":
case "RO":
case "SE":
case "SI":
case "SK":
return "eu_vat";
case "HK":
return "hk_br";
case "IN":
return "in_gst";
case "JP":
return "jp_cn";
case "KR":
return "kr_brn";
case "LI":
return "li_uid";
case "MX":
return "mx_rfc";
case "MY":
return "my_sst";
case "NO":
return "no_vat";
case "NZ":
return "nz_gst";
case "PE":
return "pe_ruc";
case "PH":
return "ph_tin";
case "RS":
return "rs_pib";
case "RU":
return "ru_inn";
case "SA":
return "sa_vat";
case "SG":
return "sg_gst";
case "SV":
return "sv_nit";
case "TH":
return "th_vat";
case "TR":
return "tr_tin";
case "TW":
return "tw_vat";
case "UA":
return "ua_vat";
case "US":
return "us_ein";
case "UY":
return "uy_ruc";
case "VE":
return "ve_rif";
case "VN":
return "vn_tin";
case "ZA":
return "za_vat";
default:
return null;
}
}
}

View File

@ -0,0 +1,7 @@
using Bit.Core.Enums;
namespace Bit.Core.Billing.Models;
public record TokenizedPaymentMethodDTO(
PaymentMethodType Type,
string Token);

View File

@ -56,13 +56,13 @@ public interface IProviderBillingService
PlanType planType);
/// <summary>
/// Retrieves a provider's billing subscription data.
/// Retrieves the <paramref name="provider"/>'s consolidated billing subscription, which includes their Stripe subscription and configured provider plans.
/// </summary>
/// <param name="providerId">The ID of the provider to retrieve subscription data for.</param>
/// <returns>A <see cref="ProviderSubscriptionDTO"/> object containing the provider's Stripe <see cref="Stripe.Subscription"/> and their <see cref="ConfiguredProviderPlanDTO"/>s.</returns>
/// <param name="provider">The provider to retrieve the consolidated billing subscription for.</param>
/// <returns>A <see cref="ConsolidatedBillingSubscriptionDTO"/> containing the provider's Stripe <see cref="Stripe.Subscription"/> and a list of <see cref="ConfiguredProviderPlanDTO"/>s representing their configured plans.</returns>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<ProviderSubscriptionDTO> GetSubscriptionDTO(
Guid providerId);
Task<ConsolidatedBillingSubscriptionDTO> GetConsolidatedBillingSubscription(
Provider provider);
/// <summary>
/// Scales the <paramref name="provider"/>'s seats for the specified <paramref name="planType"/> using the provided <paramref name="seatAdjustment"/>.
@ -85,12 +85,4 @@ public interface IProviderBillingService
/// <param name="provider">The provider to create the <see cref="Stripe.Subscription"/> for.</param>
Task StartSubscription(
Provider provider);
/// <summary>
/// Retrieves a provider's billing payment information.
/// </summary>
/// <param name="providerId">The ID of the provider to retrieve payment information for.</param>
/// <returns>A <see cref="ProviderPaymentInfoDTO"/> object containing the provider's Stripe <see cref="Stripe.PaymentMethod"/> and their <see cref="TaxInfo"/>s.</returns>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<ProviderPaymentInfoDTO> GetPaymentInformationAsync(Guid providerId);
}

View File

@ -1,6 +1,6 @@
using Bit.Core.Billing.Models;
using Bit.Core.Entities;
using Bit.Core.Models.Business;
using Bit.Core.Enums;
using Stripe;
namespace Bit.Core.Billing.Services;
@ -46,6 +46,24 @@ public interface ISubscriberService
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null);
/// <summary>
/// Retrieves the account credit, a masked representation of the default payment method and the tax information for the
/// provided <paramref name="subscriber"/>. This is essentially a consolidated invocation of the <see cref="GetPaymentMethod"/>
/// and <see cref="GetTaxInformation"/> methods with a response that includes the customer's <see cref="Stripe.Customer.Balance"/> as account credit in order to cut down on Stripe API calls.
/// </summary>
/// <param name="subscriber">The subscriber to retrieve payment information for.</param>
/// <returns>A <see cref="PaymentInformationDTO"/> containing the subscriber's account credit, masked payment method and tax information.</returns>
Task<PaymentInformationDTO> GetPaymentInformation(
ISubscriber subscriber);
/// <summary>
/// Retrieves a masked representation of the subscriber's payment method for presentation to a client.
/// </summary>
/// <param name="subscriber">The subscriber to retrieve the masked payment method for.</param>
/// <returns>A <see cref="MaskedPaymentMethodDTO"/> containing a non-identifiable description of the subscriber's payment method.</returns>
Task<MaskedPaymentMethodDTO> GetPaymentMethod(
ISubscriber subscriber);
/// <summary>
/// Retrieves a Stripe <see cref="Subscription"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewaySubscriptionId"/> property.
/// </summary>
@ -71,6 +89,16 @@ public interface ISubscriberService
ISubscriber subscriber,
SubscriptionGetOptions subscriptionGetOptions = null);
/// <summary>
/// Retrieves the <see cref="subscriber"/>'s tax information using their Stripe <see cref="Stripe.Customer"/>'s <see cref="Stripe.Customer.Address"/>.
/// </summary>
/// <param name="subscriber">The subscriber to retrieve the tax information for.</param>
/// <returns>A <see cref="TaxInformationDTO"/> representing the <paramref name="subscriber"/>'s tax information.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<TaxInformationDTO> GetTaxInformation(
ISubscriber subscriber);
/// <summary>
/// Attempts to remove a subscriber's saved payment method. If the Stripe <see cref="Stripe.Customer"/> representing the
/// <paramref name="subscriber"/> contains a valid <b>"btCustomerId"</b> key in its <see cref="Stripe.Customer.Metadata"/> property,
@ -81,20 +109,34 @@ public interface ISubscriberService
Task RemovePaymentMethod(ISubscriber subscriber);
/// <summary>
/// Retrieves a Stripe <see cref="TaxInfo"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// Updates the payment method for the provided <paramref name="subscriber"/> using the <paramref name="tokenizedPaymentMethod"/>.
/// The following payment method types are supported: [<see cref="PaymentMethodType.Card"/>, <see cref="PaymentMethodType.BankAccount"/>, <see cref="PaymentMethodType.PayPal"/>].
/// For each type, updating the payment method will attempt to establish a new payment method using the token in the <see cref="TokenizedPaymentMethodDTO"/>. Then, it will
/// remove the exising payment method(s) linked to the subscriber's customer.
/// </summary>
/// <param name="subscriber">The subscriber to retrieve the Stripe customer for.</param>
/// <returns>A Stripe <see cref="TaxInfo"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<TaxInfo> GetTaxInformationAsync(ISubscriber subscriber);
/// <param name="subscriber">The subscriber to update the payment method for.</param>
/// <param name="tokenizedPaymentMethod">A DTO representing a tokenized payment method.</param>
Task UpdatePaymentMethod(
ISubscriber subscriber,
TokenizedPaymentMethodDTO tokenizedPaymentMethod);
/// <summary>
/// Retrieves a Stripe <see cref="BillingInfo.BillingSource"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// Updates the tax information for the provided <paramref name="subscriber"/>.
/// </summary>
/// <param name="subscriber">The subscriber to retrieve the Stripe customer for.</param>
/// <returns>A Stripe <see cref="BillingInfo.BillingSource"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<BillingInfo.BillingSource> GetPaymentMethodAsync(ISubscriber subscriber);
/// <param name="subscriber">The <paramref name="subscriber"/> to update the tax information for.</param>
/// <param name="taxInformation">A <see cref="TaxInformationDTO"/> representing the <paramref name="subscriber"/>'s updated tax information.</param>
Task UpdateTaxInformation(
ISubscriber subscriber,
TaxInformationDTO taxInformation);
/// <summary>
/// Verifies the subscriber's pending bank account using the provided <paramref name="microdeposits"/>.
/// </summary>
/// <param name="subscriber">The subscriber to verify the bank account for.</param>
/// <param name="microdeposits">Deposits made to the subscriber's bank account in order to ensure they have access to it.
/// <a href="https://docs.stripe.com/payments/ach-debit/set-up-payment">Learn more.</a></param>
/// <returns></returns>
Task VerifyBankAccount(
ISubscriber subscriber,
(long, long) microdeposits);
}

View File

@ -1,7 +1,10 @@
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Models;
using Bit.Core.Entities;
using Bit.Core.Models.Business;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Braintree;
using Microsoft.Extensions.Logging;
using Stripe;
@ -14,7 +17,9 @@ namespace Bit.Core.Billing.Services.Implementations;
public class SubscriberService(
IBraintreeGateway braintreeGateway,
IGlobalSettings globalSettings,
ILogger<SubscriberService> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter) : ISubscriberService
{
public async Task CancelSubscription(
@ -132,6 +137,46 @@ public class SubscriberService(
}
}
public async Task<PaymentInformationDTO> GetPaymentInformation(
ISubscriber subscriber)
{
ArgumentNullException.ThrowIfNull(subscriber);
var customer = await GetCustomer(subscriber, new CustomerGetOptions
{
Expand = ["default_source", "invoice_settings.default_payment_method", "tax_ids"]
});
if (customer == null)
{
return null;
}
var accountCredit = customer.Balance * -1 / 100;
var paymentMethod = await GetMaskedPaymentMethodDTOAsync(subscriber.Id, customer);
var taxInformation = GetTaxInformationDTOFrom(customer);
return new PaymentInformationDTO(
accountCredit,
paymentMethod,
taxInformation);
}
public async Task<MaskedPaymentMethodDTO> GetPaymentMethod(
ISubscriber subscriber)
{
ArgumentNullException.ThrowIfNull(subscriber);
var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions
{
Expand = ["default_source", "invoice_settings.default_payment_method"]
});
return await GetMaskedPaymentMethodDTOAsync(subscriber.Id, customer);
}
public async Task<Customer> GetCustomerOrThrow(
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null)
@ -240,6 +285,16 @@ public class SubscriberService(
}
}
public async Task<TaxInformationDTO> GetTaxInformation(
ISubscriber subscriber)
{
ArgumentNullException.ThrowIfNull(subscriber);
var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions { Expand = ["tax_ids"] });
return GetTaxInformationDTOFrom(customer);
}
public async Task RemovePaymentMethod(
ISubscriber subscriber)
{
@ -332,113 +387,438 @@ public class SubscriberService(
}
}
public async Task<TaxInfo> GetTaxInformationAsync(ISubscriber subscriber)
public async Task UpdatePaymentMethod(
ISubscriber subscriber,
TokenizedPaymentMethodDTO tokenizedPaymentMethod)
{
ArgumentNullException.ThrowIfNull(subscriber);
ArgumentNullException.ThrowIfNull(tokenizedPaymentMethod);
if (string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
var customer = await GetCustomerOrThrow(subscriber);
var (type, token) = tokenizedPaymentMethod;
if (string.IsNullOrEmpty(token))
{
logger.LogError("Cannot retrieve GatewayCustomerId for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId));
logger.LogError("Updated payment method for ({SubscriberID}) must contain a token", subscriber.Id);
return null;
throw ContactSupport();
}
var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions { Expand = ["tax_ids"] });
if (customer is null)
// ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault
switch (type)
{
logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})",
subscriber.GatewayCustomerId, subscriber.Id);
case PaymentMethodType.BankAccount:
{
var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.SetupIntentList(new SetupIntentListOptions
{
PaymentMethod = token
});
return null;
var getExistingSetupIntentsForCustomer = stripeAdapter.SetupIntentList(new SetupIntentListOptions
{
Customer = subscriber.GatewayCustomerId
});
// Find the setup intent for the incoming payment method token.
var setupIntentsForUpdatedPaymentMethod = await getSetupIntentsForUpdatedPaymentMethod;
if (setupIntentsForUpdatedPaymentMethod.Count != 1)
{
logger.LogError("There were more than 1 setup intents for subscriber's ({SubscriberID}) updated payment method", subscriber.Id);
throw ContactSupport();
}
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
// Find the customer's existing setup intents that should be cancelled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
// Store the incoming payment method's setup intent ID in the cache for the subscriber so it can be verified later.
await setupIntentCache.Set(subscriber.Id, matchingSetupIntent.Id);
// Cancel the customer's other open setup intents.
var postProcessing = existingSetupIntentsForCustomer.Select(si =>
stripeAdapter.SetupIntentCancel(si.Id,
new SetupIntentCancelOptions { CancellationReason = "abandoned" })).ToList();
// Remove the customer's other attached Stripe payment methods.
postProcessing.Add(RemoveStripePaymentMethodsAsync(customer));
// Remove the customer's Braintree customer ID.
postProcessing.Add(RemoveBraintreeCustomerIdAsync(customer));
await Task.WhenAll(postProcessing);
break;
}
case PaymentMethodType.Card:
{
var getExistingSetupIntentsForCustomer = stripeAdapter.SetupIntentList(new SetupIntentListOptions
{
Customer = subscriber.GatewayCustomerId
});
// Remove the customer's other attached Stripe payment methods.
await RemoveStripePaymentMethodsAsync(customer);
// Attach the incoming payment method.
await stripeAdapter.PaymentMethodAttachAsync(token,
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
// Find the customer's existing setup intents that should be cancelled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
// Cancel the customer's other open setup intents.
var postProcessing = existingSetupIntentsForCustomer.Select(si =>
stripeAdapter.SetupIntentCancel(si.Id,
new SetupIntentCancelOptions { CancellationReason = "abandoned" })).ToList();
var metadata = customer.Metadata;
if (metadata.ContainsKey(BraintreeCustomerIdKey))
{
metadata[BraintreeCustomerIdKey] = null;
}
// Set the customer's default payment method in Stripe and remove their Braintree customer ID.
postProcessing.Add(stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = token
},
Metadata = metadata
}));
await Task.WhenAll(postProcessing);
break;
}
case PaymentMethodType.PayPal:
{
string braintreeCustomerId;
if (customer.Metadata != null)
{
var hasBraintreeCustomerId = customer.Metadata.TryGetValue(BraintreeCustomerIdKey, out braintreeCustomerId);
if (hasBraintreeCustomerId)
{
var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);
if (braintreeCustomer == null)
{
logger.LogError("Failed to retrieve Braintree customer ({BraintreeCustomerId}) when updating payment method for subscriber ({SubscriberID})", braintreeCustomerId, subscriber.Id);
throw ContactSupport();
}
await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token);
return;
}
}
braintreeCustomerId = await CreateBraintreeCustomerAsync(subscriber, token);
await AddBraintreeCustomerIdAsync(customer, braintreeCustomerId);
break;
}
default:
{
logger.LogError("Cannot update subscriber's ({SubscriberID}) payment method to type ({PaymentMethodType}) as it is not supported", subscriber.Id, type.ToString());
throw ContactSupport();
}
}
var address = customer.Address;
// Line1 is required, so if missing we're using the subscriber name
// see: https://stripe.com/docs/api/customers/create#create_customer-address-line1
if (address is not null && string.IsNullOrWhiteSpace(address.Line1))
{
address.Line1 = null;
}
return MapToTaxInfo(customer);
}
public async Task<BillingInfo.BillingSource> GetPaymentMethodAsync(ISubscriber subscriber)
public async Task UpdateTaxInformation(
ISubscriber subscriber,
TaxInformationDTO taxInformation)
{
ArgumentNullException.ThrowIfNull(subscriber);
var customer = await GetCustomerOrThrow(subscriber, GetCustomerPaymentOptions());
if (customer == null)
ArgumentNullException.ThrowIfNull(taxInformation);
var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions
{
Expand = ["tax_ids"]
});
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Address = new AddressOptions
{
Country = taxInformation.Country,
PostalCode = taxInformation.PostalCode,
Line1 = taxInformation.Line1 ?? string.Empty,
Line2 = taxInformation.Line2,
City = taxInformation.City,
State = taxInformation.State
}
});
if (!subscriber.IsUser())
{
var taxId = customer.TaxIds?.FirstOrDefault();
if (taxId != null)
{
await stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id);
}
var taxIdType = taxInformation.GetTaxIdType();
if (!string.IsNullOrWhiteSpace(taxInformation.TaxId) &&
!string.IsNullOrWhiteSpace(taxIdType))
{
await stripeAdapter.TaxIdCreateAsync(customer.Id, new TaxIdCreateOptions
{
Type = taxIdType,
Value = taxInformation.TaxId,
});
}
}
}
public async Task VerifyBankAccount(
ISubscriber subscriber,
(long, long) microdeposits)
{
ArgumentNullException.ThrowIfNull(subscriber);
var setupIntentId = await setupIntentCache.Get(subscriber.Id);
if (string.IsNullOrEmpty(setupIntentId))
{
logger.LogError("No setup intent ID exists to verify for subscriber with ID ({SubscriberID})", subscriber.Id);
throw ContactSupport();
}
var (amount1, amount2) = microdeposits;
await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId, new SetupIntentVerifyMicrodepositsOptions
{
Amounts = [amount1, amount2]
});
var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId);
await stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId, new PaymentMethodAttachOptions
{
Customer = subscriber.GatewayCustomerId
});
await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId,
new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = setupIntent.PaymentMethodId
}
});
}
#region Shared Utilities
private async Task AddBraintreeCustomerIdAsync(
Customer customer,
string braintreeCustomerId)
{
var metadata = customer.Metadata ?? new Dictionary<string, string>();
metadata[BraintreeCustomerIdKey] = braintreeCustomerId;
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Metadata = metadata
});
}
private async Task<string> CreateBraintreeCustomerAsync(
ISubscriber subscriber,
string paymentMethodNonce)
{
var braintreeCustomerId =
subscriber.BraintreeCustomerIdPrefix() +
subscriber.Id.ToString("N").ToLower() +
CoreHelpers.RandomString(3, upper: false, numeric: false);
var customerResult = await braintreeGateway.Customer.CreateAsync(new CustomerRequest
{
Id = braintreeCustomerId,
CustomFields = new Dictionary<string, string>
{
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
[subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
},
Email = subscriber.BillingEmailAddress(),
PaymentMethodNonce = paymentMethodNonce,
});
if (customerResult.IsSuccess())
{
return customerResult.Target.Id;
}
logger.LogError("Failed to create Braintree customer for subscriber ({ID})", subscriber.Id);
throw ContactSupport();
}
private async Task<MaskedPaymentMethodDTO> GetMaskedPaymentMethodDTOAsync(
Guid subscriberId,
Customer customer)
{
if (customer.Metadata != null)
{
var hasBraintreeCustomerId = customer.Metadata.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId);
if (hasBraintreeCustomerId)
{
var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);
return MaskedPaymentMethodDTO.From(braintreeCustomer);
}
}
var attachedPaymentMethodDTO = MaskedPaymentMethodDTO.From(customer);
if (attachedPaymentMethodDTO != null)
{
return attachedPaymentMethodDTO;
}
/*
* attachedPaymentMethodDTO being null represents a case where we could be looking for the SetupIntent for an unverified "us_bank_account".
* We store the ID of this SetupIntent in the cache when we originally update the payment method.
*/
var setupIntentId = await setupIntentCache.Get(subscriberId);
if (string.IsNullOrEmpty(setupIntentId))
{
logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})",
subscriber.GatewayCustomerId, subscriber.Id);
return null;
}
if (customer.Metadata?.ContainsKey("btCustomerId") ?? false)
var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions
{
try
Expand = ["payment_method"]
});
return MaskedPaymentMethodDTO.From(setupIntent);
}
private static TaxInformationDTO GetTaxInformationDTOFrom(
Customer customer)
{
if (customer.Address == null)
{
return null;
}
return new TaxInformationDTO(
customer.Address.Country,
customer.Address.PostalCode,
customer.TaxIds?.FirstOrDefault()?.Value,
customer.Address.Line1,
customer.Address.Line2,
customer.Address.City,
customer.Address.State);
}
private async Task RemoveBraintreeCustomerIdAsync(
Customer customer)
{
var metadata = customer.Metadata ?? new Dictionary<string, string>();
if (metadata.ContainsKey(BraintreeCustomerIdKey))
{
metadata[BraintreeCustomerIdKey] = null;
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
var braintreeCustomer = await braintreeGateway.Customer.FindAsync(
customer.Metadata["btCustomerId"]);
if (braintreeCustomer?.DefaultPaymentMethod != null)
Metadata = metadata
});
}
}
private async Task RemoveStripePaymentMethodsAsync(
Customer customer)
{
if (customer.Sources != null && customer.Sources.Any())
{
foreach (var source in customer.Sources)
{
switch (source)
{
return new BillingInfo.BillingSource(
braintreeCustomer.DefaultPaymentMethod);
case BankAccount:
await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id);
break;
case Card:
await stripeAdapter.CardDeleteAsync(customer.Id, source.Id);
break;
}
}
catch (Braintree.Exceptions.NotFoundException ex)
}
var paymentMethods = await stripeAdapter.CustomerListPaymentMethods(customer.Id);
await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.PaymentMethodDetachAsync(pm.Id)));
}
private async Task ReplaceBraintreePaymentMethodAsync(
Braintree.Customer customer,
string defaultPaymentMethodToken)
{
var existingDefaultPaymentMethod = customer.DefaultPaymentMethod;
var createPaymentMethodResult = await braintreeGateway.PaymentMethod.CreateAsync(new PaymentMethodRequest
{
CustomerId = customer.Id,
PaymentMethodNonce = defaultPaymentMethodToken
});
if (!createPaymentMethodResult.IsSuccess())
{
logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Creation of new payment method failed | Error: {Error}", customer.Id, createPaymentMethodResult.Message);
throw ContactSupport();
}
var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync(
customer.Id,
new CustomerRequest { DefaultPaymentMethodToken = createPaymentMethodResult.Target.Token });
if (!updateCustomerResult.IsSuccess())
{
logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Customer update failed | Error: {Error}",
customer.Id, updateCustomerResult.Message);
await braintreeGateway.PaymentMethod.DeleteAsync(createPaymentMethodResult.Target.Token);
throw ContactSupport();
}
if (existingDefaultPaymentMethod != null)
{
var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token);
if (!deletePaymentMethodResult.IsSuccess())
{
logger.LogError("An error occurred while trying to retrieve braintree customer ({SubscriberID}): {Error}", subscriber.Id, ex.Message);
logger.LogWarning(
"Failed to delete replaced payment method for Braintree customer ({ID}) - outdated payment method still exists | Error: {Error}",
customer.Id, deletePaymentMethodResult.Message);
}
}
if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card")
{
return new BillingInfo.BillingSource(
customer.InvoiceSettings.DefaultPaymentMethod);
}
if (customer.DefaultSource != null &&
(customer.DefaultSource is Card || customer.DefaultSource is BankAccount))
{
return new BillingInfo.BillingSource(customer.DefaultSource);
}
var paymentMethod = GetLatestCardPaymentMethod(customer.Id);
return paymentMethod != null ? new BillingInfo.BillingSource(paymentMethod) : null;
}
private static CustomerGetOptions GetCustomerPaymentOptions()
{
var customerOptions = new CustomerGetOptions();
customerOptions.AddExpand("default_source");
customerOptions.AddExpand("invoice_settings.default_payment_method");
return customerOptions;
}
private Stripe.PaymentMethod GetLatestCardPaymentMethod(string customerId)
{
var cardPaymentMethods = stripeAdapter.PaymentMethodListAutoPaging(
new PaymentMethodListOptions { Customer = customerId, Type = "card" });
return cardPaymentMethods.MaxBy(m => m.Created);
}
private TaxInfo MapToTaxInfo(Customer customer)
{
var address = customer.Address;
var taxId = customer.TaxIds?.FirstOrDefault();
return new TaxInfo
{
TaxIdNumber = taxId?.Value,
BillingAddressLine1 = address?.Line1,
BillingAddressLine2 = address?.Line2,
BillingAddressCity = address?.City,
BillingAddressState = address?.State,
BillingAddressPostalCode = address?.PostalCode,
BillingAddressCountry = address?.Country,
};
}
#endregion
}

View File

@ -49,6 +49,7 @@ public interface IPaymentService
Task<BillingInfo> GetBillingHistoryAsync(ISubscriber subscriber);
Task<BillingInfo> GetBillingBalanceAndSourceAsync(ISubscriber subscriber);
Task<SubscriptionInfo> GetSubscriptionAsync(ISubscriber subscriber);
Task<TaxInfo> GetTaxInfoAsync(ISubscriber subscriber);
Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo);
Task<TaxRate> CreateTaxRateAsync(TaxRate taxRate);
Task UpdateTaxRateAsync(TaxRate taxRate);

View File

@ -9,6 +9,7 @@ public interface IStripeAdapter
Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null);
Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null);
Task<Stripe.Customer> CustomerDeleteAsync(string id);
Task<List<PaymentMethod>> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null);
Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions);
Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null);
Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions);
@ -38,5 +39,10 @@ public interface IStripeAdapter
Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null);
Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null);
Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null);
Task<SetupIntent> SetupIntentCreate(SetupIntentCreateOptions options);
Task<List<SetupIntent>> SetupIntentList(SetupIntentListOptions options);
Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null);
Task<SetupIntent> SetupIntentGet(string id, SetupIntentGetOptions options = null);
Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options);
Task<List<Stripe.TestHelpers.TestClock>> TestClockListAsync();
}

View File

@ -16,6 +16,7 @@ public class StripeAdapter : IStripeAdapter
private readonly Stripe.CardService _cardService;
private readonly Stripe.BankAccountService _bankAccountService;
private readonly Stripe.PriceService _priceService;
private readonly Stripe.SetupIntentService _setupIntentService;
private readonly Stripe.TestHelpers.TestClockService _testClockService;
public StripeAdapter()
@ -31,6 +32,7 @@ public class StripeAdapter : IStripeAdapter
_cardService = new Stripe.CardService();
_bankAccountService = new Stripe.BankAccountService();
_priceService = new Stripe.PriceService();
_setupIntentService = new SetupIntentService();
_testClockService = new Stripe.TestHelpers.TestClockService();
}
@ -54,6 +56,13 @@ public class StripeAdapter : IStripeAdapter
return _customerService.DeleteAsync(id);
}
public async Task<List<PaymentMethod>> CustomerListPaymentMethods(string id,
CustomerListPaymentMethodsOptions options = null)
{
var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options);
return paymentMethods.Data;
}
public Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options)
{
return _subscriptionService.CreateAsync(options);
@ -222,6 +231,25 @@ public class StripeAdapter : IStripeAdapter
return await _priceService.ListAsync(options);
}
public Task<SetupIntent> SetupIntentCreate(SetupIntentCreateOptions options)
=> _setupIntentService.CreateAsync(options);
public async Task<List<SetupIntent>> SetupIntentList(SetupIntentListOptions options)
{
var setupIntents = await _setupIntentService.ListAsync(options);
return setupIntents.Data;
}
public Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null)
=> _setupIntentService.CancelAsync(id, options);
public Task<SetupIntent> SetupIntentGet(string id, SetupIntentGetOptions options = null)
=> _setupIntentService.GetAsync(id, options);
public Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options)
=> _setupIntentService.VerifyMicrodepositsAsync(id, options);
public async Task<List<Stripe.TestHelpers.TestClock>> TestClockListAsync()
{
var items = new List<Stripe.TestHelpers.TestClock>();

View File

@ -1651,6 +1651,43 @@ public class StripePaymentService : IPaymentService
return subscriptionInfo;
}
public async Task<TaxInfo> GetTaxInfoAsync(ISubscriber subscriber)
{
if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
{
return null;
}
var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId,
new CustomerGetOptions { Expand = ["tax_ids"] });
if (customer == null)
{
return null;
}
var address = customer.Address;
var taxId = customer.TaxIds?.FirstOrDefault();
// Line1 is required, so if missing we're using the subscriber name
// see: https://stripe.com/docs/api/customers/create#create_customer-address-line1
if (address != null && string.IsNullOrWhiteSpace(address.Line1))
{
address.Line1 = null;
}
return new TaxInfo
{
TaxIdNumber = taxId?.Value,
BillingAddressLine1 = address?.Line1,
BillingAddressLine2 = address?.Line2,
BillingAddressCity = address?.City,
BillingAddressState = address?.State,
BillingAddressPostalCode = address?.PostalCode,
BillingAddressCountry = address?.Country,
};
}
public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo)
{
if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))

View File

@ -1,6 +1,11 @@
using Bit.Api.Billing.Controllers;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
@ -21,6 +26,7 @@ namespace Bit.Api.Test.Billing.Controllers;
[SutProviderCustomize]
public class ProviderBillingControllerTests
{
#region GetSubscriptionAsync
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_FFDisabled_NotFound(
Guid providerId,
@ -35,33 +41,14 @@ public class ProviderBillingControllerTests
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized(
public async Task GetSubscriptionAsync_NullProvider_NotFound(
Guid providerId,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(false);
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NoSubscriptionData_NotFound(
Guid providerId,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderBillingService>().GetSubscriptionDTO(providerId).ReturnsNull();
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
@ -69,20 +56,69 @@ public class ProviderBillingControllerTests
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_OK(
Guid providerId,
public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(false);
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_ProviderNotBillable_Unauthorized(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
var configuredProviderPlanDTOList = new List<ConfiguredProviderPlanDTO>
provider.Type = ProviderType.Reseller;
provider.Status = ProviderStatusType.Created;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(false);
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NullConsolidatedBillingSubscription_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
sutProvider.GetDependency<IProviderBillingService>().GetConsolidatedBillingSubscription(provider).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
var configuredProviderPlans = new List<ConfiguredProviderPlanDTO>
{
new (Guid.NewGuid(), providerId, PlanType.TeamsMonthly, 50, 10, 30),
new (Guid.NewGuid(), providerId, PlanType.EnterpriseMonthly, 100, 0, 90)
new (Guid.NewGuid(), provider.Id, PlanType.TeamsMonthly, 50, 10, 30),
new (Guid.NewGuid(), provider.Id , PlanType.EnterpriseMonthly, 100, 0, 90)
};
var subscription = new Subscription
@ -92,25 +128,25 @@ public class ProviderBillingControllerTests
Customer = new Customer { Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } } }
};
var providerSubscriptionDTO = new ProviderSubscriptionDTO(
configuredProviderPlanDTOList,
var consolidatedBillingSubscription = new ConsolidatedBillingSubscriptionDTO(
configuredProviderPlans,
subscription);
sutProvider.GetDependency<IProviderBillingService>().GetSubscriptionDTO(providerId)
.Returns(providerSubscriptionDTO);
sutProvider.GetDependency<IProviderBillingService>().GetConsolidatedBillingSubscription(provider)
.Returns(consolidatedBillingSubscription);
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<Ok<ProviderSubscriptionResponse>>(result);
Assert.IsType<Ok<ConsolidatedBillingSubscriptionResponse>>(result);
var providerSubscriptionResponse = ((Ok<ProviderSubscriptionResponse>)result).Value;
var response = ((Ok<ConsolidatedBillingSubscriptionResponse>)result).Value;
Assert.Equal(providerSubscriptionResponse.Status, subscription.Status);
Assert.Equal(providerSubscriptionResponse.CurrentPeriodEndDate, subscription.CurrentPeriodEnd);
Assert.Equal(providerSubscriptionResponse.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff);
Assert.Equal(response.Status, subscription.Status);
Assert.Equal(response.CurrentPeriodEndDate, subscription.CurrentPeriodEnd);
Assert.Equal(response.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff);
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
var providerTeamsPlan = providerSubscriptionResponse.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name);
var providerTeamsPlan = response.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name);
Assert.NotNull(providerTeamsPlan);
Assert.Equal(50, providerTeamsPlan.SeatMinimum);
Assert.Equal(10, providerTeamsPlan.PurchasedSeats);
@ -119,7 +155,7 @@ public class ProviderBillingControllerTests
Assert.Equal("Monthly", providerTeamsPlan.Cadence);
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
var providerEnterprisePlan = providerSubscriptionResponse.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name);
var providerEnterprisePlan = response.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name);
Assert.NotNull(providerEnterprisePlan);
Assert.Equal(100, providerEnterprisePlan.SeatMinimum);
Assert.Equal(0, providerEnterprisePlan.PurchasedSeats);
@ -127,4 +163,225 @@ public class ProviderBillingControllerTests
Assert.Equal(100 * enterprisePlan.PasswordManager.SeatPrice, providerEnterprisePlan.Cost);
Assert.Equal("Monthly", providerEnterprisePlan.Cadence);
}
#endregion
#region GetPaymentInformationAsync
[Theory, BitAutoData]
public async Task GetPaymentInformation_PaymentInformationNull_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetPaymentInformation(provider).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentInformation_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
var maskedPaymentMethod = new MaskedPaymentMethodDTO(PaymentMethodType.Card, "VISA *1234", false);
var taxInformation =
new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY");
sutProvider.GetDependency<ISubscriberService>().GetPaymentInformation(provider).Returns(new PaymentInformationDTO(
100,
maskedPaymentMethod,
taxInformation));
var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id);
Assert.IsType<Ok<PaymentInformationResponse>>(result);
var response = ((Ok<PaymentInformationResponse>)result).Value;
Assert.Equal(100, response.AccountCredit);
Assert.Equal(maskedPaymentMethod.Description, response.PaymentMethod.Description);
Assert.Equal(taxInformation.TaxId, response.TaxInformation.TaxId);
}
#endregion
#region GetPaymentMethodAsync
[Theory, BitAutoData]
public async Task GetPaymentMethod_PaymentMethodNull_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetPaymentMethod(provider).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentMethod_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetPaymentMethod(provider).Returns(new MaskedPaymentMethodDTO(
PaymentMethodType.Card, "Description", false));
var result = await sutProvider.Sut.GetPaymentMethodAsync(provider.Id);
Assert.IsType<Ok<MaskedPaymentMethodResponse>>(result);
var response = ((Ok<MaskedPaymentMethodResponse>)result).Value;
Assert.Equal(PaymentMethodType.Card, response.Type);
Assert.Equal("Description", response.Description);
Assert.False(response.NeedsVerification);
}
#endregion
#region GetTaxInformationAsync
[Theory, BitAutoData]
public async Task GetTaxInformation_TaxInformationNull_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetTaxInformation(provider).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetTaxInformation_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetTaxInformation(provider).Returns(new TaxInformationDTO(
"US",
"12345",
"123456789",
"123 Example St.",
null,
"Example Town",
"NY"));
var result = await sutProvider.Sut.GetTaxInformationAsync(provider.Id);
Assert.IsType<Ok<TaxInformationResponse>>(result);
var response = ((Ok<TaxInformationResponse>)result).Value;
Assert.Equal("US", response.Country);
Assert.Equal("12345", response.PostalCode);
Assert.Equal("123456789", response.TaxId);
Assert.Equal("123 Example St.", response.Line1);
Assert.Null(response.Line2);
Assert.Equal("Example Town", response.City);
Assert.Equal("NY", response.State);
}
#endregion
#region UpdatePaymentMethodAsync
[Theory, BitAutoData]
public async Task UpdatePaymentMethod_Ok(
Provider provider,
TokenizedPaymentMethodRequestBody requestBody,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
await sutProvider.Sut.UpdatePaymentMethodAsync(provider.Id, requestBody);
await sutProvider.GetDependency<ISubscriberService>().Received(1).UpdatePaymentMethod(
provider, Arg.Is<TokenizedPaymentMethodDTO>(
options => options.Type == requestBody.Type && options.Token == requestBody.Token));
await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider.GatewaySubscriptionId, Arg.Is<SubscriptionUpdateOptions>(
options => options.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically));
}
#endregion
#region UpdateTaxInformationAsync
[Theory, BitAutoData]
public async Task UpdateTaxInformation_Ok(
Provider provider,
TaxInformationRequestBody requestBody,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody);
await sutProvider.GetDependency<ISubscriberService>().Received(1).UpdateTaxInformation(
provider, Arg.Is<TaxInformationDTO>(
options =>
options.Country == requestBody.Country &&
options.PostalCode == requestBody.PostalCode &&
options.TaxId == requestBody.TaxId &&
options.Line1 == requestBody.Line1 &&
options.Line2 == requestBody.Line2 &&
options.City == requestBody.City &&
options.State == requestBody.State));
}
#endregion
#region VerifyBankAccount
[Theory, BitAutoData]
public async Task VerifyBankAccount_Ok(
Provider provider,
VerifyBankAccountRequestBody requestBody,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableInputs(provider, sutProvider);
var result = await sutProvider.Sut.VerifyBankAccountAsync(provider.Id, requestBody);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<ISubscriberService>().Received(1).VerifyBankAccount(
provider,
(requestBody.Amount1, requestBody.Amount2));
}
#endregion
private static void ConfigureStableInputs(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
provider.Type = ProviderType.Msp;
provider.Status = ProviderStatusType.Billable;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(true);
}
}

File diff suppressed because it is too large Load Diff