1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

[AC-1923] Add endpoint to create client organization (#3977)

* Add new endpoint for creating client organizations in consolidated billing

* Create empty org and then assign seats for code re-use

* Fixes made from debugging client side

* few more small fixes

* Vincent's feedback
This commit is contained in:
Alex Morask 2024-04-16 13:55:00 -04:00 committed by GitHub
parent 73e049f878
commit c4ba0dc2a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1462 additions and 441 deletions

View File

@ -38,13 +38,14 @@ public class ProviderService : IProviderService
private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext;
private readonly IStripeAdapter _stripeAdapter;
private readonly IFeatureService _featureService;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
IUserService userService, IOrganizationService organizationService, IMailService mailService,
IDataProtectionProvider dataProtectionProvider, IEventService eventService,
IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext, IStripeAdapter stripeAdapter)
ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService)
{
_providerRepository = providerRepository;
_providerUserRepository = providerUserRepository;
@ -59,6 +60,7 @@ public class ProviderService : IProviderService
_dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
_currentContext = currentContext;
_stripeAdapter = stripeAdapter;
_featureService = featureService;
}
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key)
@ -360,6 +362,7 @@ public class ProviderService : IProviderService
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
ThrowOnInvalidPlanType(organization.PlanType);
if (organization.UseSecretsManager)
@ -507,9 +510,13 @@ public class ProviderService : IProviderService
public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId,
OrganizationSignup organizationSignup, string clientOwnerEmail, User user)
{
ThrowOnInvalidPlanType(organizationSignup.Plan);
var consolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling);
var (organization, _, defaultCollection) = await _organizationService.SignUpAsync(organizationSignup, true);
ThrowOnInvalidPlanType(organizationSignup.Plan, consolidatedBillingEnabled);
var (organization, _, defaultCollection) = consolidatedBillingEnabled
? await _organizationService.SignupClientAsync(organizationSignup)
: await _organizationService.SignUpAsync(organizationSignup, true);
var providerOrganization = new ProviderOrganization
{
@ -611,8 +618,13 @@ public class ProviderService : IProviderService
return confirmedOwnersIds.Except(providerUserIds).Any();
}
private void ThrowOnInvalidPlanType(PlanType requestedType)
private void ThrowOnInvalidPlanType(PlanType requestedType, bool consolidatedBillingEnabled = false)
{
if (consolidatedBillingEnabled && requestedType is not (PlanType.TeamsMonthly or PlanType.EnterpriseMonthly))
{
throw new BadRequestException($"Providers cannot manage organizations with the plan type {requestedType}. Only Teams (Monthly) and Enterprise (Monthly) are allowed.");
}
if (ProviderDisallowedOrganizationTypes.Contains(requestedType))
{
throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed.");

View File

@ -1,5 +1,6 @@
using Bit.Commercial.Core.AdminConsole.Services;
using Bit.Commercial.Core.Test.AdminConsole.AutoFixture;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
@ -638,6 +639,79 @@ public class ProviderServiceTests
t.First().Item2 == null));
}
[Theory, OrganizationCustomize(FlexibleCollections = false), BitAutoData]
public async Task CreateOrganizationAsync_ConsolidatedBillingEnabled_InvalidPlanType_ThrowsBadRequestException(
Provider provider,
OrganizationSignup organizationSignup,
Organization organization,
string clientOwnerEmail,
User user,
SutProvider<ProviderService> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true);
organizationSignup.Plan = PlanType.EnterpriseAnnually;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup)
.Returns((organization, null as OrganizationUser, new Collection()));
await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user));
await providerOrganizationRepository.DidNotReceiveWithAnyArgs().CreateAsync(default);
}
[Theory, OrganizationCustomize(FlexibleCollections = false), BitAutoData]
public async Task CreateOrganizationAsync_ConsolidatedBillingEnabled_InvokeSignupClientAsync(
Provider provider,
OrganizationSignup organizationSignup,
Organization organization,
string clientOwnerEmail,
User user,
SutProvider<ProviderService> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true);
organizationSignup.Plan = PlanType.EnterpriseMonthly;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup)
.Returns((organization, null as OrganizationUser, new Collection()));
var providerOrganization = await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);
await providerOrganizationRepository.Received(1).CreateAsync(Arg.Is<ProviderOrganization>(
po =>
po.ProviderId == provider.Id &&
po.OrganizationId == organization.Id));
await sutProvider.GetDependency<IEventService>()
.Received()
.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created);
await sutProvider.GetDependency<IOrganizationService>()
.Received()
.InviteUsersAsync(
organization.Id,
user.Id,
Arg.Is<IEnumerable<(OrganizationUserInvite, string)>>(
t =>
t.Count() == 1 &&
t.First().Item1.Emails.Count() == 1 &&
t.First().Item1.Emails.First() == clientOwnerEmail &&
t.First().Item1.Type == OrganizationUserType.Owner &&
t.First().Item1.AccessAll &&
!t.First().Item1.Collections.Any() &&
t.First().Item2 == null));
}
[Theory, OrganizationCustomize(FlexibleCollections = true), BitAutoData]
public async Task CreateOrganizationAsync_WithFlexibleCollections_SetsAccessAllToFalse
(Provider provider, OrganizationSignup organizationSignup, Organization organization, string clientOwnerEmail,

View File

@ -1,4 +1,4 @@
using Bit.Api.Billing.Models;
using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.Billing.Queries;
using Bit.Core.Context;
@ -28,17 +28,17 @@ public class ProviderBillingController(
return TypedResults.Unauthorized();
}
var subscriptionData = await providerBillingQueries.GetSubscriptionData(providerId);
var providerSubscriptionDTO = await providerBillingQueries.GetSubscriptionDTO(providerId);
if (subscriptionData == null)
if (providerSubscriptionDTO == null)
{
return TypedResults.NotFound();
}
var (providerPlans, subscription) = subscriptionData;
var (providerPlans, subscription) = providerSubscriptionDTO;
var providerSubscriptionDTO = ProviderSubscriptionDTO.From(providerPlans, subscription);
var providerSubscriptionResponse = ProviderSubscriptionResponse.From(providerPlans, subscription);
return TypedResults.Ok(providerSubscriptionDTO);
return TypedResults.Ok(providerSubscriptionResponse);
}
}

View File

