From 75cae907e82fa2b2afc0ce7517a18d018ec08daa Mon Sep 17 00:00:00 2001 From: cyprain-okeke <108260115+cyprain-okeke@users.noreply.github.com> Date: Wed, 20 Dec 2023 22:54:45 +0100 Subject: [PATCH] [AC-1753] Automatically assign provider's pricing to new organizations (#3513) * Initial commit * resolve pr comment * adding some unit test * Resolve pr comments * Adding some unit test * Resolve pr comment * changes to find the bug * revert back changes on admin * Fix the failing Test * fix the bug --- .../AdminConsole/Services/ProviderService.cs | 106 +++++++++++++++++- .../Services/ProviderServiceTests.cs | 95 ++++++++++++++++ .../Providers/ProviderResponseModel.cs | 2 + .../Implementations/OrganizationService.cs | 11 +- src/Core/Constants.cs | 6 + 5 files changed, 206 insertions(+), 14 deletions(-) diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index a03e92c3f..c8b64da19 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -1,4 +1,7 @@ -using Bit.Core.AdminConsole.Entities.Provider; +using System.ComponentModel.DataAnnotations; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Repositories; @@ -33,13 +36,14 @@ public class ProviderService : IProviderService private readonly IUserService _userService; private readonly IOrganizationService _organizationService; private readonly ICurrentContext _currentContext; + private readonly IStripeAdapter _stripeAdapter; public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, IUserService userService, IOrganizationService organizationService, IMailService mailService, IDataProtectionProvider dataProtectionProvider, IEventService eventService, IOrganizationRepository organizationRepository, GlobalSettings globalSettings, - ICurrentContext currentContext) + ICurrentContext currentContext, IStripeAdapter stripeAdapter) { _providerRepository = providerRepository; _providerUserRepository = providerUserRepository; @@ -53,6 +57,7 @@ public class ProviderService : IProviderService _globalSettings = globalSettings; _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); _currentContext = currentContext; + _stripeAdapter = stripeAdapter; } public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) @@ -369,6 +374,7 @@ public class ProviderService : IProviderService Key = key, }; + await ApplyProviderPriceRateAsync(organizationId, providerId); await _providerOrganizationRepository.CreateAsync(providerOrganization); await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added); } @@ -381,18 +387,110 @@ public class ProviderService : IProviderService throw new BadRequestException("Provider must be of type Reseller in order to assign Organizations to it."); } - var existingProviderOrganizationsCount = await _providerOrganizationRepository.GetCountByOrganizationIdsAsync(organizationIds); + var orgIdsList = organizationIds.ToList(); + var existingProviderOrganizationsCount = await _providerOrganizationRepository.GetCountByOrganizationIdsAsync(orgIdsList); if (existingProviderOrganizationsCount > 0) { throw new BadRequestException("Organizations must not be assigned to any Provider."); } - var providerOrganizationsToInsert = organizationIds.Select(orgId => new ProviderOrganization { ProviderId = providerId, OrganizationId = orgId }); + var providerOrganizationsToInsert = orgIdsList.Select(orgId => new ProviderOrganization { ProviderId = providerId, OrganizationId = orgId }); var insertedProviderOrganizations = await _providerOrganizationRepository.CreateManyAsync(providerOrganizationsToInsert); await _eventService.LogProviderOrganizationEventsAsync(insertedProviderOrganizations.Select(ipo => (ipo, EventType.ProviderOrganization_Added, (DateTime?)null))); } + private async Task ApplyProviderPriceRateAsync(Guid organizationId, Guid providerId) + { + var provider = await _providerRepository.GetByIdAsync(providerId); + // if a provider was created before Nov 6, 2023.If true, the organization plan assigned to that provider is updated to a 2020 plan. + if (provider.CreationDate >= Constants.ProviderCreatedPriorNov62023) + { + return; + } + + var organization = await _organizationRepository.GetByIdAsync(organizationId); + var subscriptionItem = await GetSubscriptionItemAsync(organization.GatewaySubscriptionId, GetStripeSeatPlanId(organization.PlanType)); + var extractedPlanType = PlanTypeMappings(organization); + if (subscriptionItem != null) + { + await UpdateSubscriptionAsync(subscriptionItem, GetStripeSeatPlanId(extractedPlanType), organization); + } + + await _organizationRepository.UpsertAsync(organization); + } + + private async Task GetSubscriptionItemAsync(string subscriptionId, string oldPlanId) + { + var subscriptionDetails = await _stripeAdapter.SubscriptionGetAsync(subscriptionId); + return subscriptionDetails.Items.Data.FirstOrDefault(item => item.Price.Id == oldPlanId); + } + + private static string GetStripeSeatPlanId(PlanType planType) + { + return StaticStore.GetPlan(planType).PasswordManager.StripeSeatPlanId; + } + + private async Task UpdateSubscriptionAsync(Stripe.SubscriptionItem subscriptionItem, string extractedPlanType, Organization organization) + { + try + { + if (subscriptionItem.Price.Id != extractedPlanType) + { + await _stripeAdapter.SubscriptionUpdateAsync(subscriptionItem.Subscription, + new Stripe.SubscriptionUpdateOptions + { + Items = new List + { + new() + { + Id = subscriptionItem.Id, + Price = extractedPlanType, + Quantity = organization.Seats.Value, + }, + } + }); + } + } + catch (Exception) + { + throw new Exception("Unable to update existing plan on stripe"); + } + + } + + private static PlanType PlanTypeMappings(Organization organization) + { + var planTypeMappings = new Dictionary + { + { PlanType.EnterpriseAnnually2020, GetEnumDisplayName(PlanType.EnterpriseAnnually2020) }, + { PlanType.EnterpriseMonthly2020, GetEnumDisplayName(PlanType.EnterpriseMonthly2020) }, + { PlanType.TeamsMonthly2020, GetEnumDisplayName(PlanType.TeamsMonthly2020) }, + { PlanType.TeamsAnnually2020, GetEnumDisplayName(PlanType.TeamsAnnually2020) } + }; + + foreach (var mapping in planTypeMappings) + { + if (mapping.Value.IndexOf(organization.Plan, StringComparison.Ordinal) != -1) + { + organization.PlanType = mapping.Key; + organization.Plan = mapping.Value; + return organization.PlanType; + } + } + + throw new ArgumentException("Invalid PlanType selected"); + } + + private static string GetEnumDisplayName(Enum value) + { + var fieldInfo = value.GetType().GetField(value.ToString()); + + var displayAttribute = (DisplayAttribute)Attribute.GetCustomAttribute(fieldInfo!, typeof(DisplayAttribute)); + + return displayAttribute?.Name ?? value.ToString(); + } + public async Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, string clientOwnerEmail, User user) { diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index b7ee76da1..24167e714 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -18,6 +18,7 @@ using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.DataProtection; using NSubstitute; using NSubstitute.ReturnsExtensions; +using Stripe; using Xunit; using Provider = Bit.Core.AdminConsole.Entities.Provider.Provider; using ProviderUser = Bit.Core.AdminConsole.Entities.Provider.ProviderUser; @@ -598,4 +599,98 @@ public class ProviderServiceTests await sutProvider.GetDependency().Received() .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); } + + [Theory, BitAutoData] + public async Task AddOrganization_CreateAfterNov162023_PlanTypeDoesNotUpdated(Provider provider, Organization organization, string key, + SutProvider sutProvider) + { + provider.Type = ProviderType.Msp; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var providerOrganizationRepository = sutProvider.GetDependency(); + var expectedPlanType = PlanType.EnterpriseAnnually; + organization.PlanType = PlanType.EnterpriseAnnually; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key); + + await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .Received().LogProviderOrganizationEventAsync(Arg.Any(), + EventType.ProviderOrganization_Added); + Assert.Equal(organization.PlanType, expectedPlanType); + } + + [Theory, BitAutoData] + public async Task AddOrganization_CreateBeforeNov162023_PlanTypeUpdated(Provider provider, Organization organization, string key, + SutProvider sutProvider) + { + var newCreationDate = DateTime.UtcNow.AddMonths(-3); + BackdateProviderCreationDate(provider, newCreationDate); + provider.Type = ProviderType.Msp; + + organization.PlanType = PlanType.EnterpriseAnnually; + organization.Plan = "Enterprise (Annually)"; + + var expectedPlanType = PlanType.EnterpriseAnnually2020; + + var expectedPlanId = "2020-enterprise-org-seat-annually"; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + providerOrganizationRepository.GetByOrganizationId(organization.Id).ReturnsNull(); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var subscriptionItem = GetSubscription(organization.GatewaySubscriptionId); + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(GetSubscription(organization.GatewaySubscriptionId)); + await sutProvider.GetDependency().SubscriptionUpdateAsync( + organization.GatewaySubscriptionId, SubscriptionUpdateRequest(expectedPlanId, subscriptionItem)); + + await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key); + + await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .Received().LogProviderOrganizationEventAsync(Arg.Any(), + EventType.ProviderOrganization_Added); + + Assert.Equal(organization.PlanType, expectedPlanType); + } + + private static SubscriptionUpdateOptions SubscriptionUpdateRequest(string expectedPlanId, Subscription subscriptionItem) => + new() + { + Items = new List + { + new() { Id = subscriptionItem.Id, Price = expectedPlanId }, + } + }; + + private static Subscription GetSubscription(string subscriptionId) => + new() + { + Id = subscriptionId, + Items = new StripeList + { + Data = new List + { + new() + { + Id = "sub_item_123", + Price = new Price() + { + Id = "2023-enterprise-org-seat-annually" + } + } + } + } + }; + + private static void BackdateProviderCreationDate(Provider provider, DateTime newCreationDate) + { + // Set the CreationDate to the desired value + provider.GetType().GetProperty("CreationDate")?.SetValue(provider, newCreationDate, null); + } } diff --git a/src/Api/AdminConsole/Models/Response/Providers/ProviderResponseModel.cs b/src/Api/AdminConsole/Models/Response/Providers/ProviderResponseModel.cs index bc55093a0..a7280fd49 100644 --- a/src/Api/AdminConsole/Models/Response/Providers/ProviderResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Providers/ProviderResponseModel.cs @@ -21,6 +21,7 @@ public class ProviderResponseModel : ResponseModel BusinessCountry = provider.BusinessCountry; BusinessTaxNumber = provider.BusinessTaxNumber; BillingEmail = provider.BillingEmail; + CreationDate = provider.CreationDate; } public Guid Id { get; set; } @@ -32,4 +33,5 @@ public class ProviderResponseModel : ResponseModel public string BusinessCountry { get; set; } public string BusinessTaxNumber { get; set; } public string BillingEmail { get; set; } + public DateTime CreationDate { get; set; } } diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index e9eca14a9..6961ce71b 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -1890,11 +1890,6 @@ public class OrganizationService : IOrganizationService public void ValidatePasswordManagerPlan(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade) { - if (plan is not { LegacyYear: null }) - { - throw new BadRequestException("Invalid Password Manager plan selected."); - } - ValidatePlan(plan, upgrade.AdditionalSeats, "Password Manager"); if (plan.PasswordManager.BaseSeats + upgrade.AdditionalSeats <= 0) @@ -2409,12 +2404,8 @@ public class OrganizationService : IOrganizationService public async Task CreatePendingOrganization(Organization organization, string ownerEmail, ClaimsPrincipal user, IUserService userService, bool salesAssistedTrialStarted) { var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan is not { LegacyYear: null }) - { - throw new BadRequestException("Invalid plan selected."); - } - if (plan.Disabled) + if (plan!.Disabled) { throw new BadRequestException("Plan not found."); } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 706d6858a..a71032ea2 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -29,6 +29,12 @@ public static class Constants /// Used by IdentityServer to identify our own provider. /// public const string IdentityProvider = "bitwarden"; + + /// + /// Date identifier used in ProviderService to determine if a provider was created before Nov 6, 2023. + /// If true, the organization plan assigned to that provider is updated to a 2020 plan. + /// + public static readonly DateTime ProviderCreatedPriorNov62023 = new DateTime(2023, 11, 6); } public static class AuthConstants