diff --git a/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs b/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs index b6a58f82f..7bfef8a93 100644 --- a/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs +++ b/src/Core/Billing/Migration/Models/ProviderMigrationTracker.cs @@ -3,17 +3,14 @@ public enum ProviderMigrationProgress { Started = 1, - ClientsMigrated = 2, - TeamsPlanConfigured = 3, - EnterprisePlanConfigured = 4, - CustomerSetup = 5, - SubscriptionSetup = 6, - CreditApplied = 7, - Completed = 8, - - Reversing = 9, - ReversedClientMigrations = 10, - RemovedProviderPlans = 11 + NoClients = 2, + ClientsMigrated = 3, + TeamsPlanConfigured = 4, + EnterprisePlanConfigured = 5, + CustomerSetup = 6, + SubscriptionSetup = 7, + CreditApplied = 8, + Completed = 9, } public class ProviderMigrationTracker diff --git a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs index 46c014cdb..9ca515a26 100644 --- a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs +++ b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs @@ -41,7 +41,18 @@ public class ProviderMigrator( await migrationTrackerCache.StartTracker(provider); - await MigrateClientsAsync(providerId); + var organizations = await GetClientsAsync(provider.Id); + + if (organizations.Count == 0) + { + logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId); + + await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients); + + return; + } + + await MigrateClientsAsync(providerId, organizations); await ConfigureTeamsPlanAsync(providerId); @@ -65,6 +76,16 @@ public class ProviderMigrator( return null; } + if (providerTracker.Progress == ProviderMigrationProgress.NoClients) + { + return new ProviderMigrationResult + { + ProviderId = providerTracker.ProviderId, + ProviderName = providerTracker.ProviderName, + Result = providerTracker.Progress.ToString() + }; + } + var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId => migrationTrackerCache.GetTracker(providerId, organizationId))); @@ -99,12 +120,10 @@ public class ProviderMigrator( #region Steps - private async Task MigrateClientsAsync(Guid providerId) + private async Task MigrateClientsAsync(Guid providerId, List organizations) { logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId); - var organizations = await GetEnabledClientsAsync(providerId); - var organizationIds = organizations.Select(organization => organization.Id); await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds); @@ -129,7 +148,7 @@ public class ProviderMigrator( { logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId); - var organizations = await GetEnabledClientsAsync(providerId); + var organizations = await GetClientsAsync(providerId); var teamsSeats = organizations .Where(IsTeams) @@ -172,7 +191,7 @@ public class ProviderMigrator( { logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId); - var organizations = await GetEnabledClientsAsync(providerId); + var organizations = await GetClientsAsync(providerId); var enterpriseSeats = organizations .Where(IsEnterprise) @@ -215,7 +234,7 @@ public class ProviderMigrator( { if (string.IsNullOrEmpty(provider.GatewayCustomerId)) { - var organizations = await GetEnabledClientsAsync(provider.Id); + var organizations = await GetClientsAsync(provider.Id); var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId)); @@ -299,7 +318,7 @@ public class ProviderMigrator( private async Task ApplyCreditAsync(Provider provider) { - var organizations = await GetEnabledClientsAsync(provider.Id); + var organizations = await GetClientsAsync(provider.Id); var organizationCustomers = await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId))); @@ -355,13 +374,12 @@ public class ProviderMigrator( #region Utilities - private async Task> GetEnabledClientsAsync(Guid providerId) + private async Task> GetClientsAsync(Guid providerId) { var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); return (await Task.WhenAll(providerOrganizations.Select(providerOrganization => organizationRepository.GetByIdAsync(providerOrganization.OrganizationId)))) - .Where(organization => organization.Enabled) .ToList(); }