@ -0,0 +1,143 @@
using Bit.Api.Billing.Models.Requests;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Commands;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers;
[Route("providers/{providerId:guid}/clients")]
public class ProviderClientsController(
IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand,
ICreateCustomerCommand createCustomerCommand,
ICurrentContext currentContext,
IFeatureService featureService,
ILogger<ProviderClientsController> logger,
IOrganizationRepository organizationRepository,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderRepository providerRepository,
IProviderService providerService,
IScaleSeatsCommand scaleSeatsCommand,
IUserService userService) : Controller
{
[HttpPost]
public async Task<IResult> CreateAsync(
[FromRoute] Guid providerId,
[FromBody] CreateClientOrganizationRequestBody requestBody)
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
return TypedResults.NotFound();
}
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
return TypedResults.Unauthorized();
}
if (!currentContext.ManageProviderOrganizations(providerId))
{
return TypedResults.Unauthorized();
}
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
return TypedResults.NotFound();
}
var organizationSignup = new OrganizationSignup
{
Name = requestBody.Name,
Plan = requestBody.PlanType,
AdditionalSeats = requestBody.Seats,
Owner = user,
BillingEmail = provider.BillingEmail,
OwnerKey = requestBody.Key,
PublicKey = requestBody.KeyPair.PublicKey,
PrivateKey = requestBody.KeyPair.EncryptedPrivateKey,
CollectionName = requestBody.CollectionName
};
var providerOrganization = await providerService.CreateOrganizationAsync(
providerId,
organizationSignup,
requestBody.OwnerEmail,
user);
var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (clientOrganization == null)
{
logger.LogError("Newly created client organization ({ID}) could not be found", providerOrganization.OrganizationId);
return TypedResults.Problem();
}
await scaleSeatsCommand.ScalePasswordManagerSeats(
provider,
requestBody.PlanType,
requestBody.Seats);
await createCustomerCommand.CreateCustomer(
provider,
clientOrganization);
clientOrganization.Status = OrganizationStatusType.Managed;
await organizationRepository.ReplaceAsync(clientOrganization);
return TypedResults.Ok();
}
[HttpPut("{providerOrganizationId:guid}")]
public async Task<IResult> UpdateAsync(
[FromRoute] Guid providerId,
[FromRoute] Guid providerOrganizationId,
[FromBody] UpdateClientOrganizationRequestBody requestBody)
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
return TypedResults.NotFound();
}
if (!currentContext.ProviderProviderAdmin(providerId))
{
return TypedResults.Unauthorized();
}
var provider = await providerRepository.GetByIdAsync(providerId);
var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId);
if (provider == null || providerOrganization == null)
{
return TypedResults.NotFound();
}
var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (clientOrganization == null)
{
logger.LogError("The client organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id);
return TypedResults.Problem();
}
await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization(
provider,
clientOrganization,
requestBody.AssignedSeats);
return TypedResults.Ok();
}
}

View File

@ -1,63 +0,0 @@
using Bit.Api.Billing.Models;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Commands;
using Bit.Core.Context;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers;
[Route("providers/{providerId:guid}/organizations")]
public class ProviderOrganizationController(
IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand,
ICurrentContext currentContext,
IFeatureService featureService,
ILogger<ProviderOrganizationController> logger,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IProviderOrganizationRepository providerOrganizationRepository) : Controller
{
[HttpPut("{providerOrganizationId:guid}")]
public async Task<IResult> UpdateAsync(
[FromRoute] Guid providerId,
[FromRoute] Guid providerOrganizationId,
[FromBody] UpdateProviderOrganizationRequestBody requestBody)
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
return TypedResults.NotFound();
}
if (!currentContext.ProviderProviderAdmin(providerId))
{
return TypedResults.Unauthorized();
}
var provider = await providerRepository.GetByIdAsync(providerId);
var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId);
if (provider == null || providerOrganization == null)
{
return TypedResults.NotFound();
}
var organization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (organization == null)
{
logger.LogError("The organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id);
return TypedResults.Problem();
}
await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization(
provider,
organization,
requestBody.AssignedSeats);
return TypedResults.Ok();
}
}

View File

@ -0,0 +1,29 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Utilities;
using Bit.Core.Enums;
namespace Bit.Api.Billing.Models.Requests;
public class CreateClientOrganizationRequestBody
{
[Required(ErrorMessage = "'name' must be provided")]
public string Name { get; set; }
[Required(ErrorMessage = "'ownerEmail' must be provided")]
public string OwnerEmail { get; set; }
[EnumMatches<PlanType>(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly, ErrorMessage = "'planType' must be Teams (Monthly) or Enterprise (Monthly)")]
public PlanType PlanType { get; set; }
[Range(1, int.MaxValue, ErrorMessage = "'seats' must be greater than 0")]
public int Seats { get; set; }
[Required(ErrorMessage = "'key' must be provided")]
public string Key { get; set; }
[Required(ErrorMessage = "'keyPair' must be provided")]
public KeyPairRequestBody KeyPair { get; set; }
[Required(ErrorMessage = "'collectionName' must be provided")]
public string CollectionName { get; set; }
}

View File

@ -0,0 +1,12 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
// ReSharper disable once ClassNeverInstantiated.Global
public class KeyPairRequestBody
{
[Required(ErrorMessage = "'publicKey' must be provided")]
public string PublicKey { get; set; }
[Required(ErrorMessage = "'encryptedPrivateKey' must be provided")]
public string EncryptedPrivateKey { get; set; }
}

View File

@ -0,0 +1,10 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
public class UpdateClientOrganizationRequestBody
{
[Required]
[Range(0, int.MaxValue, ErrorMessage = "You cannot assign negative seats to a client organization.")]
public int AssignedSeats { get; set; }
}

View File

