diff --git a/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs b/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs index 88c41d99c..1d508cae1 100644 --- a/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs +++ b/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs @@ -1,5 +1,6 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Utilities; namespace Bit.Api.Models.Response; @@ -39,5 +40,6 @@ public class ProfileProviderOrganizationResponseModel : ProfileOrganizationRespo UserId = organization.UserId?.ToString(); ProviderId = organization.ProviderId?.ToString(); ProviderName = organization.ProviderName; + PlanProductType = StaticStore.GetPlan(organization.PlanType).Product; } } diff --git a/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs b/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs index e121962e6..2d06fc429 100644 --- a/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs @@ -34,4 +34,5 @@ public class ProviderUserOrganizationDetails public Guid? ProviderId { get; set; } public Guid? ProviderUserId { get; set; } public string ProviderName { get; set; } + public Enums.PlanType PlanType { get; set; } } diff --git a/src/Core/Repositories/IProviderUserRepository.cs b/src/Core/Repositories/IProviderUserRepository.cs index 4a5db368e..ba920a575 100644 --- a/src/Core/Repositories/IProviderUserRepository.cs +++ b/src/Core/Repositories/IProviderUserRepository.cs @@ -11,7 +11,7 @@ public interface IProviderUserRepository : IRepository Task> GetManyByUserAsync(Guid userId); Task GetByProviderUserAsync(Guid providerId, Guid userId); Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); - Task> GetManyDetailsByProviderAsync(Guid providerId); + Task> GetManyDetailsByProviderAsync(Guid providerId, ProviderUserStatusType? status = null); Task> GetManyDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null); Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null); diff --git a/src/Core/Services/Implementations/OrganizationService.cs b/src/Core/Services/Implementations/OrganizationService.cs index 88feb8c80..bff3b9fa4 100644 --- a/src/Core/Services/Implementations/OrganizationService.cs +++ b/src/Core/Services/Implementations/OrganizationService.cs @@ -2,6 +2,7 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.Enums.Provider; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.Data; @@ -43,6 +44,8 @@ public class OrganizationService : IOrganizationService private readonly IOrganizationConnectionRepository _organizationConnectionRepository; private readonly ICurrentContext _currentContext; private readonly ILogger _logger; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IProviderUserRepository _providerUserRepository; public OrganizationService( IOrganizationRepository organizationRepository, @@ -69,7 +72,9 @@ public class OrganizationService : IOrganizationService IOrganizationApiKeyRepository organizationApiKeyRepository, IOrganizationConnectionRepository organizationConnectionRepository, ICurrentContext currentContext, - ILogger logger) + ILogger logger, + IProviderOrganizationRepository providerOrganizationRepository, + IProviderUserRepository providerUserRepository) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -96,6 +101,8 @@ public class OrganizationService : IOrganizationService _organizationConnectionRepository = organizationConnectionRepository; _currentContext = currentContext; _logger = logger; + _providerOrganizationRepository = providerOrganizationRepository; + _providerUserRepository = providerUserRepository; } public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, @@ -1635,8 +1642,19 @@ public class OrganizationService : IOrganizationService throw new BadRequestException(failureMessage); } - var ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, - OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); + var providerOrg = await this._providerOrganizationRepository.GetByOrganizationId(organization.Id); + + IEnumerable ownerEmails; + if (providerOrg != null) + { + ownerEmails = (await _providerUserRepository.GetManyDetailsByProviderAsync(providerOrg.ProviderId, ProviderUserStatusType.Confirmed)) + .Select(u => u.Email).Distinct(); + } + else + { + ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, + OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); + } var initialSeatCount = organization.Seats.Value; await AdjustSeatsAsync(organization, seatsToAdd, prorationDate, ownerEmails); diff --git a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs index 0e1138e14..a0e2b6989 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs @@ -84,13 +84,13 @@ public class ProviderUserRepository : Repository, IProviderU } } - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + public async Task> GetManyDetailsByProviderAsync(Guid providerId, ProviderUserStatusType? status) { using (var connection = new SqlConnection(ConnectionString)) { var results = await connection.QueryAsync( "[dbo].[ProviderUserUserDetails_ReadByProviderId]", - new { ProviderId = providerId }, + new { ProviderId = providerId, Status = status }, commandType: CommandType.StoredProcedure); return results.ToList(); diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs index 76236a1c3..cad923c26 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs @@ -103,7 +103,7 @@ public class ProviderUserRepository : return await query.FirstOrDefaultAsync(); } } - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + public async Task> GetManyDetailsByProviderAsync(Guid providerId, ProviderUserStatusType? status) { using (var scope = ServiceScopeFactory.CreateScope()) { @@ -113,17 +113,19 @@ public class ProviderUserRepository : on pu.UserId equals u.Id into u_g from u in u_g.DefaultIfEmpty() select new { pu, u }; - var data = await view.Where(e => e.pu.ProviderId == providerId).Select(e => new ProviderUserUserDetails - { - Id = e.pu.Id, - UserId = e.pu.UserId, - ProviderId = e.pu.ProviderId, - Name = e.u.Name, - Email = e.u.Email ?? e.pu.Email, - Status = e.pu.Status, - Type = e.pu.Type, - Permissions = e.pu.Permissions, - }).ToArrayAsync(); + var data = await view + .Where(e => e.pu.ProviderId == providerId && (status == null || e.pu.Status == status)) + .Select(e => new ProviderUserUserDetails + { + Id = e.pu.Id, + UserId = e.pu.UserId, + ProviderId = e.pu.ProviderId, + Name = e.u.Name, + Email = e.u.Email ?? e.pu.Email, + Status = e.pu.Status, + Type = e.pu.Type, + Permissions = e.pu.Permissions, + }).ToArrayAsync(); return data; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs index ef1359985..d578012f4 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs @@ -41,6 +41,7 @@ public class ProviderUserOrganizationDetailsViewQuery : IQuery sutProvider) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).ReturnsNull(); + + async Task GetAction() + { + return await sutProvider.Sut.Get(); + } + + await Assert.ThrowsAsync((Func>)GetAction); + } + + [Theory] + [BitAutoData] + public async Task Get_Success_AtLeastOneEnabledOrg(User user, + List> userEquivalentDomains, + List userExcludedGlobalEquivalentDomains, + ICollection organizationUserDetails, + ICollection providerUserDetails, + IEnumerable providerUserOrganizationDetails, + ICollection folders, + ICollection ciphers, + ICollection sends, + ICollection policies, + ICollection collections, + SutProvider sutProvider) + { + // Get dependencies + var userService = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var providerUserRepository = sutProvider.GetDependency(); + var folderRepository = sutProvider.GetDependency(); + var cipherRepository = sutProvider.GetDependency(); + var sendRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var collectionRepository = sutProvider.GetDependency(); + var collectionCipherRepository = sutProvider.GetDependency(); + + // Adjust random data to match required formats / test intentions + user.EquivalentDomains = JsonSerializer.Serialize(userEquivalentDomains); + user.ExcludedGlobalEquivalentDomains = JsonSerializer.Serialize(userExcludedGlobalEquivalentDomains); + + // At least 1 org needs to be enabled to fully test + if (!organizationUserDetails.Any(o => o.Enabled)) + { + // We need at least 1 enabled org + if (organizationUserDetails.Count > 0) + { + organizationUserDetails.First().Enabled = true; + } + else + { + // create an enabled org + var enabledOrg = new Fixture().Create(); + enabledOrg.Enabled = true; + organizationUserDetails.Add((enabledOrg)); + } + } + + // Setup returns + userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); + + organizationUserRepository + .GetManyDetailsByUserAsync(user.Id, OrganizationUserStatusType.Confirmed).Returns(organizationUserDetails); + + providerUserRepository + .GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed).Returns(providerUserDetails); + + providerUserRepository + .GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed) + .Returns(providerUserOrganizationDetails); + + folderRepository.GetManyByUserIdAsync(user.Id).Returns(folders); + cipherRepository.GetManyByUserIdAsync(user.Id).Returns(ciphers); + + sendRepository + .GetManyByUserIdAsync(user.Id).Returns(sends); + + policyRepository.GetManyByUserIdAsync(user.Id).Returns(policies); + + // Returns for methods only called if we have enabled orgs + collectionRepository.GetManyByUserIdAsync(user.Id).Returns(collections); + collectionCipherRepository.GetManyByUserIdAsync(user.Id).Returns(new List()); + + // Back to standard test setup + userService.TwoFactorIsEnabledAsync(user).Returns(false); + userService.HasPremiumFromOrganization(user).Returns(false); + + // Execute GET + var result = await sutProvider.Sut.Get(); + + + // Asserts + // Assert that methods are called + var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); + this.AssertMethodsCalledAsync(userService, organizationUserRepository, providerUserRepository, folderRepository, + cipherRepository, sendRepository, collectionRepository, collectionCipherRepository, hasEnabledOrgs); + + Assert.IsType(result); + + // Collections should not be empty when at least 1 org is enabled + Assert.NotEmpty(result.Collections); + } + + + [Theory] + [BitAutoData] + public async Task Get_Success_AllDisabledOrgs(User user, + List> userEquivalentDomains, + List userExcludedGlobalEquivalentDomains, + ICollection organizationUserDetails, + ICollection providerUserDetails, + IEnumerable providerUserOrganizationDetails, + ICollection folders, + ICollection ciphers, + ICollection sends, + ICollection policies, + SutProvider sutProvider) + { + // Get dependencies + var userService = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var providerUserRepository = sutProvider.GetDependency(); + var folderRepository = sutProvider.GetDependency(); + var cipherRepository = sutProvider.GetDependency(); + var sendRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var collectionRepository = sutProvider.GetDependency(); + var collectionCipherRepository = sutProvider.GetDependency(); + + // Adjust random data to match required formats / test intentions + user.EquivalentDomains = JsonSerializer.Serialize(userEquivalentDomains); + user.ExcludedGlobalEquivalentDomains = JsonSerializer.Serialize(userExcludedGlobalEquivalentDomains); + + // All orgs disabled + if (organizationUserDetails.Count > 0) + { + foreach (var orgUserDetails in organizationUserDetails) + { + orgUserDetails.Enabled = false; + } + } + else + { + var disabledOrg = new Fixture().Create(); + disabledOrg.Enabled = false; + organizationUserDetails.Add((disabledOrg)); + } + + + // Setup returns + userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); + + organizationUserRepository + .GetManyDetailsByUserAsync(user.Id, OrganizationUserStatusType.Confirmed).Returns(organizationUserDetails); + + providerUserRepository + .GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed).Returns(providerUserDetails); + + providerUserRepository + .GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed) + .Returns(providerUserOrganizationDetails); + + folderRepository.GetManyByUserIdAsync(user.Id).Returns(folders); + cipherRepository.GetManyByUserIdAsync(user.Id).Returns(ciphers); + + sendRepository + .GetManyByUserIdAsync(user.Id).Returns(sends); + + policyRepository.GetManyByUserIdAsync(user.Id).Returns(policies); + + userService.TwoFactorIsEnabledAsync(user).Returns(false); + userService.HasPremiumFromOrganization(user).Returns(false); + + // Execute GET + var result = await sutProvider.Sut.Get(); + + + // Asserts + // Assert that methods are called + + var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); + this.AssertMethodsCalledAsync(userService, organizationUserRepository, providerUserRepository, folderRepository, + cipherRepository, sendRepository, collectionRepository, collectionCipherRepository, hasEnabledOrgs); + + Assert.IsType(result); + + // Collections should be empty when all standard orgs are disabled. + Assert.Empty(result.Collections); + } + + + // Test where provider org has specific plan type and assert plan type comes out on SyncResponseModel class on ProfileResponseModel + [Theory] + [BitAutoData] + public async Task Get_ProviderPlanTypeProperlyPopulated(User user, + List> userEquivalentDomains, + List userExcludedGlobalEquivalentDomains, + ICollection organizationUserDetails, + ICollection providerUserDetails, + IEnumerable providerUserOrganizationDetails, + ICollection folders, + ICollection ciphers, + ICollection sends, + ICollection policies, + ICollection collections, + SutProvider sutProvider) + { + // Get dependencies + var userService = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var providerUserRepository = sutProvider.GetDependency(); + var folderRepository = sutProvider.GetDependency(); + var cipherRepository = sutProvider.GetDependency(); + var sendRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var collectionRepository = sutProvider.GetDependency(); + var collectionCipherRepository = sutProvider.GetDependency(); + + // Adjust random data to match required formats / test intentions + user.EquivalentDomains = JsonSerializer.Serialize(userEquivalentDomains); + user.ExcludedGlobalEquivalentDomains = JsonSerializer.Serialize(userExcludedGlobalEquivalentDomains); + + + // Setup returns + userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); + + organizationUserRepository + .GetManyDetailsByUserAsync(user.Id, OrganizationUserStatusType.Confirmed).Returns(organizationUserDetails); + + providerUserRepository + .GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed).Returns(providerUserDetails); + + providerUserRepository + .GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed) + .Returns(providerUserOrganizationDetails); + + folderRepository.GetManyByUserIdAsync(user.Id).Returns(folders); + cipherRepository.GetManyByUserIdAsync(user.Id).Returns(ciphers); + + sendRepository + .GetManyByUserIdAsync(user.Id).Returns(sends); + + policyRepository.GetManyByUserIdAsync(user.Id).Returns(policies); + + // Returns for methods only called if we have enabled orgs + collectionRepository.GetManyByUserIdAsync(user.Id).Returns(collections); + collectionCipherRepository.GetManyByUserIdAsync(user.Id).Returns(new List()); + + // Back to standard test setup + userService.TwoFactorIsEnabledAsync(user).Returns(false); + userService.HasPremiumFromOrganization(user).Returns(false); + + // Execute GET + var result = await sutProvider.Sut.Get(); + + // Asserts + // Assert that methods are called + + var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); + this.AssertMethodsCalledAsync(userService, organizationUserRepository, providerUserRepository, folderRepository, + cipherRepository, sendRepository, collectionRepository, collectionCipherRepository, hasEnabledOrgs); + + Assert.IsType(result); + + // Look up ProviderOrg output and compare to ProviderOrg method inputs to ensure + // product type is set correctly. + foreach (var profProviderOrg in result.Profile.ProviderOrganizations) + { + var matchedProviderUserOrgDetails = + providerUserOrganizationDetails.FirstOrDefault(p => p.OrganizationId.ToString() == profProviderOrg.Id); + + if (matchedProviderUserOrgDetails != null) + { + var providerOrgProductType = StaticStore.GetPlan(matchedProviderUserOrgDetails.PlanType).Product; + Assert.Equal(providerOrgProductType, profProviderOrg.PlanProductType); + } + } + } + + + private async void AssertMethodsCalledAsync(IUserService userService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, IFolderRepository folderRepository, + ICipherRepository cipherRepository, ISendRepository sendRepository, + ICollectionRepository collectionRepository, + ICollectionCipherRepository collectionCipherRepository, + bool hasEnabledOrgs) + { + await userService.ReceivedWithAnyArgs(1).GetUserByPrincipalAsync(default); + await organizationUserRepository.ReceivedWithAnyArgs(1) + .GetManyDetailsByUserAsync(default); + await providerUserRepository.ReceivedWithAnyArgs(1) + .GetManyDetailsByUserAsync(default); + await providerUserRepository.ReceivedWithAnyArgs(1) + .GetManyOrganizationDetailsByUserAsync(default); + + await folderRepository.ReceivedWithAnyArgs(1) + .GetManyByUserIdAsync(default); + + await cipherRepository.ReceivedWithAnyArgs(1) + .GetManyByUserIdAsync(default); + + await sendRepository.ReceivedWithAnyArgs(1) + .GetManyByUserIdAsync(default); + + // These two are only called when at least 1 enabled org. + if (hasEnabledOrgs) + { + await collectionRepository.ReceivedWithAnyArgs(1) + .GetManyByUserIdAsync(default); + await collectionCipherRepository.ReceivedWithAnyArgs(1) + .GetManyByUserIdAsync(default); + } + else + { + // all disabled orgs + await collectionRepository.ReceivedWithAnyArgs(0) + .GetManyByUserIdAsync(default); + await collectionCipherRepository.ReceivedWithAnyArgs(0) + .GetManyByUserIdAsync(default); + } + + await userService.ReceivedWithAnyArgs(1) + .TwoFactorIsEnabledAsync(default); + await userService.ReceivedWithAnyArgs(1) + .HasPremiumFromOrganization(default); + } +} diff --git a/util/Migrator/DbScripts/2023-01-24_00_AutoscalingProviderOrgFixes.sql b/util/Migrator/DbScripts/2023-01-24_00_AutoscalingProviderOrgFixes.sql new file mode 100644 index 000000000..ccea51823 --- /dev/null +++ b/util/Migrator/DbScripts/2023-01-24_00_AutoscalingProviderOrgFixes.sql @@ -0,0 +1,71 @@ +-- SG-992 changes: add planType to provider orgs +CREATE OR ALTER VIEW [dbo].[ProviderUserProviderOrganizationDetailsView] +AS +SELECT + PU.[UserId], + PO.[OrganizationId], + O.[Name], + O.[Enabled], + O.[UsePolicies], + O.[UseSso], + O.[UseKeyConnector], + O.[UseScim], + O.[UseGroups], + O.[UseDirectory], + O.[UseEvents], + O.[UseTotp], + O.[Use2fa], + O.[UseApi], + O.[UseResetPassword], + O.[SelfHost], + O.[UsersGetPremium], + O.[UseCustomPermissions], + O.[Seats], + O.[MaxCollections], + O.[MaxStorageGb], + O.[Identifier], + PO.[Key], + O.[PublicKey], + O.[PrivateKey], + PU.[Status], + PU.[Type], + PO.[ProviderId], + PU.[Id] ProviderUserId, + P.[Name] ProviderName, + O.[PlanType] -- new prop +FROM + [dbo].[ProviderUser] PU + INNER JOIN + [dbo].[ProviderOrganization] PO ON PO.[ProviderId] = PU.[ProviderId] + INNER JOIN + [dbo].[Organization] O ON O.[Id] = PO.[OrganizationId] + INNER JOIN + [dbo].[Provider] P ON P.[Id] = PU.[ProviderId] + GO + + +-- Refresh metadata of stored procs & functions that use the updated view +IF OBJECT_ID('[dbo].[ProviderUserProviderOrganizationDetails_ReadByUserIdStatus]') IS NOT NULL +BEGIN + EXECUTE sp_refreshsqlmodule N'[dbo].[ProviderUserProviderOrganizationDetails_ReadByUserIdStatus]'; +END +GO + + +-- EC-591 / SG-996 changes: add optional status to stored proc +CREATE OR ALTER PROCEDURE [dbo].[ProviderUserUserDetails_ReadByProviderId] +@ProviderId UNIQUEIDENTIFIER, +@Status TINYINT = NULL -- new: this is required to be backwards compatible +AS +BEGIN + SET NOCOUNT ON + +SELECT + * +FROM + [dbo].[ProviderUserUserDetailsView] +WHERE + [ProviderId] = @ProviderId + AND [Status] = COALESCE(@Status, [Status]) -- new +END +GO