diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index 23e8cee4b..7b14f3ed3 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -38,13 +38,14 @@ public class ProviderService : IProviderService private readonly IOrganizationService _organizationService; private readonly ICurrentContext _currentContext; private readonly IStripeAdapter _stripeAdapter; + private readonly IFeatureService _featureService; public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, IUserService userService, IOrganizationService organizationService, IMailService mailService, IDataProtectionProvider dataProtectionProvider, IEventService eventService, IOrganizationRepository organizationRepository, GlobalSettings globalSettings, - ICurrentContext currentContext, IStripeAdapter stripeAdapter) + ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService) { _providerRepository = providerRepository; _providerUserRepository = providerUserRepository; @@ -59,6 +60,7 @@ public class ProviderService : IProviderService _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); _currentContext = currentContext; _stripeAdapter = stripeAdapter; + _featureService = featureService; } public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) @@ -360,6 +362,7 @@ public class ProviderService : IProviderService } var organization = await _organizationRepository.GetByIdAsync(organizationId); + ThrowOnInvalidPlanType(organization.PlanType); if (organization.UseSecretsManager) @@ -507,9 +510,13 @@ public class ProviderService : IProviderService public async Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, string clientOwnerEmail, User user) { - ThrowOnInvalidPlanType(organizationSignup.Plan); + var consolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling); - var (organization, _, defaultCollection) = await _organizationService.SignUpAsync(organizationSignup, true); + ThrowOnInvalidPlanType(organizationSignup.Plan, consolidatedBillingEnabled); + + var (organization, _, defaultCollection) = consolidatedBillingEnabled + ? await _organizationService.SignupClientAsync(organizationSignup) + : await _organizationService.SignUpAsync(organizationSignup, true); var providerOrganization = new ProviderOrganization { @@ -611,8 +618,13 @@ public class ProviderService : IProviderService return confirmedOwnersIds.Except(providerUserIds).Any(); } - private void ThrowOnInvalidPlanType(PlanType requestedType) + private void ThrowOnInvalidPlanType(PlanType requestedType, bool consolidatedBillingEnabled = false) { + if (consolidatedBillingEnabled && requestedType is not (PlanType.TeamsMonthly or PlanType.EnterpriseMonthly)) + { + throw new BadRequestException($"Providers cannot manage organizations with the plan type {requestedType}. Only Teams (Monthly) and Enterprise (Monthly) are allowed."); + } + if (ProviderDisallowedOrganizationTypes.Contains(requestedType)) { throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index 0ab8c588f..22e8760cb 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -1,5 +1,6 @@ using Bit.Commercial.Core.AdminConsole.Services; using Bit.Commercial.Core.Test.AdminConsole.AutoFixture; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -638,6 +639,79 @@ public class ProviderServiceTests t.First().Item2 == null)); } + [Theory, OrganizationCustomize(FlexibleCollections = false), BitAutoData] + public async Task CreateOrganizationAsync_ConsolidatedBillingEnabled_InvalidPlanType_ThrowsBadRequestException( + Provider provider, + OrganizationSignup organizationSignup, + Organization organization, + string clientOwnerEmail, + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true); + + organizationSignup.Plan = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var providerOrganizationRepository = sutProvider.GetDependency(); + + sutProvider.GetDependency().SignupClientAsync(organizationSignup) + .Returns((organization, null as OrganizationUser, new Collection())); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user)); + + await providerOrganizationRepository.DidNotReceiveWithAnyArgs().CreateAsync(default); + } + + [Theory, OrganizationCustomize(FlexibleCollections = false), BitAutoData] + public async Task CreateOrganizationAsync_ConsolidatedBillingEnabled_InvokeSignupClientAsync( + Provider provider, + OrganizationSignup organizationSignup, + Organization organization, + string clientOwnerEmail, + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true); + + organizationSignup.Plan = PlanType.EnterpriseMonthly; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var providerOrganizationRepository = sutProvider.GetDependency(); + + sutProvider.GetDependency().SignupClientAsync(organizationSignup) + .Returns((organization, null as OrganizationUser, new Collection())); + + var providerOrganization = await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); + + await providerOrganizationRepository.Received(1).CreateAsync(Arg.Is( + po => + po.ProviderId == provider.Id && + po.OrganizationId == organization.Id)); + + await sutProvider.GetDependency() + .Received() + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created); + + await sutProvider.GetDependency() + .Received() + .InviteUsersAsync( + organization.Id, + user.Id, + Arg.Is>( + t => + t.Count() == 1 && + t.First().Item1.Emails.Count() == 1 && + t.First().Item1.Emails.First() == clientOwnerEmail && + t.First().Item1.Type == OrganizationUserType.Owner && + t.First().Item1.AccessAll && + !t.First().Item1.Collections.Any() && + t.First().Item2 == null)); + } + [Theory, OrganizationCustomize(FlexibleCollections = true), BitAutoData] public async Task CreateOrganizationAsync_WithFlexibleCollections_SetsAccessAllToFalse (Provider provider, OrganizationSignup organizationSignup, Organization organization, string clientOwnerEmail, diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index 583a5937e..2f33dd50d 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,4 +1,4 @@ -using Bit.Api.Billing.Models; +using Bit.Api.Billing.Models.Responses; using Bit.Core; using Bit.Core.Billing.Queries; using Bit.Core.Context; @@ -28,17 +28,17 @@ public class ProviderBillingController( return TypedResults.Unauthorized(); } - var subscriptionData = await providerBillingQueries.GetSubscriptionData(providerId); + var providerSubscriptionDTO = await providerBillingQueries.GetSubscriptionDTO(providerId); - if (subscriptionData == null) + if (providerSubscriptionDTO == null) { return TypedResults.NotFound(); } - var (providerPlans, subscription) = subscriptionData; + var (providerPlans, subscription) = providerSubscriptionDTO; - var providerSubscriptionDTO = ProviderSubscriptionDTO.From(providerPlans, subscription); + var providerSubscriptionResponse = ProviderSubscriptionResponse.From(providerPlans, subscription); - return TypedResults.Ok(providerSubscriptionDTO); + return TypedResults.Ok(providerSubscriptionResponse); } } diff --git a/src/Api/Billing/Controllers/ProviderClientsController.cs b/src/Api/Billing/Controllers/ProviderClientsController.cs new file mode 100644 index 000000000..6f7bd809f --- /dev/null +++ b/src/Api/Billing/Controllers/ProviderClientsController.cs @@ -0,0 +1,143 @@ +using Bit.Api.Billing.Models.Requests; +using Bit.Core; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Commands; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Models.Business; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Billing.Controllers; + +[Route("providers/{providerId:guid}/clients")] +public class ProviderClientsController( + IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand, + ICreateCustomerCommand createCustomerCommand, + ICurrentContext currentContext, + IFeatureService featureService, + ILogger logger, + IOrganizationRepository organizationRepository, + IProviderOrganizationRepository providerOrganizationRepository, + IProviderRepository providerRepository, + IProviderService providerService, + IScaleSeatsCommand scaleSeatsCommand, + IUserService userService) : Controller +{ + [HttpPost] + public async Task CreateAsync( + [FromRoute] Guid providerId, + [FromBody] CreateClientOrganizationRequestBody requestBody) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return TypedResults.NotFound(); + } + + var user = await userService.GetUserByPrincipalAsync(User); + + if (user == null) + { + return TypedResults.Unauthorized(); + } + + if (!currentContext.ManageProviderOrganizations(providerId)) + { + return TypedResults.Unauthorized(); + } + + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + return TypedResults.NotFound(); + } + + var organizationSignup = new OrganizationSignup + { + Name = requestBody.Name, + Plan = requestBody.PlanType, + AdditionalSeats = requestBody.Seats, + Owner = user, + BillingEmail = provider.BillingEmail, + OwnerKey = requestBody.Key, + PublicKey = requestBody.KeyPair.PublicKey, + PrivateKey = requestBody.KeyPair.EncryptedPrivateKey, + CollectionName = requestBody.CollectionName + }; + + var providerOrganization = await providerService.CreateOrganizationAsync( + providerId, + organizationSignup, + requestBody.OwnerEmail, + user); + + var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); + + if (clientOrganization == null) + { + logger.LogError("Newly created client organization ({ID}) could not be found", providerOrganization.OrganizationId); + + return TypedResults.Problem(); + } + + await scaleSeatsCommand.ScalePasswordManagerSeats( + provider, + requestBody.PlanType, + requestBody.Seats); + + await createCustomerCommand.CreateCustomer( + provider, + clientOrganization); + + clientOrganization.Status = OrganizationStatusType.Managed; + + await organizationRepository.ReplaceAsync(clientOrganization); + + return TypedResults.Ok(); + } + + [HttpPut("{providerOrganizationId:guid}")] + public async Task UpdateAsync( + [FromRoute] Guid providerId, + [FromRoute] Guid providerOrganizationId, + [FromBody] UpdateClientOrganizationRequestBody requestBody) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return TypedResults.NotFound(); + } + + if (!currentContext.ProviderProviderAdmin(providerId)) + { + return TypedResults.Unauthorized(); + } + + var provider = await providerRepository.GetByIdAsync(providerId); + + var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId); + + if (provider == null || providerOrganization == null) + { + return TypedResults.NotFound(); + } + + var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); + + if (clientOrganization == null) + { + logger.LogError("The client organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id); + + return TypedResults.Problem(); + } + + await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization( + provider, + clientOrganization, + requestBody.AssignedSeats); + + return TypedResults.Ok(); + } +} diff --git a/src/Api/Billing/Controllers/ProviderOrganizationController.cs b/src/Api/Billing/Controllers/ProviderOrganizationController.cs deleted file mode 100644 index a5cc31c79..000000000 --- a/src/Api/Billing/Controllers/ProviderOrganizationController.cs +++ /dev/null @@ -1,63 +0,0 @@ -using Bit.Api.Billing.Models; -using Bit.Core; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("providers/{providerId:guid}/organizations")] -public class ProviderOrganizationController( - IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand, - ICurrentContext currentContext, - IFeatureService featureService, - ILogger logger, - IOrganizationRepository organizationRepository, - IProviderRepository providerRepository, - IProviderOrganizationRepository providerOrganizationRepository) : Controller -{ - [HttpPut("{providerOrganizationId:guid}")] - public async Task UpdateAsync( - [FromRoute] Guid providerId, - [FromRoute] Guid providerOrganizationId, - [FromBody] UpdateProviderOrganizationRequestBody requestBody) - { - if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { - return TypedResults.NotFound(); - } - - if (!currentContext.ProviderProviderAdmin(providerId)) - { - return TypedResults.Unauthorized(); - } - - var provider = await providerRepository.GetByIdAsync(providerId); - - var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId); - - if (provider == null || providerOrganization == null) - { - return TypedResults.NotFound(); - } - - var organization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); - - if (organization == null) - { - logger.LogError("The organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id); - - return TypedResults.Problem(); - } - - await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization( - provider, - organization, - requestBody.AssignedSeats); - - return TypedResults.Ok(); - } -} diff --git a/src/Api/Billing/Models/Requests/CreateClientOrganizationRequestBody.cs b/src/Api/Billing/Models/Requests/CreateClientOrganizationRequestBody.cs new file mode 100644 index 000000000..c27fb4522 --- /dev/null +++ b/src/Api/Billing/Models/Requests/CreateClientOrganizationRequestBody.cs @@ -0,0 +1,29 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Utilities; +using Bit.Core.Enums; + +namespace Bit.Api.Billing.Models.Requests; + +public class CreateClientOrganizationRequestBody +{ + [Required(ErrorMessage = "'name' must be provided")] + public string Name { get; set; } + + [Required(ErrorMessage = "'ownerEmail' must be provided")] + public string OwnerEmail { get; set; } + + [EnumMatches(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly, ErrorMessage = "'planType' must be Teams (Monthly) or Enterprise (Monthly)")] + public PlanType PlanType { get; set; } + + [Range(1, int.MaxValue, ErrorMessage = "'seats' must be greater than 0")] + public int Seats { get; set; } + + [Required(ErrorMessage = "'key' must be provided")] + public string Key { get; set; } + + [Required(ErrorMessage = "'keyPair' must be provided")] + public KeyPairRequestBody KeyPair { get; set; } + + [Required(ErrorMessage = "'collectionName' must be provided")] + public string CollectionName { get; set; } +} diff --git a/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs b/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs new file mode 100644 index 000000000..b4f2c00f4 --- /dev/null +++ b/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs @@ -0,0 +1,12 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Billing.Models.Requests; + +// ReSharper disable once ClassNeverInstantiated.Global +public class KeyPairRequestBody +{ + [Required(ErrorMessage = "'publicKey' must be provided")] + public string PublicKey { get; set; } + [Required(ErrorMessage = "'encryptedPrivateKey' must be provided")] + public string EncryptedPrivateKey { get; set; } +} diff --git a/src/Api/Billing/Models/Requests/UpdateClientOrganizationRequestBody.cs b/src/Api/Billing/Models/Requests/UpdateClientOrganizationRequestBody.cs new file mode 100644 index 000000000..c6e04aa79 --- /dev/null +++ b/src/Api/Billing/Models/Requests/UpdateClientOrganizationRequestBody.cs @@ -0,0 +1,10 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Billing.Models.Requests; + +public class UpdateClientOrganizationRequestBody +{ + [Required] + [Range(0, int.MaxValue, ErrorMessage = "You cannot assign negative seats to a client organization.")] + public int AssignedSeats { get; set; } +} diff --git a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs similarity index 84% rename from src/Api/Billing/Models/ProviderSubscriptionDTO.cs rename to src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs index ad0714967..51ab67129 100644 --- a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs +++ b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs @@ -2,9 +2,9 @@ using Bit.Core.Utilities; using Stripe; -namespace Bit.Api.Billing.Models; +namespace Bit.Api.Billing.Models.Responses; -public record ProviderSubscriptionDTO( +public record ProviderSubscriptionResponse( string Status, DateTime CurrentPeriodEndDate, decimal? DiscountPercentage, @@ -13,8 +13,8 @@ public record ProviderSubscriptionDTO( private const string _annualCadence = "Annual"; private const string _monthlyCadence = "Monthly"; - public static ProviderSubscriptionDTO From( - IEnumerable providerPlans, + public static ProviderSubscriptionResponse From( + IEnumerable providerPlans, Subscription subscription) { var providerPlansDTO = providerPlans @@ -32,7 +32,7 @@ public record ProviderSubscriptionDTO( cadence); }); - return new ProviderSubscriptionDTO( + return new ProviderSubscriptionResponse( subscription.Status, subscription.CurrentPeriodEnd, subscription.Customer?.Discount?.Coupon?.PercentOff, diff --git a/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs b/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs deleted file mode 100644 index 7bac8fdef..000000000 --- a/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Api.Billing.Models; - -public class UpdateProviderOrganizationRequestBody -{ - public int AssignedSeats { get; set; } -} diff --git a/src/Api/Utilities/EnumMatchesAttribute.cs b/src/Api/Utilities/EnumMatchesAttribute.cs new file mode 100644 index 000000000..a13b9d59d --- /dev/null +++ b/src/Api/Utilities/EnumMatchesAttribute.cs @@ -0,0 +1,26 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Utilities; + +public class EnumMatchesAttribute(params T[] accepted) : ValidationAttribute + where T : Enum +{ + public override bool IsValid(object value) + { + if (value == null || accepted == null || accepted.Length == 0) + { + return false; + } + + var success = Enum.TryParse(typeof(T), value.ToString(), out var result); + + if (!success) + { + return false; + } + + var typed = (T)result; + + return accepted.Contains(typed); + } +} diff --git a/src/Core/AdminConsole/Enums/OrganizationStatusType.cs b/src/Core/AdminConsole/Enums/OrganizationStatusType.cs index 1f6fb8d39..8b45c7e88 100644 --- a/src/Core/AdminConsole/Enums/OrganizationStatusType.cs +++ b/src/Core/AdminConsole/Enums/OrganizationStatusType.cs @@ -3,5 +3,6 @@ public enum OrganizationStatusType : byte { Pending = 0, - Created = 1 + Created = 1, + Managed = 2, } diff --git a/src/Core/AdminConsole/Services/IOrganizationService.cs b/src/Core/AdminConsole/Services/IOrganizationService.cs index a9d3ee1cc..2c82518d8 100644 --- a/src/Core/AdminConsole/Services/IOrganizationService.cs +++ b/src/Core/AdminConsole/Services/IOrganizationService.cs @@ -26,6 +26,8 @@ public interface IOrganizationService /// A tuple containing the new organization, the initial organizationUser (if any) and the default collection (if any) #nullable enable Task<(Organization organization, OrganizationUser? organizationUser, Collection? defaultCollection)> SignUpAsync(OrganizationSignup organizationSignup, bool provider = false); + + Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup); #nullable disable /// /// Create a new organization on a self-hosted instance diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 9c87ff40a..35404f85a 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -421,6 +421,89 @@ public class OrganizationService : IOrganizationService } } + public async Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup) + { + var consolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling); + + if (!consolidatedBillingEnabled) + { + throw new InvalidOperationException($"{nameof(SignupClientAsync)} is only for use within Consolidated Billing"); + } + + var plan = StaticStore.GetPlan(signup.Plan); + + ValidatePlan(plan, signup.AdditionalSeats, "Password Manager"); + + var flexibleCollectionsSignupEnabled = + _featureService.IsEnabled(FeatureFlagKeys.FlexibleCollectionsSignup); + + var flexibleCollectionsV1Enabled = + _featureService.IsEnabled(FeatureFlagKeys.FlexibleCollectionsV1); + + var organization = new Organization + { + // Pre-generate the org id so that we can save it with the Stripe subscription.. + Id = CoreHelpers.GenerateComb(), + Name = signup.Name, + BillingEmail = signup.BillingEmail, + PlanType = plan!.Type, + Seats = signup.AdditionalSeats, + MaxCollections = plan.PasswordManager.MaxCollections, + // Extra storage not available for purchase with Consolidated Billing. + MaxStorageGb = 0, + UsePolicies = plan.HasPolicies, + UseSso = plan.HasSso, + UseGroups = plan.HasGroups, + UseEvents = plan.HasEvents, + UseDirectory = plan.HasDirectory, + UseTotp = plan.HasTotp, + Use2fa = plan.Has2fa, + UseApi = plan.HasApi, + UseResetPassword = plan.HasResetPassword, + SelfHost = plan.HasSelfHost, + UsersGetPremium = plan.UsersGetPremium, + UseCustomPermissions = plan.HasCustomPermissions, + UseScim = plan.HasScim, + Plan = plan.Name, + Gateway = GatewayType.Stripe, + ReferenceData = signup.Owner.ReferenceData, + Enabled = true, + LicenseKey = CoreHelpers.SecureRandomString(20), + PublicKey = signup.PublicKey, + PrivateKey = signup.PrivateKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + Status = OrganizationStatusType.Created, + UsePasswordManager = true, + // Secrets Manager not available for purchase with Consolidated Billing. + UseSecretsManager = false, + + // This feature flag indicates that new organizations should be automatically onboarded to + // Flexible Collections enhancements + FlexibleCollections = flexibleCollectionsSignupEnabled, + + // These collection management settings smooth the migration for existing organizations by disabling some FC behavior. + // If the organization is onboarded to Flexible Collections on signup, we turn them OFF to enable all new behaviour. + // If the organization is NOT onboarded now, they will have to be migrated later, so they default to ON to limit FC changes on migration. + LimitCollectionCreationDeletion = !flexibleCollectionsSignupEnabled, + AllowAdminAccessToAllCollectionItems = !flexibleCollectionsV1Enabled + }; + + var returnValue = await SignUpAsync(organization, default, signup.OwnerKey, signup.CollectionName, false); + + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Signup, organization, _currentContext) + { + PlanName = plan.Name, + PlanType = plan.Type, + Seats = returnValue.Item1.Seats, + SignupInitiationPath = signup.InitiationPath, + Storage = returnValue.Item1.MaxStorageGb, + }); + + return returnValue; + } + /// /// Create a new organization in a cloud environment /// diff --git a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs index db21926be..43adc73d8 100644 --- a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs +++ b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs @@ -5,6 +5,15 @@ namespace Bit.Core.Billing.Commands; public interface IAssignSeatsToClientOrganizationCommand { + /// + /// Assigns a specified number of to a client on behalf of + /// its . Seat adjustments for the client organization may autoscale the provider's Stripe + /// depending on the provider's seat minimum for the client 's + /// . + /// + /// The MSP that manages the client . + /// The client organization whose you want to update. + /// The number of seats to assign to the client organization. Task AssignSeatsToClientOrganization( Provider provider, Organization organization, diff --git a/src/Core/Billing/Commands/ICreateCustomerCommand.cs b/src/Core/Billing/Commands/ICreateCustomerCommand.cs new file mode 100644 index 000000000..0d7994223 --- /dev/null +++ b/src/Core/Billing/Commands/ICreateCustomerCommand.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; + +namespace Bit.Core.Billing.Commands; + +public interface ICreateCustomerCommand +{ + /// + /// Create a Stripe for the provided client utilizing + /// the address and tax information of its . + /// + /// The MSP that owns the client organization. + /// The client organization to create a Stripe for. + Task CreateCustomer( + Provider provider, + Organization organization); +} diff --git a/src/Core/Billing/Commands/IScaleSeatsCommand.cs b/src/Core/Billing/Commands/IScaleSeatsCommand.cs new file mode 100644 index 000000000..97fe9e2e3 --- /dev/null +++ b/src/Core/Billing/Commands/IScaleSeatsCommand.cs @@ -0,0 +1,12 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Enums; + +namespace Bit.Core.Billing.Commands; + +public interface IScaleSeatsCommand +{ + Task ScalePasswordManagerSeats( + Provider provider, + PlanType planType, + int seatAdjustment); +} diff --git a/src/Core/Billing/Commands/IStartSubscriptionCommand.cs b/src/Core/Billing/Commands/IStartSubscriptionCommand.cs index 9a5ce7d79..74e9367c4 100644 --- a/src/Core/Billing/Commands/IStartSubscriptionCommand.cs +++ b/src/Core/Billing/Commands/IStartSubscriptionCommand.cs @@ -1,10 +1,19 @@ using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Enums; using Bit.Core.Models.Business; namespace Bit.Core.Billing.Commands; public interface IStartSubscriptionCommand { + /// + /// Starts a Stripe for the given utilizing the provided + /// to handle automatic taxation and non-US tax identification. subscriptions + /// will always be started with a for both the and + /// plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan. + /// + /// The provider to create the for. + /// The tax information to use for automatic taxation and non-US tax identification. Task StartSubscription( Provider provider, TaxInfo taxInfo); diff --git a/src/Core/Billing/Commands/Implementations/CreateCustomerCommand.cs b/src/Core/Billing/Commands/Implementations/CreateCustomerCommand.cs new file mode 100644 index 000000000..9a9714f24 --- /dev/null +++ b/src/Core/Billing/Commands/Implementations/CreateCustomerCommand.cs @@ -0,0 +1,89 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Queries; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Commands.Implementations; + +public class CreateCustomerCommand( + IGlobalSettings globalSettings, + ILogger logger, + IOrganizationRepository organizationRepository, + IStripeAdapter stripeAdapter, + ISubscriberQueries subscriberQueries) : ICreateCustomerCommand +{ + public async Task CreateCustomer( + Provider provider, + Organization organization) + { + ArgumentNullException.ThrowIfNull(provider); + ArgumentNullException.ThrowIfNull(organization); + + if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) + { + logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); + + return; + } + + var providerCustomer = await subscriberQueries.GetCustomerOrThrow(provider, new CustomerGetOptions + { + Expand = ["tax_ids"] + }); + + var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); + + var organizationDisplayName = organization.DisplayName(); + + var customerCreateOptions = new CustomerCreateOptions + { + Address = new AddressOptions + { + Country = providerCustomer.Address?.Country, + PostalCode = providerCustomer.Address?.PostalCode, + Line1 = providerCustomer.Address?.Line1, + Line2 = providerCustomer.Address?.Line2, + City = providerCustomer.Address?.City, + State = providerCustomer.Address?.State + }, + Name = organizationDisplayName, + Description = $"{provider.Name} Client Organization", + Email = provider.BillingEmail, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = organization.SubscriberType(), + Value = organizationDisplayName.Length <= 30 + ? organizationDisplayName + : organizationDisplayName[..30] + } + ] + }, + Metadata = new Dictionary + { + { "region", globalSettings.BaseServiceUri.CloudRegion } + }, + TaxIdData = providerTaxId == null ? null : + [ + new CustomerTaxIdDataOptions + { + Type = providerTaxId.Type, + Value = providerTaxId.Value + } + ] + }; + + var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + + organization.GatewayCustomerId = customer.Id; + + await organizationRepository.ReplaceAsync(organization); + } +} diff --git a/src/Core/Billing/Commands/Implementations/ScaleSeatsCommand.cs b/src/Core/Billing/Commands/Implementations/ScaleSeatsCommand.cs new file mode 100644 index 000000000..8d6d9a90e --- /dev/null +++ b/src/Core/Billing/Commands/Implementations/ScaleSeatsCommand.cs @@ -0,0 +1,130 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Commands.Implementations; + +public class ScaleSeatsCommand( + ILogger logger, + IPaymentService paymentService, + IProviderBillingQueries providerBillingQueries, + IProviderPlanRepository providerPlanRepository) : IScaleSeatsCommand +{ + public async Task ScalePasswordManagerSeats(Provider provider, PlanType planType, int seatAdjustment) + { + ArgumentNullException.ThrowIfNull(provider); + + if (provider.Type != ProviderType.Msp) + { + logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their Password Manager seats", provider.Id); + + throw ContactSupport(); + } + + if (!planType.SupportsConsolidatedBilling()) + { + logger.LogError("Cannot scale provider ({ProviderID}) Password Manager seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString()); + + throw ContactSupport(); + } + + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); + + var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == planType); + + if (providerPlan == null || !providerPlan.IsConfigured()) + { + logger.LogError("Cannot scale provider ({ProviderID}) Password Manager seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType); + + throw ContactSupport(); + } + + var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); + + var currentlyAssignedSeatTotal = + await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, planType); + + var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; + + var update = CurryUpdateFunction( + provider, + providerPlan, + newlyAssignedSeatTotal); + + /* + * Below the limit => Below the limit: + * No subscription update required. We can safely update the organization's seats. + */ + if (currentlyAssignedSeatTotal <= seatMinimum && + newlyAssignedSeatTotal <= seatMinimum) + { + providerPlan.AllocatedSeats = newlyAssignedSeatTotal; + + await providerPlanRepository.ReplaceAsync(providerPlan); + } + /* + * Below the limit => Above the limit: + * We have to scale the subscription up from the seat minimum to the newly assigned seat total. + */ + else if (currentlyAssignedSeatTotal <= seatMinimum && + newlyAssignedSeatTotal > seatMinimum) + { + await update( + seatMinimum, + newlyAssignedSeatTotal); + } + /* + * Above the limit => Above the limit: + * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. + */ + else if (currentlyAssignedSeatTotal > seatMinimum && + newlyAssignedSeatTotal > seatMinimum) + { + await update( + currentlyAssignedSeatTotal, + newlyAssignedSeatTotal); + } + /* + * Above the limit => Below the limit: + * We have to scale the subscription down from the currently assigned seat total to the seat minimum. + */ + else if (currentlyAssignedSeatTotal > seatMinimum && + newlyAssignedSeatTotal <= seatMinimum) + { + await update( + currentlyAssignedSeatTotal, + seatMinimum); + } + } + + private Func CurryUpdateFunction( + Provider provider, + ProviderPlan providerPlan, + int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => + { + var plan = StaticStore.GetPlan(providerPlan.PlanType); + + await paymentService.AdjustSeats( + provider, + plan, + currentlySubscribedSeats, + newlySubscribedSeats); + + var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum + ? newlySubscribedSeats - providerPlan.SeatMinimum + : 0; + + providerPlan.PurchasedSeats = newlyPurchasedSeats; + providerPlan.AllocatedSeats = newlyAssignedSeats; + + await providerPlanRepository.ReplaceAsync(providerPlan); + }; +} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index c4f25e2f6..2d802dced 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -18,7 +18,9 @@ public static class ServiceCollectionExtensions // Commands services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); } } diff --git a/src/Core/Billing/Models/ConfiguredProviderPlan.cs b/src/Core/Billing/Models/ConfiguredProviderPlanDTO.cs similarity index 78% rename from src/Core/Billing/Models/ConfiguredProviderPlan.cs rename to src/Core/Billing/Models/ConfiguredProviderPlanDTO.cs index d6bc2b752..519e2f406 100644 --- a/src/Core/Billing/Models/ConfiguredProviderPlan.cs +++ b/src/Core/Billing/Models/ConfiguredProviderPlanDTO.cs @@ -3,7 +3,7 @@ using Bit.Core.Enums; namespace Bit.Core.Billing.Models; -public record ConfiguredProviderPlan( +public record ConfiguredProviderPlanDTO( Guid Id, Guid ProviderId, PlanType PlanType, @@ -11,9 +11,9 @@ public record ConfiguredProviderPlan( int PurchasedSeats, int AssignedSeats) { - public static ConfiguredProviderPlan From(ProviderPlan providerPlan) => + public static ConfiguredProviderPlanDTO From(ProviderPlan providerPlan) => providerPlan.IsConfigured() - ? new ConfiguredProviderPlan( + ? new ConfiguredProviderPlanDTO( providerPlan.Id, providerPlan.ProviderId, providerPlan.PlanType, diff --git a/src/Core/Billing/Models/ProviderSubscriptionDTO.cs b/src/Core/Billing/Models/ProviderSubscriptionDTO.cs new file mode 100644 index 000000000..557a6b359 --- /dev/null +++ b/src/Core/Billing/Models/ProviderSubscriptionDTO.cs @@ -0,0 +1,7 @@ +using Stripe; + +namespace Bit.Core.Billing.Models; + +public record ProviderSubscriptionDTO( + List ProviderPlans, + Subscription Subscription); diff --git a/src/Core/Billing/Models/ProviderSubscriptionData.cs b/src/Core/Billing/Models/ProviderSubscriptionData.cs deleted file mode 100644 index 27da6cd22..000000000 --- a/src/Core/Billing/Models/ProviderSubscriptionData.cs +++ /dev/null @@ -1,7 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Models; - -public record ProviderSubscriptionData( - List ProviderPlans, - Subscription Subscription); diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs index e4b7d0f14..1347ea4b8 100644 --- a/src/Core/Billing/Queries/IProviderBillingQueries.cs +++ b/src/Core/Billing/Queries/IProviderBillingQueries.cs @@ -21,7 +21,7 @@ public interface IProviderBillingQueries /// Retrieves a provider's billing subscription data. /// /// The ID of the provider to retrieve subscription data for. - /// A object containing the provider's Stripe and their s. + /// A object containing the provider's Stripe and their s. /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetSubscriptionData(Guid providerId); + Task GetSubscriptionDTO(Guid providerId); } diff --git a/src/Core/Billing/Queries/ISubscriberQueries.cs b/src/Core/Billing/Queries/ISubscriberQueries.cs index ea6c0d985..013ae3e1d 100644 --- a/src/Core/Billing/Queries/ISubscriberQueries.cs +++ b/src/Core/Billing/Queries/ISubscriberQueries.cs @@ -6,6 +6,18 @@ namespace Bit.Core.Billing.Queries; public interface ISubscriberQueries { + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization, provider or user to retrieve the customer for. + /// Optional parameters that can be passed to Stripe to expand or modify the . + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetCustomer( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null); + /// /// Retrieves a Stripe using the 's property. /// @@ -18,13 +30,29 @@ public interface ISubscriberQueries ISubscriber subscriber, SubscriptionGetOptions subscriptionGetOptions = null); + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization or user to retrieve the subscription for. + /// Optional parameters that can be passed to Stripe to expand or modify the . + /// A Stripe . + /// Thrown when the is . + /// Thrown when the subscriber's is or empty. + /// Thrown when the returned from Stripe's API is null. + Task GetCustomerOrThrow( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null); + /// /// Retrieves a Stripe using the 's property. /// /// The organization or user to retrieve the subscription for. + /// Optional parameters that can be passed to Stripe to expand or modify the . /// A Stripe . /// Thrown when the is . /// Thrown when the subscriber's is or empty. /// Thrown when the returned from Stripe's API is null. - Task GetSubscriptionOrThrow(ISubscriber subscriber); + Task GetSubscriptionOrThrow( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null); } diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs index f8bff9d3f..a941b6f94 100644 --- a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs +++ b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs @@ -44,11 +44,11 @@ public class ProviderBillingQueries( var plan = StaticStore.GetPlan(planType); return providerOrganizations - .Where(providerOrganization => providerOrganization.Plan == plan.Name) + .Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) .Sum(providerOrganization => providerOrganization.Seats ?? 0); } - public async Task GetSubscriptionData(Guid providerId) + public async Task GetSubscriptionDTO(Guid providerId) { var provider = await providerRepository.GetByIdAsync(providerId); @@ -82,10 +82,10 @@ public class ProviderBillingQueries( var configuredProviderPlans = providerPlans .Where(providerPlan => providerPlan.IsConfigured()) - .Select(ConfiguredProviderPlan.From) + .Select(ConfiguredProviderPlanDTO.From) .ToList(); - return new ProviderSubscriptionData( + return new ProviderSubscriptionDTO( configuredProviderPlans, subscription); } diff --git a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs index a160a8759..f9420cd48 100644 --- a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs +++ b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs @@ -11,6 +11,31 @@ public class SubscriberQueries( ILogger logger, IStripeAdapter stripeAdapter) : ISubscriberQueries { + public async Task GetCustomer( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); + + return null; + } + + var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + + if (customer != null) + { + return customer; + } + + logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", subscriber.GatewayCustomerId, subscriber.Id); + + return null; + } + public async Task GetSubscription( ISubscriber subscriber, SubscriptionGetOptions subscriptionGetOptions = null) @@ -19,7 +44,7 @@ public class SubscriberQueries( if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { - logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); return null; } @@ -31,30 +56,57 @@ public class SubscriberQueries( return subscription; } - logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", subscriber.GatewaySubscriptionId, subscriber.Id); return null; } - public async Task GetSubscriptionOrThrow(ISubscriber subscriber) + public async Task GetSubscriptionOrThrow( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null) { ArgumentNullException.ThrowIfNull(subscriber); if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { - logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); throw ContactSupport(); } - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { return subscription; } - logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", subscriber.GatewaySubscriptionId, subscriber.Id); + + throw ContactSupport(); + } + + public async Task GetCustomerOrThrow( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); + + throw ContactSupport(); + } + + var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + + if (customer != null) + { + return customer; + } + + logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", subscriber.GatewayCustomerId, subscriber.Id); throw ContactSupport(); } diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index 57480ac11..8e82e0209 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -1,5 +1,5 @@ using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models; +using Bit.Api.Billing.Models.Responses; using Bit.Core; using Bit.Core.Billing.Models; using Bit.Core.Billing.Queries; @@ -61,7 +61,7 @@ public class ProviderBillingControllerTests sutProvider.GetDependency().ProviderProviderAdmin(providerId) .Returns(true); - sutProvider.GetDependency().GetSubscriptionData(providerId).ReturnsNull(); + sutProvider.GetDependency().GetSubscriptionDTO(providerId).ReturnsNull(); var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); @@ -79,7 +79,7 @@ public class ProviderBillingControllerTests sutProvider.GetDependency().ProviderProviderAdmin(providerId) .Returns(true); - var configuredPlans = new List + var configuredProviderPlanDTOList = new List { new (Guid.NewGuid(), providerId, PlanType.TeamsMonthly, 50, 10, 30), new (Guid.NewGuid(), providerId, PlanType.EnterpriseMonthly, 100, 0, 90) @@ -92,25 +92,25 @@ public class ProviderBillingControllerTests Customer = new Customer { Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } } } }; - var providerSubscriptionData = new ProviderSubscriptionData( - configuredPlans, + var providerSubscriptionDTO = new ProviderSubscriptionDTO( + configuredProviderPlanDTOList, subscription); - sutProvider.GetDependency().GetSubscriptionData(providerId) - .Returns(providerSubscriptionData); + sutProvider.GetDependency().GetSubscriptionDTO(providerId) + .Returns(providerSubscriptionDTO); var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - Assert.IsType>(result); + Assert.IsType>(result); - var providerSubscriptionDTO = ((Ok)result).Value; + var providerSubscriptionResponse = ((Ok)result).Value; - Assert.Equal(providerSubscriptionDTO.Status, subscription.Status); - Assert.Equal(providerSubscriptionDTO.CurrentPeriodEndDate, subscription.CurrentPeriodEnd); - Assert.Equal(providerSubscriptionDTO.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff); + Assert.Equal(providerSubscriptionResponse.Status, subscription.Status); + Assert.Equal(providerSubscriptionResponse.CurrentPeriodEndDate, subscription.CurrentPeriodEnd); + Assert.Equal(providerSubscriptionResponse.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff); var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var providerTeamsPlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name); + var providerTeamsPlan = providerSubscriptionResponse.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name); Assert.NotNull(providerTeamsPlan); Assert.Equal(50, providerTeamsPlan.SeatMinimum); Assert.Equal(10, providerTeamsPlan.PurchasedSeats); @@ -119,7 +119,7 @@ public class ProviderBillingControllerTests Assert.Equal("Monthly", providerTeamsPlan.Cadence); var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - var providerEnterprisePlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name); + var providerEnterprisePlan = providerSubscriptionResponse.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name); Assert.NotNull(providerEnterprisePlan); Assert.Equal(100, providerEnterprisePlan.SeatMinimum); Assert.Equal(0, providerEnterprisePlan.PurchasedSeats); diff --git a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs new file mode 100644 index 000000000..b6acb73e4 --- /dev/null +++ b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs @@ -0,0 +1,339 @@ +using System.Security.Claims; +using Bit.Api.Billing.Controllers; +using Bit.Api.Billing.Models.Requests; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Commands; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Models.Business; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http.HttpResults; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Xunit; + +namespace Bit.Api.Test.Billing.Controllers; + +[ControllerCustomize(typeof(ProviderClientsController))] +[SutProviderCustomize] +public class ProviderClientsControllerTests +{ + #region CreateAsync + [Theory, BitAutoData] + public async Task CreateAsync_FFDisabled_NotFound( + Guid providerId, + CreateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); + + var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_NoPrincipalUser_Unauthorized( + Guid providerId, + CreateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).ReturnsNull(); + + var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_NotProviderAdmin_Unauthorized( + Guid providerId, + CreateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(new User()); + + sutProvider.GetDependency().ManageProviderOrganizations(providerId) + .Returns(false); + + var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_NoProvider_NotFound( + Guid providerId, + CreateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(new User()); + + sutProvider.GetDependency().ManageProviderOrganizations(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .ReturnsNull(); + + var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_MissingClientOrganization_ServerError( + Guid providerId, + CreateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + var user = new User(); + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + sutProvider.GetDependency().ManageProviderOrganizations(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(new Provider()); + + var clientOrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency().CreateOrganizationAsync( + providerId, + Arg.Any(), + requestBody.OwnerEmail, + user) + .Returns(new ProviderOrganization + { + OrganizationId = clientOrganizationId + }); + + sutProvider.GetDependency().GetByIdAsync(clientOrganizationId).ReturnsNull(); + + var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_OK( + Guid providerId, + CreateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + var user = new User(); + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency().ManageProviderOrganizations(providerId) + .Returns(true); + + var provider = new Provider(); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + var clientOrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency().CreateOrganizationAsync( + providerId, + Arg.Is(signup => + signup.Name == requestBody.Name && + signup.Plan == requestBody.PlanType && + signup.AdditionalSeats == requestBody.Seats && + signup.OwnerKey == requestBody.Key && + signup.PublicKey == requestBody.KeyPair.PublicKey && + signup.PrivateKey == requestBody.KeyPair.EncryptedPrivateKey && + signup.CollectionName == requestBody.CollectionName), + requestBody.OwnerEmail, + user) + .Returns(new ProviderOrganization + { + OrganizationId = clientOrganizationId + }); + + var clientOrganization = new Organization { Id = clientOrganizationId }; + + sutProvider.GetDependency().GetByIdAsync(clientOrganizationId) + .Returns(clientOrganization); + + var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + + Assert.IsType(result); + + await sutProvider.GetDependency().Received(1).CreateCustomer( + provider, + clientOrganization); + } + #endregion + + #region UpdateAsync + [Theory, BitAutoData] + public async Task UpdateAsync_FFDisabled_NotFound( + Guid providerId, + Guid providerOrganizationId, + UpdateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_NotProviderAdmin_Unauthorized( + Guid providerId, + Guid providerOrganizationId, + UpdateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(false); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_NoProvider_NotFound( + Guid providerId, + Guid providerOrganizationId, + UpdateClientOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .ReturnsNull(); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_NoProviderOrganization_NotFound( + Guid providerId, + Guid providerOrganizationId, + UpdateClientOrganizationRequestBody requestBody, + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) + .ReturnsNull(); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_NoOrganization_ServerError( + Guid providerId, + Guid providerOrganizationId, + UpdateClientOrganizationRequestBody requestBody, + Provider provider, + ProviderOrganization providerOrganization, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) + .Returns(providerOrganization); + + sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) + .ReturnsNull(); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_NoContent( + Guid providerId, + Guid providerOrganizationId, + UpdateClientOrganizationRequestBody requestBody, + Provider provider, + ProviderOrganization providerOrganization, + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) + .Returns(providerOrganization); + + sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) + .Returns(organization); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + await sutProvider.GetDependency().Received(1) + .AssignSeatsToClientOrganization( + provider, + organization, + requestBody.AssignedSeats); + + Assert.IsType(result); + } + #endregion +} diff --git a/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs deleted file mode 100644 index 805683de2..000000000 --- a/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs +++ /dev/null @@ -1,168 +0,0 @@ -using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models; -using Bit.Core; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Http.HttpResults; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Xunit; -using ProviderOrganization = Bit.Core.AdminConsole.Entities.Provider.ProviderOrganization; - -namespace Bit.Api.Test.Billing.Controllers; - -[ControllerCustomize(typeof(ProviderOrganizationController))] -[SutProviderCustomize] -public class ProviderOrganizationControllerTests -{ - [Theory, BitAutoData] - public async Task UpdateAsync_FFDisabled_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(false); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(false); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoProvider_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoProviderOrganization_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoOrganization_ServerError( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - Provider provider, - ProviderOrganization providerOrganization, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .Returns(providerOrganization); - - sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoContent( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - Provider provider, - ProviderOrganization providerOrganization, - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .Returns(providerOrganization); - - sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) - .Returns(organization); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - await sutProvider.GetDependency().Received(1) - .AssignSeatsToClientOrganization( - provider, - organization, - requestBody.AssignedSeats); - - Assert.IsType(result); - } -} diff --git a/test/Api.Test/Utilities/EnumMatchesAttributeTests.cs b/test/Api.Test/Utilities/EnumMatchesAttributeTests.cs new file mode 100644 index 000000000..f1c4accbb --- /dev/null +++ b/test/Api.Test/Utilities/EnumMatchesAttributeTests.cs @@ -0,0 +1,63 @@ +using Bit.Api.Utilities; +using Bit.Core.Enums; +using Xunit; + +namespace Bit.Api.Test.Utilities; + +public class EnumMatchesAttributeTests +{ + [Fact] + public void IsValid_NullInput_False() + { + var enumMatchesAttribute = + new EnumMatchesAttribute(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly); + + var result = enumMatchesAttribute.IsValid(null); + + Assert.False(result); + } + + [Fact] + public void IsValid_NullAccepted_False() + { + var enumMatchesAttribute = + new EnumMatchesAttribute(); + + var result = enumMatchesAttribute.IsValid(PlanType.TeamsMonthly); + + Assert.False(result); + } + + [Fact] + public void IsValid_EmptyAccepted_False() + { + var enumMatchesAttribute = + new EnumMatchesAttribute([]); + + var result = enumMatchesAttribute.IsValid(PlanType.TeamsMonthly); + + Assert.False(result); + } + + [Fact] + public void IsValid_ParseFails_False() + { + var enumMatchesAttribute = + new EnumMatchesAttribute(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly); + + var result = enumMatchesAttribute.IsValid(GatewayType.Stripe); + + Assert.False(result); + } + + [Fact] + public void IsValid_Matches_True() + { + var enumMatchesAttribute = + new EnumMatchesAttribute(PlanType.TeamsMonthly, PlanType.EnterpriseMonthly); + + var result = enumMatchesAttribute.IsValid(PlanType.TeamsMonthly); + + Assert.True(result); + } +} diff --git a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs index fd249a4ad..bd7c6d4c5 100644 --- a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs @@ -447,6 +447,47 @@ public class OrganizationServiceTests Assert.Contains("You can't subtract Machine Accounts!", exception.Message); } + [Theory, BitAutoData] + public async Task SignupClientAsync_Succeeds( + OrganizationSignup signup, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true); + + signup.Plan = PlanType.TeamsMonthly; + + var (organization, _, _) = await sutProvider.Sut.SignupClientAsync(signup); + + var plan = StaticStore.GetPlan(signup.Plan); + + await sutProvider.GetDependency().Received(1).CreateAsync(Arg.Is(org => + org.Id == organization.Id && + org.Name == signup.Name && + org.Plan == plan.Name && + org.PlanType == plan.Type && + org.UsePolicies == plan.HasPolicies && + org.PublicKey == signup.PublicKey && + org.PrivateKey == signup.PrivateKey && + org.UseSecretsManager == false)); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(orgApiKey => + orgApiKey.OrganizationId == organization.Id)); + + await sutProvider.GetDependency().Received(1) + .UpsertOrganizationAbilityAsync(organization); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(c => c.Name == signup.CollectionName && c.OrganizationId == organization.Id), null, null); + + await sutProvider.GetDependency().Received(1).RaiseEventAsync(Arg.Is( + re => + re.Type == ReferenceEventType.Signup && + re.PlanType == plan.Type)); + } + [Theory] [OrganizationInviteCustomize(InviteeUserType = OrganizationUserType.User, InvitorUserType = OrganizationUserType.Owner), OrganizationCustomize(FlexibleCollections = false), BitAutoData] diff --git a/test/Core.Test/Billing/Commands/CreateCustomerCommandTests.cs b/test/Core.Test/Billing/Commands/CreateCustomerCommandTests.cs new file mode 100644 index 000000000..b532879e9 --- /dev/null +++ b/test/Core.Test/Billing/Commands/CreateCustomerCommandTests.cs @@ -0,0 +1,129 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Commands.Implementations; +using Bit.Core.Billing.Queries; +using Bit.Core.Entities; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; +using GlobalSettings = Bit.Core.Settings.GlobalSettings; + +namespace Bit.Core.Test.Billing.Commands; + +[SutProviderCustomize] +public class CreateCustomerCommandTests +{ + private const string _customerId = "customer_id"; + + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_ProviderNull_ThrowsArgumentNullException( + Organization organization, + SutProvider sutProvider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(null, organization)); + + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_OrganizationNull_ThrowsArgumentNullException( + Provider provider, + SutProvider sutProvider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(provider, null)); + + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_HasGatewayCustomerId_NoOp( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = _customerId; + + await sutProvider.Sut.CreateCustomer(provider, organization); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetCustomerOrThrow(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + organization.Name = "Name"; + organization.BusinessName = "BusinessName"; + + var providerCustomer = new Customer + { + Address = new Address + { + Country = "USA", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Unit 4", + City = "Fake Town", + State = "Fake State" + }, + TaxIds = new StripeList + { + Data = + [ + new TaxId { Type = "TYPE", Value = "VALUE" } + ] + } + }; + + sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( + options => options.Expand.FirstOrDefault() == "tax_ids")) + .Returns(providerCustomer); + + sutProvider.GetDependency().BaseServiceUri + .Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) { CloudRegion = "US" }); + + sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)) + .Returns(new Customer + { + Id = "customer_id" + }); + + await sutProvider.Sut.CreateCustomer(provider, organization); + + await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.GatewayCustomerId == "customer_id")); + } +} diff --git a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs index 534444ba9..afa361781 100644 --- a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs +++ b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs @@ -29,7 +29,7 @@ public class ProviderBillingQueriesTests providerRepository.GetByIdAsync(providerId).ReturnsNull(); - var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); Assert.Null(subscriptionData); @@ -50,7 +50,7 @@ public class ProviderBillingQueriesTests subscriberQueries.GetSubscription(provider).ReturnsNull(); - var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); Assert.Null(subscriptionData); @@ -109,7 +109,7 @@ public class ProviderBillingQueriesTests providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); - var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); Assert.NotNull(subscriptionData); @@ -140,7 +140,7 @@ public class ProviderBillingQueriesTests return; - void Compare(ProviderPlan providerPlan, ConfiguredProviderPlan configuredProviderPlan) + void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan) { Assert.NotNull(configuredProviderPlan); Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); diff --git a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs index 51682a666..8fcba59aa 100644 --- a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs +++ b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs @@ -1,7 +1,5 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Entities; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -17,6 +15,56 @@ namespace Bit.Core.Test.Billing.Queries; [SutProviderCustomize] public class SubscriberQueriesTests { + #region GetCustomer + [Theory, BitAutoData] + public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetCustomer(null)); + + [Theory, BitAutoData] + public async Task GetCustomer_NoGatewayCustomerId_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + + var customer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Null(customer); + } + + [Theory, BitAutoData] + public async Task GetCustomer_NoCustomer_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .ReturnsNull(); + + var customer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Null(customer); + } + + [Theory, BitAutoData] + public async Task GetCustomer_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var customer = new Customer(); + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .Returns(customer); + + var gotCustomer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Equivalent(customer, gotCustomer); + } + #endregion + #region GetSubscription [Theory, BitAutoData] public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( @@ -25,123 +73,91 @@ public class SubscriberQueriesTests async () => await sutProvider.Sut.GetSubscription(null)); [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ReturnsNull( + public async Task GetSubscription_NoGatewaySubscriptionId_ReturnsNull( Organization organization, SutProvider sutProvider) { organization.GatewaySubscriptionId = null; - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + var subscription = await sutProvider.Sut.GetSubscription(organization); - Assert.Null(gotSubscription); + Assert.Null(subscription); } [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoSubscription_ReturnsNull( + public async Task GetSubscription_NoSubscription_ReturnsNull( Organization organization, SutProvider sutProvider) { - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) .ReturnsNull(); - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + var subscription = await sutProvider.Sut.GetSubscription(organization); - Assert.Null(gotSubscription); + Assert.Null(subscription); } [Theory, BitAutoData] - public async Task GetSubscription_Organization_Succeeds( + public async Task GetSubscription_Succeeds( Organization organization, SutProvider sutProvider) { var subscription = new Subscription(); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscription(organization); Assert.Equivalent(subscription, gotSubscription); } + #endregion + + #region GetCustomerOrThrow + [Theory, BitAutoData] + public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetCustomerOrThrow(null)); [Theory, BitAutoData] - public async Task GetSubscription_User_NoGatewaySubscriptionId_ReturnsNull( - User user, + public async Task GetCustomerOrThrow_NoGatewaySubscriptionId_ThrowsGatewayException( + Organization organization, SutProvider sutProvider) { - user.GatewaySubscriptionId = null; + organization.GatewayCustomerId = null; - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Null(gotSubscription); + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetSubscription_User_NoSubscription_ReturnsNull( - User user, + public async Task GetSubscriptionOrThrow_NoCustomer_ThrowsGatewayException( + Organization organization, SutProvider sutProvider) { - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) .ReturnsNull(); - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Null(gotSubscription); + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetSubscription_User_Succeeds( - User user, + public async Task GetCustomerOrThrow_Succeeds( + Organization organization, SutProvider sutProvider) { - var subscription = new Subscription(); + var customer = new Customer(); - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .Returns(subscription); + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .Returns(customer); - var gotSubscription = await sutProvider.Sut.GetSubscription(user); + var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization); - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Provider_NoGatewaySubscriptionId_ReturnsNull( - Provider provider, - SutProvider sutProvider) - { - provider.GatewaySubscriptionId = null; - - var gotSubscription = await sutProvider.Sut.GetSubscription(provider); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Provider_NoSubscription_ReturnsNull( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .ReturnsNull(); - - var gotSubscription = await sutProvider.Sut.GetSubscription(provider); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Provider_Succeeds( - Provider provider, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(provider); - - Assert.Equivalent(subscription, gotSubscription); + Assert.Equivalent(customer, gotCustomer); } #endregion @@ -153,7 +169,7 @@ public class SubscriberQueriesTests async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( + public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ThrowsGatewayException( Organization organization, SutProvider sutProvider) { @@ -163,101 +179,31 @@ public class SubscriberQueriesTests } [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Organization_NoSubscription_ThrowsGatewayException( + public async Task GetSubscriptionOrThrow_NoSubscription_ThrowsGatewayException( Organization organization, SutProvider sutProvider) { - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) .ReturnsNull(); await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Organization_Succeeds( + public async Task GetSubscriptionOrThrow_Succeeds( Organization organization, SutProvider sutProvider) { var subscription = new Subscription(); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); Assert.Equivalent(subscription, gotSubscription); } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_User_NoGatewaySubscriptionId_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - user.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_User_NoSubscription_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_User_Succeeds( - User user, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(user); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Provider_NoGatewaySubscriptionId_ThrowsGatewayException( - Provider provider, - SutProvider sutProvider) - { - provider.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Provider_NoSubscription_ThrowsGatewayException( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Provider_Succeeds( - Provider provider, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(provider); - - Assert.Equivalent(subscription, gotSubscription); - } #endregion }