@ -2,9 +2,9 @@
using Bit.Core.Utilities;
using Stripe;
namespace Bit.Api.Billing.Models;
namespace Bit.Api.Billing.Models.Responses;
public record ProviderSubscriptionDTO(
public record ProviderSubscriptionResponse(
string Status,
DateTime CurrentPeriodEndDate,
decimal? DiscountPercentage,
@ -13,8 +13,8 @@ public record ProviderSubscriptionDTO(
private const string _annualCadence = "Annual";
private const string _monthlyCadence = "Monthly";
public static ProviderSubscriptionDTO From(
IEnumerable<ConfiguredProviderPlan> providerPlans,
public static ProviderSubscriptionResponse From(
IEnumerable<ConfiguredProviderPlanDTO> providerPlans,
Subscription subscription)
{
var providerPlansDTO = providerPlans
@ -32,7 +32,7 @@ public record ProviderSubscriptionDTO(
cadence);
});
return new ProviderSubscriptionDTO(
return new ProviderSubscriptionResponse(
subscription.Status,
subscription.CurrentPeriodEnd,
subscription.Customer?.Discount?.Coupon?.PercentOff,

View File

@ -1,6 +0,0 @@
namespace Bit.Api.Billing.Models;
public class UpdateProviderOrganizationRequestBody
{
public int AssignedSeats { get; set; }
}

View File

@ -0,0 +1,26 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Utilities;
public class EnumMatchesAttribute<T>(params T[] accepted) : ValidationAttribute
where T : Enum
{
public override bool IsValid(object value)
{
if (value == null || accepted == null || accepted.Length == 0)
{
return false;
}
var success = Enum.TryParse(typeof(T), value.ToString(), out var result);
if (!success)
{
return false;
}
var typed = (T)result;
return accepted.Contains(typed);
}
}

View File

@ -3,5 +3,6 @@
public enum OrganizationStatusType : byte
{
Pending = 0,
Created = 1
Created = 1,
Managed = 2,
}

View File

@ -26,6 +26,8 @@ public interface IOrganizationService
/// <returns>A tuple containing the new organization, the initial organizationUser (if any) and the default collection (if any)</returns>
#nullable enable
Task<(Organization organization, OrganizationUser? organizationUser, Collection? defaultCollection)> SignUpAsync(OrganizationSignup organizationSignup, bool provider = false);
Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup);
#nullable disable
/// <summary>
/// Create a new organization on a self-hosted instance

View File

@ -421,6 +421,89 @@ public class OrganizationService : IOrganizationService
}
}
public async Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup)
{
var consolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling);
if (!consolidatedBillingEnabled)
{
throw new InvalidOperationException($"{nameof(SignupClientAsync)} is only for use within Consolidated Billing");
}
var plan = StaticStore.GetPlan(signup.Plan);
ValidatePlan(plan, signup.AdditionalSeats, "Password Manager");
var flexibleCollectionsSignupEnabled =
_featureService.IsEnabled(FeatureFlagKeys.FlexibleCollectionsSignup);
var flexibleCollectionsV1Enabled =
_featureService.IsEnabled(FeatureFlagKeys.FlexibleCollectionsV1);
var organization = new Organization
{
// Pre-generate the org id so that we can save it with the Stripe subscription..
Id = CoreHelpers.GenerateComb(),
Name = signup.Name,
BillingEmail = signup.BillingEmail,
PlanType = plan!.Type,
Seats = signup.AdditionalSeats,
MaxCollections = plan.PasswordManager.MaxCollections,
// Extra storage not available for purchase with Consolidated Billing.
MaxStorageGb = 0,
UsePolicies = plan.HasPolicies,
UseSso = plan.HasSso,
UseGroups = plan.HasGroups,
UseEvents = plan.HasEvents,
UseDirectory = plan.HasDirectory,
UseTotp = plan.HasTotp,
Use2fa = plan.Has2fa,
UseApi = plan.HasApi,
UseResetPassword = plan.HasResetPassword,
SelfHost = plan.HasSelfHost,
UsersGetPremium = plan.UsersGetPremium,
UseCustomPermissions = plan.HasCustomPermissions,
UseScim = plan.HasScim,
Plan = plan.Name,
Gateway = GatewayType.Stripe,
ReferenceData = signup.Owner.ReferenceData,
Enabled = true,
LicenseKey = CoreHelpers.SecureRandomString(20),
PublicKey = signup.PublicKey,
PrivateKey = signup.PrivateKey,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
Status = OrganizationStatusType.Created,
UsePasswordManager = true,
// Secrets Manager not available for purchase with Consolidated Billing.
UseSecretsManager = false,
// This feature flag indicates that new organizations should be automatically onboarded to
// Flexible Collections enhancements
FlexibleCollections = flexibleCollectionsSignupEnabled,
// These collection management settings smooth the migration for existing organizations by disabling some FC behavior.
// If the organization is onboarded to Flexible Collections on signup, we turn them OFF to enable all new behaviour.
// If the organization is NOT onboarded now, they will have to be migrated later, so they default to ON to limit FC changes on migration.
LimitCollectionCreationDeletion = !flexibleCollectionsSignupEnabled,
AllowAdminAccessToAllCollectionItems = !flexibleCollectionsV1Enabled
};
var returnValue = await SignUpAsync(organization, default, signup.OwnerKey, signup.CollectionName, false);
await _referenceEventService.RaiseEventAsync(
new ReferenceEvent(ReferenceEventType.Signup, organization, _currentContext)
{
PlanName = plan.Name,
PlanType = plan.Type,
Seats = returnValue.Item1.Seats,
SignupInitiationPath = signup.InitiationPath,
Storage = returnValue.Item1.MaxStorageGb,
});
return returnValue;
}
/// <summary>
/// Create a new organization in a cloud environment
/// </summary>

View File

@ -5,6 +5,15 @@ namespace Bit.Core.Billing.Commands;
public interface IAssignSeatsToClientOrganizationCommand
{
/// <summary>
/// Assigns a specified number of <paramref name="seats"/> to a client <paramref name="organization"/> on behalf of
/// its <paramref name="provider"/>. Seat adjustments for the client organization may autoscale the provider's Stripe
/// <see cref="Stripe.Subscription"/> depending on the provider's seat minimum for the client <paramref name="organization"/>'s
/// <see cref="Organization.PlanType"/>.
/// </summary>
/// <param name="provider">The MSP that manages the client <paramref name="organization"/>.</param>
/// <param name="organization">The client organization whose <see cref="seats"/> you want to update.</param>
/// <param name="seats">The number of seats to assign to the client organization.</param>
Task AssignSeatsToClientOrganization(
Provider provider,
Organization organization,

View File

@ -0,0 +1,17 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
namespace Bit.Core.Billing.Commands;
public interface ICreateCustomerCommand
{
/// <summary>
/// Create a Stripe <see cref="Stripe.Customer"/> for the provided client <paramref name="organization"/> utilizing
/// the address and tax information of its <paramref name="provider"/>.
/// </summary>
/// <param name="provider">The MSP that owns the client organization.</param>
/// <param name="organization">The client organization to create a Stripe <see cref="Stripe.Customer"/> for.</param>
Task CreateCustomer(
Provider provider,
Organization organization);
}

View File

@ -0,0 +1,12 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Enums;
namespace Bit.Core.Billing.Commands;
public interface IScaleSeatsCommand
{
Task ScalePasswordManagerSeats(
Provider provider,
PlanType planType,
int seatAdjustment);
}

View File

@ -1,10 +1,19 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Commands;
public interface IStartSubscriptionCommand
{
/// <summary>
/// Starts a Stripe <see cref="Stripe.Subscription"/> for the given <paramref name="provider"/> utilizing the provided
/// <paramref name="taxInfo"/> to handle automatic taxation and non-US tax identification. <see cref="Provider"/> subscriptions
/// will always be started with a <see cref="Stripe.SubscriptionItem"/> for both the <see cref="PlanType.TeamsMonthly"/> and <see cref="PlanType.EnterpriseMonthly"/>
/// plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan.
/// </summary>
/// <param name="provider">The provider to create the <see cref="Stripe.Subscription"/> for.</param>
/// <param name="taxInfo">The tax information to use for automatic taxation and non-US tax identification.</param>
Task StartSubscription(
Provider provider,
TaxInfo taxInfo);

View File

@ -0,0 +1,89 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Queries;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Commands.Implementations;
public class CreateCustomerCommand(
IGlobalSettings globalSettings,
ILogger<CreateCustomerCommand> logger,
IOrganizationRepository organizationRepository,
IStripeAdapter stripeAdapter,
ISubscriberQueries subscriberQueries) : ICreateCustomerCommand
{
public async Task CreateCustomer(
Provider provider,
Organization organization)
{
ArgumentNullException.ThrowIfNull(provider);
ArgumentNullException.ThrowIfNull(organization);
if (!string.IsNullOrEmpty(organization.GatewayCustomerId))
{
logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId));
return;
}
var providerCustomer = await subscriberQueries.GetCustomerOrThrow(provider, new CustomerGetOptions
{
Expand = ["tax_ids"]
});
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
var organizationDisplayName = organization.DisplayName();
var customerCreateOptions = new CustomerCreateOptions
{
Address = new AddressOptions
{
Country = providerCustomer.Address?.Country,
PostalCode = providerCustomer.Address?.PostalCode,
Line1 = providerCustomer.Address?.Line1,
Line2 = providerCustomer.Address?.Line2,
City = providerCustomer.Address?.City,
State = providerCustomer.Address?.State
},
Name = organizationDisplayName,
Description = $"{provider.Name} Client Organization",
Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = organization.SubscriberType(),
Value = organizationDisplayName.Length <= 30
? organizationDisplayName
: organizationDisplayName[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
{ "region", globalSettings.BaseServiceUri.CloudRegion }
},
TaxIdData = providerTaxId == null ? null :
[
new CustomerTaxIdDataOptions
{
Type = providerTaxId.Type,
Value = providerTaxId.Value
}
]
};
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
organization.GatewayCustomerId = customer.Id;
await organizationRepository.ReplaceAsync(organization);
}
}

View File

@ -0,0 +1,130 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Queries;
using Bit.Core.Billing.Repositories;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Billing.Commands.Implementations;
public class ScaleSeatsCommand(
ILogger<ScaleSeatsCommand> logger,
IPaymentService paymentService,
IProviderBillingQueries providerBillingQueries,
IProviderPlanRepository providerPlanRepository) : IScaleSeatsCommand
{
public async Task ScalePasswordManagerSeats(Provider provider, PlanType planType, int seatAdjustment)
{
ArgumentNullException.ThrowIfNull(provider);
if (provider.Type != ProviderType.Msp)
{
logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their Password Manager seats", provider.Id);
throw ContactSupport();
}
if (!planType.SupportsConsolidatedBilling())
{
logger.LogError("Cannot scale provider ({ProviderID}) Password Manager seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString());
throw ContactSupport();
}
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == planType);
if (providerPlan == null || !providerPlan.IsConfigured())
{
logger.LogError("Cannot scale provider ({ProviderID}) Password Manager seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType);
throw ContactSupport();
}
var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0);
var currentlyAssignedSeatTotal =
await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, planType);
var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment;
var update = CurryUpdateFunction(
provider,
providerPlan,
newlyAssignedSeatTotal);
/*
* Below the limit => Below the limit:
* No subscription update required. We can safely update the organization's seats.
*/
if (currentlyAssignedSeatTotal <= seatMinimum &&
newlyAssignedSeatTotal <= seatMinimum)
{
providerPlan.AllocatedSeats = newlyAssignedSeatTotal;
await providerPlanRepository.ReplaceAsync(providerPlan);
}
/*
* Below the limit => Above the limit:
* We have to scale the subscription up from the seat minimum to the newly assigned seat total.
*/
else if (currentlyAssignedSeatTotal <= seatMinimum &&
newlyAssignedSeatTotal > seatMinimum)
{
await update(
seatMinimum,
newlyAssignedSeatTotal);
}
/*
* Above the limit => Above the limit:
* We have to scale the subscription from the currently assigned seat total to the newly assigned seat total.
*/
else if (currentlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal > seatMinimum)
{
await update(
currentlyAssignedSeatTotal,
newlyAssignedSeatTotal);
}
/*
* Above the limit => Below the limit:
* We have to scale the subscription down from the currently assigned seat total to the seat minimum.
*/
else if (currentlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal <= seatMinimum)
{
await update(
currentlyAssignedSeatTotal,
seatMinimum);
}
}
private Func<int, int, Task> CurryUpdateFunction(
Provider provider,
ProviderPlan providerPlan,
int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) =>
{
var plan = StaticStore.GetPlan(providerPlan.PlanType);
await paymentService.AdjustSeats(
provider,
plan,
currentlySubscribedSeats,
newlySubscribedSeats);
var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum
? newlySubscribedSeats - providerPlan.SeatMinimum
: 0;
providerPlan.PurchasedSeats = newlyPurchasedSeats;
providerPlan.AllocatedSeats = newlyAssignedSeats;
await providerPlanRepository.ReplaceAsync(providerPlan);
};
}

View File

@ -18,7 +18,9 @@ public static class ServiceCollectionExtensions
// Commands
services.AddTransient<IAssignSeatsToClientOrganizationCommand, AssignSeatsToClientOrganizationCommand>();
services.AddTransient<ICancelSubscriptionCommand, CancelSubscriptionCommand>();
services.AddTransient<ICreateCustomerCommand, CreateCustomerCommand>();
services.AddTransient<IRemovePaymentMethodCommand, RemovePaymentMethodCommand>();
services.AddTransient<IScaleSeatsCommand, ScaleSeatsCommand>();
services.AddTransient<IStartSubscriptionCommand, StartSubscriptionCommand>();
}
}

View File

@ -3,7 +3,7 @@ using Bit.Core.Enums;
namespace Bit.Core.Billing.Models;
public record ConfiguredProviderPlan(
public record ConfiguredProviderPlanDTO(
Guid Id,
Guid ProviderId,
PlanType PlanType,
@ -11,9 +11,9 @@ public record ConfiguredProviderPlan(
int PurchasedSeats,
int AssignedSeats)
{
public static ConfiguredProviderPlan From(ProviderPlan providerPlan) =>
public static ConfiguredProviderPlanDTO From(ProviderPlan providerPlan) =>
providerPlan.IsConfigured()
? new ConfiguredProviderPlan(
? new ConfiguredProviderPlanDTO(
providerPlan.Id,
providerPlan.ProviderId,
providerPlan.PlanType,

View File

@ -0,0 +1,7 @@
using Stripe;
namespace Bit.Core.Billing.Models;
public record ProviderSubscriptionDTO(
List<ConfiguredProviderPlanDTO> ProviderPlans,
Subscription Subscription);

View File

@ -1,7 +0,0 @@
using Stripe;
namespace Bit.Core.Billing.Models;
public record ProviderSubscriptionData(
List<ConfiguredProviderPlan> ProviderPlans,
Subscription Subscription);

View File

@ -21,7 +21,7 @@ public interface IProviderBillingQueries
/// Retrieves a provider's billing subscription data.
/// </summary>
/// <param name="providerId">The ID of the provider to retrieve subscription data for.</param>
/// <returns>A <see cref="ProviderSubscriptionData"/> object containing the provider's Stripe <see cref="Stripe.Subscription"/> and their <see cref="ConfiguredProviderPlan"/>s.</returns>
/// <returns>A <see cref="ProviderSubscriptionDTO"/> object containing the provider's Stripe <see cref="Stripe.Subscription"/> and their <see cref="ConfiguredProviderPlanDTO"/>s.</returns>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<ProviderSubscriptionData> GetSubscriptionData(Guid providerId);
Task<ProviderSubscriptionDTO> GetSubscriptionDTO(Guid providerId);
}

View File

@ -6,6 +6,18 @@ namespace Bit.Core.Billing.Queries;
public interface ISubscriberQueries
{
/// <summary>
/// Retrieves a Stripe <see cref="Customer"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// </summary>
/// <param name="subscriber">The organization, provider or user to retrieve the customer for.</param>
/// <param name="customerGetOptions">Optional parameters that can be passed to Stripe to expand or modify the <see cref="Customer"/>.</param>
/// <returns>A Stripe <see cref="Customer"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<Customer> GetCustomer(
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null);
/// <summary>
/// Retrieves a Stripe <see cref="Subscription"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewaySubscriptionId"/> property.
/// </summary>
@ -18,13 +30,29 @@ public interface ISubscriberQueries
ISubscriber subscriber,
SubscriptionGetOptions subscriptionGetOptions = null);
/// <summary>
/// Retrieves a Stripe <see cref="Customer"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// </summary>
/// <param name="subscriber">The organization or user to retrieve the subscription for.</param>
/// <param name="customerGetOptions">Optional parameters that can be passed to Stripe to expand or modify the <see cref="Customer"/>.</param>
/// <returns>A Stripe <see cref="Customer"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <exception cref="GatewayException">Thrown when the subscriber's <see cref="ISubscriber.GatewayCustomerId"/> is <see langword="null"/> or empty.</exception>
/// <exception cref="GatewayException">Thrown when the <see cref="Customer"/> returned from Stripe's API is null.</exception>
Task<Customer> GetCustomerOrThrow(
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null);
/// <summary>
/// Retrieves a Stripe <see cref="Subscription"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewaySubscriptionId"/> property.
/// </summary>
/// <param name="subscriber">The organization or user to retrieve the subscription for.</param>
/// <param name="subscriptionGetOptions">Optional parameters that can be passed to Stripe to expand or modify the <see cref="Subscription"/>.</param>
/// <returns>A Stripe <see cref="Subscription"/>.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <exception cref="GatewayException">Thrown when the subscriber's <see cref="ISubscriber.GatewaySubscriptionId"/> is <see langword="null"/> or empty.</exception>
/// <exception cref="GatewayException">Thrown when the <see cref="Subscription"/> returned from Stripe's API is null.</exception>
Task<Subscription> GetSubscriptionOrThrow(ISubscriber subscriber);
Task<Subscription> GetSubscriptionOrThrow(
ISubscriber subscriber,
SubscriptionGetOptions subscriptionGetOptions = null);
}

View File

@ -44,11 +44,11 @@ public class ProviderBillingQueries(
var plan = StaticStore.GetPlan(planType);
return providerOrganizations
.Where(providerOrganization => providerOrganization.Plan == plan.Name)
.Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed)
.Sum(providerOrganization => providerOrganization.Seats ?? 0);
}
public async Task<ProviderSubscriptionData> GetSubscriptionData(Guid providerId)
public async Task<ProviderSubscriptionDTO> GetSubscriptionDTO(Guid providerId)
{
var provider = await providerRepository.GetByIdAsync(providerId);
@ -82,10 +82,10 @@ public class ProviderBillingQueries(
var configuredProviderPlans = providerPlans
.Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlan.From)
.Select(ConfiguredProviderPlanDTO.From)
.ToList();
return new ProviderSubscriptionData(
return new ProviderSubscriptionDTO(
configuredProviderPlans,
subscription);
}

View File

@ -11,6 +11,31 @@ public class SubscriberQueries(
ILogger<SubscriberQueries> logger,
IStripeAdapter stripeAdapter) : ISubscriberQueries
{
public async Task<Customer> GetCustomer(
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null)
{
ArgumentNullException.ThrowIfNull(subscriber);
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId));
return null;
}
var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions);
if (customer != null)
{
return customer;
}
logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", subscriber.GatewayCustomerId, subscriber.Id);
return null;
}
public async Task<Subscription> GetSubscription(
ISubscriber subscriber,
SubscriptionGetOptions subscriptionGetOptions = null)
@ -19,7 +44,7 @@ public class SubscriberQueries(
if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{
logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id);
logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId));
return null;
}
@ -31,30 +56,57 @@ public class SubscriberQueries(
return subscription;
}
logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId);
logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", subscriber.GatewaySubscriptionId, subscriber.Id);
return null;
}
public async Task<Subscription> GetSubscriptionOrThrow(ISubscriber subscriber)
public async Task<Subscription> GetSubscriptionOrThrow(
ISubscriber subscriber,
SubscriptionGetOptions subscriptionGetOptions = null)
{
ArgumentNullException.ThrowIfNull(subscriber);
if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{
logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id);
logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId));
throw ContactSupport();
}
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId);
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
if (subscription != null)
{
return subscription;
}
logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId);
logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", subscriber.GatewaySubscriptionId, subscriber.Id);
throw ContactSupport();
}
public async Task<Customer> GetCustomerOrThrow(
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null)
{
ArgumentNullException.ThrowIfNull(subscriber);
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId));
throw ContactSupport();
}
var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions);
if (customer != null)
{
return customer;
}
logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", subscriber.GatewayCustomerId, subscriber.Id);
throw ContactSupport();
}

View File

@ -1,5 +1,5 @@
using Bit.Api.Billing.Controllers;
using Bit.Api.Billing.Models;
using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Queries;
@ -61,7 +61,7 @@ public class ProviderBillingControllerTests
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderBillingQueries>().GetSubscriptionData(providerId).ReturnsNull();
sutProvider.GetDependency<IProviderBillingQueries>().GetSubscriptionDTO(providerId).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
@ -79,7 +79,7 @@ public class ProviderBillingControllerTests
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
var configuredPlans = new List<ConfiguredProviderPlan>
var configuredProviderPlanDTOList = new List<ConfiguredProviderPlanDTO>
{
new (Guid.NewGuid(), providerId, PlanType.TeamsMonthly, 50, 10, 30),
new (Guid.NewGuid(), providerId, PlanType.EnterpriseMonthly, 100, 0, 90)
@ -92,25 +92,25 @@ public class ProviderBillingControllerTests
Customer = new Customer { Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } } }
};
var providerSubscriptionData = new ProviderSubscriptionData(
configuredPlans,
var providerSubscriptionDTO = new ProviderSubscriptionDTO(
configuredProviderPlanDTOList,
subscription);
sutProvider.GetDependency<IProviderBillingQueries>().GetSubscriptionData(providerId)
.Returns(providerSubscriptionData);
sutProvider.GetDependency<IProviderBillingQueries>().GetSubscriptionDTO(providerId)
.Returns(providerSubscriptionDTO);
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
Assert.IsType<Ok<ProviderSubscriptionDTO>>(result);
Assert.IsType<Ok<ProviderSubscriptionResponse>>(result);
var providerSubscriptionDTO = ((Ok<ProviderSubscriptionDTO>)result).Value;
var providerSubscriptionResponse = ((Ok<ProviderSubscriptionResponse>)result).Value;
Assert.Equal(providerSubscriptionDTO.Status, subscription.Status);
Assert.Equal(providerSubscriptionDTO.CurrentPeriodEndDate, subscription.CurrentPeriodEnd);
Assert.Equal(providerSubscriptionDTO.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff);
Assert.Equal(providerSubscriptionResponse.Status, subscription.Status);
Assert.Equal(providerSubscriptionResponse.CurrentPeriodEndDate, subscription.CurrentPeriodEnd);
Assert.Equal(providerSubscriptionResponse.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff);
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
var providerTeamsPlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name);
var providerTeamsPlan = providerSubscriptionResponse.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name);
Assert.NotNull(providerTeamsPlan);
Assert.Equal(50, providerTeamsPlan.SeatMinimum);
Assert.Equal(10, providerTeamsPlan.PurchasedSeats);
@ -119,7 +119,7 @@ public class ProviderBillingControllerTests
Assert.Equal("Monthly", providerTeamsPlan.Cadence);
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
var providerEnterprisePlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name);
var providerEnterprisePlan = providerSubscriptionResponse.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name);
Assert.NotNull(providerEnterprisePlan);
Assert.Equal(100, providerEnterprisePlan.SeatMinimum);
Assert.Equal(0, providerEnterprisePlan.PurchasedSeats);

View File

@ -0,0 +1,339 @@
using System.Security.Claims;
using Bit.Api.Billing.Controllers;
using Bit.Api.Billing.Models.Requests;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Commands;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http.HttpResults;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
namespace Bit.Api.Test.Billing.Controllers;
[ControllerCustomize(typeof(ProviderClientsController))]
[SutProviderCustomize]
public class ProviderClientsControllerTests
{
#region CreateAsync
[Theory, BitAutoData]
public async Task CreateAsync_FFDisabled_NotFound(
Guid providerId,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(false);
var result = await sutProvider.Sut.CreateAsync(providerId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task CreateAsync_NoPrincipalUser_Unauthorized(
Guid providerId,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IUserService>().GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).ReturnsNull();
var result = await sutProvider.Sut.CreateAsync(providerId, requestBody);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task CreateAsync_NotProviderAdmin_Unauthorized(
Guid providerId,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IUserService>().GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(new User());
sutProvider.GetDependency<ICurrentContext>().ManageProviderOrganizations(providerId)
.Returns(false);
var result = await sutProvider.Sut.CreateAsync(providerId, requestBody);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task CreateAsync_NoProvider_NotFound(
Guid providerId,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IUserService>().GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(new User());
sutProvider.GetDependency<ICurrentContext>().ManageProviderOrganizations(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.ReturnsNull();
var result = await sutProvider.Sut.CreateAsync(providerId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task CreateAsync_MissingClientOrganization_ServerError(
Guid providerId,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
var user = new User();
sutProvider.GetDependency<IUserService>().GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
sutProvider.GetDependency<ICurrentContext>().ManageProviderOrganizations(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(new Provider());
var clientOrganizationId = Guid.NewGuid();
sutProvider.GetDependency<IProviderService>().CreateOrganizationAsync(
providerId,
Arg.Any<OrganizationSignup>(),
requestBody.OwnerEmail,
user)
.Returns(new ProviderOrganization
{
OrganizationId = clientOrganizationId
});
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(clientOrganizationId).ReturnsNull();
var result = await sutProvider.Sut.CreateAsync(providerId, requestBody);
Assert.IsType<ProblemHttpResult>(result);
}
[Theory, BitAutoData]
public async Task CreateAsync_OK(
Guid providerId,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
var user = new User();
sutProvider.GetDependency<IUserService>().GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>())
.Returns(user);
sutProvider.GetDependency<ICurrentContext>().ManageProviderOrganizations(providerId)
.Returns(true);
var provider = new Provider();
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
var clientOrganizationId = Guid.NewGuid();
sutProvider.GetDependency<IProviderService>().CreateOrganizationAsync(
providerId,
Arg.Is<OrganizationSignup>(signup =>
signup.Name == requestBody.Name &&
signup.Plan == requestBody.PlanType &&
signup.AdditionalSeats == requestBody.Seats &&
signup.OwnerKey == requestBody.Key &&
signup.PublicKey == requestBody.KeyPair.PublicKey &&
signup.PrivateKey == requestBody.KeyPair.EncryptedPrivateKey &&
signup.CollectionName == requestBody.CollectionName),
requestBody.OwnerEmail,
user)
.Returns(new ProviderOrganization
{
OrganizationId = clientOrganizationId
});
var clientOrganization = new Organization { Id = clientOrganizationId };
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(clientOrganizationId)
.Returns(clientOrganization);
var result = await sutProvider.Sut.CreateAsync(providerId, requestBody);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<ICreateCustomerCommand>().Received(1).CreateCustomer(
provider,
clientOrganization);
}
#endregion
#region UpdateAsync
[Theory, BitAutoData]
public async Task UpdateAsync_FFDisabled_NotFound(
Guid providerId,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(false);
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task UpdateAsync_NotProviderAdmin_Unauthorized(
Guid providerId,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(false);
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task UpdateAsync_NoProvider_NotFound(
Guid providerId,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task UpdateAsync_NoProviderOrganization_NotFound(
Guid providerId,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
Provider provider,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task UpdateAsync_NoOrganization_ServerError(
Guid providerId,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
Provider provider,
ProviderOrganization providerOrganization,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.Returns(providerOrganization);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(providerOrganization.OrganizationId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<ProblemHttpResult>(result);
}
[Theory, BitAutoData]
public async Task UpdateAsync_NoContent(
Guid providerId,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
Provider provider,
ProviderOrganization providerOrganization,
Organization organization,
SutProvider<ProviderClientsController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.Returns(providerOrganization);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(providerOrganization.OrganizationId)
.Returns(organization);
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
await sutProvider.GetDependency<IAssignSeatsToClientOrganizationCommand>().Received(1)
.AssignSeatsToClientOrganization(
provider,
organization,
requestBody.AssignedSeats);
Assert.IsType<Ok>(result);
}
#endregion
}

View File

@ -1,168 +0,0 @@
using Bit.Api.Billing.Controllers;
using Bit.Api.Billing.Models;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Commands;
using Bit.Core.Context;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http.HttpResults;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
using ProviderOrganization = Bit.Core.AdminConsole.Entities.Provider.ProviderOrganization;
namespace Bit.Api.Test.Billing.Controllers;
[ControllerCustomize(typeof(ProviderOrganizationController))]
[SutProviderCustomize]
public class ProviderOrganizationControllerTests
{
[Theory, BitAutoData]
public async Task UpdateAsync_FFDisabled_NotFound(
Guid providerId,
Guid providerOrganizationId,
UpdateProviderOrganizationRequestBody requestBody,
SutProvider<ProviderOrganizationController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(false);
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized(
Guid providerId,
Guid providerOrganizationId,
UpdateProviderOrganizationRequestBody requestBody,
SutProvider<ProviderOrganizationController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(false);
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NoProvider_NotFound(
Guid providerId,
Guid providerOrganizationId,
UpdateProviderOrganizationRequestBody requestBody,
SutProvider<ProviderOrganizationController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NoProviderOrganization_NotFound(
Guid providerId,
Guid providerOrganizationId,
UpdateProviderOrganizationRequestBody requestBody,
Provider provider,
SutProvider<ProviderOrganizationController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NoOrganization_ServerError(
Guid providerId,
Guid providerOrganizationId,
UpdateProviderOrganizationRequestBody requestBody,
Provider provider,
ProviderOrganization providerOrganization,
SutProvider<ProviderOrganizationController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.Returns(providerOrganization);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(providerOrganization.OrganizationId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
Assert.IsType<ProblemHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NoContent(
Guid providerId,
Guid providerOrganizationId,
UpdateProviderOrganizationRequestBody requestBody,
Provider provider,
ProviderOrganization providerOrganization,
Organization organization,
SutProvider<ProviderOrganizationController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(providerId)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId)
.Returns(provider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.Returns(providerOrganization);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(providerOrganization.OrganizationId)
.Returns(organization);
var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
await sutProvider.GetDependency<IAssignSeatsToClientOrganizationCommand>().Received(1)
.AssignSeatsToClientOrganization(
provider,
organization,
requestBody.AssignedSeats);
Assert.IsType<Ok>(result);
}
}

View File

@ -0,0 +1,63 @@
using Bit.Api.Utilities;
using Bit.Core.Enums;
using Xunit;
namespace Bit.Api.Test.Utilities;
public class EnumMatchesAttributeTests
{
[Fact]
public void IsValid_NullInput_False()
{
var enumMatchesAttribute =
new EnumMatchesAttribute<PlanType>(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly);
var result = enumMatchesAttribute.IsValid(null);
Assert.False(result);
}
[Fact]
public void IsValid_NullAccepted_False()
{
var enumMatchesAttribute =
new EnumMatchesAttribute<PlanType>();
var result = enumMatchesAttribute.IsValid(PlanType.TeamsMonthly);
Assert.False(result);
}
[Fact]
public void IsValid_EmptyAccepted_False()
{
var enumMatchesAttribute =
new EnumMatchesAttribute<PlanType>([]);
var result = enumMatchesAttribute.IsValid(PlanType.TeamsMonthly);
Assert.False(result);
}
[Fact]
public void IsValid_ParseFails_False()
{
var enumMatchesAttribute =
new EnumMatchesAttribute<PlanType>(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly);
var result = enumMatchesAttribute.IsValid(GatewayType.Stripe);
Assert.False(result);
}
[Fact]
public void IsValid_Matches_True()
{
var enumMatchesAttribute =
new EnumMatchesAttribute<PlanType>(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly);
var result = enumMatchesAttribute.IsValid(PlanType.TeamsMonthly);
Assert.True(result);
}
}

View File

@ -447,6 +447,47 @@ public class OrganizationServiceTests
Assert.Contains("You can't subtract Machine Accounts!", exception.Message);
}
[Theory, BitAutoData]
public async Task SignupClientAsync_Succeeds(
OrganizationSignup signup,
SutProvider<OrganizationService> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true);
signup.Plan = PlanType.TeamsMonthly;
var (organization, _, _) = await sutProvider.Sut.SignupClientAsync(signup);
var plan = StaticStore.GetPlan(signup.Plan);
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).CreateAsync(Arg.Is<Organization>(org =>
org.Id == organization.Id &&
org.Name == signup.Name &&
org.Plan == plan.Name &&
org.PlanType == plan.Type &&
org.UsePolicies == plan.HasPolicies &&
org.PublicKey == signup.PublicKey &&
org.PrivateKey == signup.PrivateKey &&
org.UseSecretsManager == false));
await sutProvider.GetDependency<IOrganizationApiKeyRepository>().Received(1)
.CreateAsync(Arg.Is<OrganizationApiKey>(orgApiKey =>
orgApiKey.OrganizationId == organization.Id));
await sutProvider.GetDependency<IApplicationCacheService>().Received(1)
.UpsertOrganizationAbilityAsync(organization);
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs().CreateAsync(default);
await sutProvider.GetDependency<ICollectionRepository>().Received(1)
.CreateAsync(Arg.Is<Collection>(c => c.Name == signup.CollectionName && c.OrganizationId == organization.Id), null, null);
await sutProvider.GetDependency<IReferenceEventService>().Received(1).RaiseEventAsync(Arg.Is<ReferenceEvent>(
re =>
re.Type == ReferenceEventType.Signup &&
re.PlanType == plan.Type));
}
[Theory]
[OrganizationInviteCustomize(InviteeUserType = OrganizationUserType.User,
InvitorUserType = OrganizationUserType.Owner), OrganizationCustomize(FlexibleCollections = false), BitAutoData]

View File

@ -0,0 +1,129 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Commands.Implementations;
using Bit.Core.Billing.Queries;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Stripe;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Core.Test.Billing.Commands;
[SutProviderCustomize]
public class CreateCustomerCommandTests
{
private const string _customerId = "customer_id";
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_ProviderNull_ThrowsArgumentNullException(
Organization organization,
SutProvider<CreateCustomerCommand> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.CreateCustomer(null, organization));
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_OrganizationNull_ThrowsArgumentNullException(
Provider provider,
SutProvider<CreateCustomerCommand> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.CreateCustomer(provider, null));
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_HasGatewayCustomerId_NoOp(
Provider provider,
Organization organization,
SutProvider<CreateCustomerCommand> sutProvider)
{
organization.GatewayCustomerId = _customerId;
await sutProvider.Sut.CreateCustomer(provider, organization);
await sutProvider.GetDependency<ISubscriberQueries>().DidNotReceiveWithAnyArgs()
.GetCustomerOrThrow(Arg.Any<ISubscriber>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_Succeeds(
Provider provider,
Organization organization,
SutProvider<CreateCustomerCommand> sutProvider)
{
organization.GatewayCustomerId = null;
organization.Name = "Name";
organization.BusinessName = "BusinessName";
var providerCustomer = new Customer
{
Address = new Address
{
Country = "USA",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Unit 4",
City = "Fake Town",
State = "Fake State"
},
TaxIds = new StripeList<TaxId>
{
Data =
[
new TaxId { Type = "TYPE", Value = "VALUE" }
]
}
};
sutProvider.GetDependency<ISubscriberQueries>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.FirstOrDefault() == "tax_ids"))
.Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
.Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) { CloudRegion = "US" });
sutProvider.GetDependency<IStripeAdapter>().CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value))
.Returns(new Customer
{
Id = "customer_id"
});
await sutProvider.Sut.CreateCustomer(provider, organization);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value));
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).ReplaceAsync(Arg.Is<Organization>(
org => org.GatewayCustomerId == "customer_id"));
}
}

View File

@ -29,7 +29,7 @@ public class ProviderBillingQueriesTests
providerRepository.GetByIdAsync(providerId).ReturnsNull();
var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId);
var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId);
Assert.Null(subscriptionData);
@ -50,7 +50,7 @@ public class ProviderBillingQueriesTests
subscriberQueries.GetSubscription(provider).ReturnsNull();
var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId);
var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId);
Assert.Null(subscriptionData);
@ -109,7 +109,7 @@ public class ProviderBillingQueriesTests
providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans);
var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId);
var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId);
Assert.NotNull(subscriptionData);
@ -140,7 +140,7 @@ public class ProviderBillingQueriesTests
return;
void Compare(ProviderPlan providerPlan, ConfiguredProviderPlan configuredProviderPlan)
void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan)
{
Assert.NotNull(configuredProviderPlan);
Assert.Equal(providerPlan.Id, configuredProviderPlan.Id);

View File

@ -1,7 +1,5 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Queries.Implementations;
using Bit.Core.Entities;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@ -17,6 +15,56 @@ namespace Bit.Core.Test.Billing.Queries;
[SutProviderCustomize]
public class SubscriberQueriesTests
{
#region GetCustomer
[Theory, BitAutoData]
public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberQueries> sutProvider)
=> await Assert.ThrowsAsync<ArgumentNullException>(
async () => await sutProvider.Sut.GetCustomer(null));
[Theory, BitAutoData]
public async Task GetCustomer_NoGatewayCustomerId_ReturnsNull(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
organization.GatewayCustomerId = null;
var customer = await sutProvider.Sut.GetCustomer(organization);
Assert.Null(customer);
}
[Theory, BitAutoData]
public async Task GetCustomer_NoCustomer_ReturnsNull(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>()
.CustomerGetAsync(organization.GatewayCustomerId)
.ReturnsNull();
var customer = await sutProvider.Sut.GetCustomer(organization);
Assert.Null(customer);
}
[Theory, BitAutoData]
public async Task GetCustomer_Succeeds(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
var customer = new Customer();
sutProvider.GetDependency<IStripeAdapter>()
.CustomerGetAsync(organization.GatewayCustomerId)
.Returns(customer);
var gotCustomer = await sutProvider.Sut.GetCustomer(organization);
Assert.Equivalent(customer, gotCustomer);
}
#endregion
#region GetSubscription
[Theory, BitAutoData]
public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException(
@ -25,123 +73,91 @@ public class SubscriberQueriesTests
async () => await sutProvider.Sut.GetSubscription(null));
[Theory, BitAutoData]
public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ReturnsNull(
public async Task GetSubscription_NoGatewaySubscriptionId_ReturnsNull(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
organization.GatewaySubscriptionId = null;
var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
var subscription = await sutProvider.Sut.GetSubscription(organization);
Assert.Null(gotSubscription);
Assert.Null(subscription);
}
[Theory, BitAutoData]
public async Task GetSubscription_Organization_NoSubscription_ReturnsNull(
public async Task GetSubscription_NoSubscription_ReturnsNull(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(organization.GatewaySubscriptionId)
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(organization.GatewaySubscriptionId)
.ReturnsNull();
var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
var subscription = await sutProvider.Sut.GetSubscription(organization);
Assert.Null(gotSubscription);
Assert.Null(subscription);
}
[Theory, BitAutoData]
public async Task GetSubscription_Organization_Succeeds(
public async Task GetSubscription_Succeeds(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(organization.GatewaySubscriptionId)
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(organization.GatewaySubscriptionId)
.Returns(subscription);
var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
Assert.Equivalent(subscription, gotSubscription);
}
#endregion
#region GetCustomerOrThrow
[Theory, BitAutoData]
public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberQueries> sutProvider)
=> await Assert.ThrowsAsync<ArgumentNullException>(
async () => await sutProvider.Sut.GetCustomerOrThrow(null));
[Theory, BitAutoData]
public async Task GetSubscription_User_NoGatewaySubscriptionId_ReturnsNull(
User user,
public async Task GetCustomerOrThrow_NoGatewaySubscriptionId_ThrowsGatewayException(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
user.GatewaySubscriptionId = null;
organization.GatewayCustomerId = null;
var gotSubscription = await sutProvider.Sut.GetSubscription(user);
Assert.Null(gotSubscription);
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization));
}
[Theory, BitAutoData]
public async Task GetSubscription_User_NoSubscription_ReturnsNull(
User user,
public async Task GetSubscriptionOrThrow_NoCustomer_ThrowsGatewayException(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(user.GatewaySubscriptionId)
sutProvider.GetDependency<IStripeAdapter>()
.CustomerGetAsync(organization.GatewayCustomerId)
.ReturnsNull();
var gotSubscription = await sutProvider.Sut.GetSubscription(user);
Assert.Null(gotSubscription);
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization));
}
[Theory, BitAutoData]
public async Task GetSubscription_User_Succeeds(
User user,
public async Task GetCustomerOrThrow_Succeeds(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
var subscription = new Subscription();
var customer = new Customer();
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(user.GatewaySubscriptionId)
.Returns(subscription);
sutProvider.GetDependency<IStripeAdapter>()
.CustomerGetAsync(organization.GatewayCustomerId)
.Returns(customer);
var gotSubscription = await sutProvider.Sut.GetSubscription(user);
var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization);
Assert.Equivalent(subscription, gotSubscription);
}
[Theory, BitAutoData]
public async Task GetSubscription_Provider_NoGatewaySubscriptionId_ReturnsNull(
Provider provider,
SutProvider<SubscriberQueries> sutProvider)
{
provider.GatewaySubscriptionId = null;
var gotSubscription = await sutProvider.Sut.GetSubscription(provider);
Assert.Null(gotSubscription);
}
[Theory, BitAutoData]
public async Task GetSubscription_Provider_NoSubscription_ReturnsNull(
Provider provider,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(provider.GatewaySubscriptionId)
.ReturnsNull();
var gotSubscription = await sutProvider.Sut.GetSubscription(provider);
Assert.Null(gotSubscription);
}
[Theory, BitAutoData]
public async Task GetSubscription_Provider_Succeeds(
Provider provider,
SutProvider<SubscriberQueries> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(provider.GatewaySubscriptionId)
.Returns(subscription);
var gotSubscription = await sutProvider.Sut.GetSubscription(provider);
Assert.Equivalent(subscription, gotSubscription);
Assert.Equivalent(customer, gotCustomer);
}
#endregion
@ -153,7 +169,7 @@ public class SubscriberQueriesTests
async () => await sutProvider.Sut.GetSubscriptionOrThrow(null));
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_Organization_NoGatewaySubscriptionId_ThrowsGatewayException(
public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ThrowsGatewayException(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
@ -163,101 +179,31 @@ public class SubscriberQueriesTests
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_Organization_NoSubscription_ThrowsGatewayException(
public async Task GetSubscriptionOrThrow_NoSubscription_ThrowsGatewayException(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(organization.GatewaySubscriptionId)
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(organization.GatewaySubscriptionId)
.ReturnsNull();
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization));
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_Organization_Succeeds(
public async Task GetSubscriptionOrThrow_Succeeds(
Organization organization,
SutProvider<SubscriberQueries> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(organization.GatewaySubscriptionId)
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(organization.GatewaySubscriptionId)
.Returns(subscription);
var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization);
Assert.Equivalent(subscription, gotSubscription);
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_User_NoGatewaySubscriptionId_ThrowsGatewayException(
User user,
SutProvider<SubscriberQueries> sutProvider)
{
user.GatewaySubscriptionId = null;
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user));
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_User_NoSubscription_ThrowsGatewayException(
User user,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(user.GatewaySubscriptionId)
.ReturnsNull();
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user));
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_User_Succeeds(
User user,
SutProvider<SubscriberQueries> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(user.GatewaySubscriptionId)
.Returns(subscription);
var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(user);
Assert.Equivalent(subscription, gotSubscription);
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_Provider_NoGatewaySubscriptionId_ThrowsGatewayException(
Provider provider,
SutProvider<SubscriberQueries> sutProvider)
{
provider.GatewaySubscriptionId = null;
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider));
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_Provider_NoSubscription_ThrowsGatewayException(
Provider provider,
SutProvider<SubscriberQueries> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(provider.GatewaySubscriptionId)
.ReturnsNull();
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider));
}
[Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_Provider_Succeeds(
Provider provider,
SutProvider<SubscriberQueries> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(provider.GatewaySubscriptionId)
.Returns(subscription);
var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(provider);
Assert.Equivalent(subscription, gotSubscription);
}
#endregion
}