diff --git a/src/Core/Repositories/IProviderUserRepository.cs b/src/Core/Repositories/IProviderUserRepository.cs index ba920a575..9f0374fee 100644 --- a/src/Core/Repositories/IProviderUserRepository.cs +++ b/src/Core/Repositories/IProviderUserRepository.cs @@ -18,4 +18,5 @@ public interface IProviderUserRepository : IRepository Task DeleteManyAsync(IEnumerable userIds); Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids); Task GetCountByOnlyOwnerAsync(Guid userId); + Task> GetManyByOrganizationAsync(Guid organizationId, ProviderUserStatusType? status = null); } diff --git a/src/Core/Services/Implementations/OrganizationService.cs b/src/Core/Services/Implementations/OrganizationService.cs index 547050bc4..82d39a9d6 100644 --- a/src/Core/Services/Implementations/OrganizationService.cs +++ b/src/Core/Services/Implementations/OrganizationService.cs @@ -993,7 +993,7 @@ public class OrganizationService : IOrganizationService } } - var (organizationUsers, events) = await SaveUsersSendInvitesAsync(organizationId, invites); + var (organizationUsers, events) = await SaveUsersSendInvitesAsync(organizationId, invites, systemUser: null); await _eventService.LogOrganizationUserEventsAsync(events); @@ -1003,7 +1003,7 @@ public class OrganizationService : IOrganizationService public async Task> InviteUsersAsync(Guid organizationId, EventSystemUser systemUser, IEnumerable<(OrganizationUserInvite invite, string externalId)> invites) { - var (organizationUsers, events) = await SaveUsersSendInvitesAsync(organizationId, invites); + var (organizationUsers, events) = await SaveUsersSendInvitesAsync(organizationId, invites, systemUser); await _eventService.LogOrganizationUserEventsAsync(events.Select(e => (e.Item1, e.Item2, systemUser, e.Item3))); @@ -1011,7 +1011,7 @@ public class OrganizationService : IOrganizationService } private async Task<(List organizationUsers, List<(OrganizationUser, EventType, DateTime?)> events)> SaveUsersSendInvitesAsync(Guid organizationId, - IEnumerable<(OrganizationUserInvite invite, string externalId)> invites) + IEnumerable<(OrganizationUserInvite invite, string externalId)> invites, EventSystemUser? systemUser) { var organization = await GetOrgById(organizationId); var initialSeatCount = organization.Seats; @@ -1040,7 +1040,7 @@ public class OrganizationService : IOrganizationService } var invitedAreAllOwners = invites.All(i => i.invite.Type == OrganizationUserType.Owner); - if (!invitedAreAllOwners && !await HasConfirmedOwnersExceptAsync(organizationId, new Guid[] { })) + if (!invitedAreAllOwners && !await HasConfirmedOwnersExceptAsync(organizationId, new Guid[] { }, includeProvider: true)) { throw new BadRequestException("Organization must have at least one confirmed owner."); } @@ -1596,7 +1596,7 @@ public class OrganizationService : IOrganizationService throw new BadRequestException("Only owners can delete other owners."); } - if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { organizationUserId })) + if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { organizationUserId }, includeProvider: true)) { throw new BadRequestException("Organization must have at least one confirmed owner."); } @@ -1700,7 +1700,7 @@ public class OrganizationService : IOrganizationService bool hasOtherOwner = confirmedOwnersIds.Except(organizationUsersId).Any(); if (!hasOtherOwner && includeProvider) { - return (await _currentContext.ProviderIdForOrg(organizationId)).HasValue; + return (await _providerUserRepository.GetManyByOrganizationAsync(organizationId, ProviderUserStatusType.Confirmed)).Any(); } return hasOtherOwner; } @@ -2272,7 +2272,7 @@ public class OrganizationService : IOrganizationService throw new BadRequestException("Already revoked."); } - if (!await HasConfirmedOwnersExceptAsync(organizationUser.OrganizationId, new[] { organizationUser.Id })) + if (!await HasConfirmedOwnersExceptAsync(organizationUser.OrganizationId, new[] { organizationUser.Id }, includeProvider: true)) { throw new BadRequestException("Organization must have at least one confirmed owner."); } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs index a0e2b6989..114ed53f5 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs @@ -160,4 +160,17 @@ public class ProviderUserRepository : Repository, IProviderU return results; } } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, ProviderUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByOrganizationIdStatus]", + new { OrganizationId = organizationId, Status = status }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs index cad923c26..e858347fe 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs @@ -180,4 +180,19 @@ public class ProviderUserRepository : .CountAsync(); } } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, ProviderUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + join po in dbContext.ProviderOrganizations + on pu.ProviderId equals po.ProviderId + where po.OrganizationId == organizationId && + (status == null || pu.Status == status) + select pu; + return await query.ToArrayAsync(); + } + } } diff --git a/src/Sql/dbo/Stored Procedures/ProviderUser_ReadByOrganizationIdStatus.sql b/src/Sql/dbo/Stored Procedures/ProviderUser_ReadByOrganizationIdStatus.sql new file mode 100644 index 000000000..c581a4c9f --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/ProviderUser_ReadByOrganizationIdStatus.sql @@ -0,0 +1,17 @@ +CREATE PROCEDURE [dbo].[ProviderUser_ReadByOrganizationIdStatus] + @OrganizationId UNIQUEIDENTIFIER, + @Status TINYINT +AS +BEGIN + SET NOCOUNT ON + + SELECT + PU.* + FROM + [dbo].[ProviderUserView] PU + INNER JOIN [dbo].[ProviderOrganizationView] as PO + ON PU.[ProviderId] = PO.[ProviderId] + WHERE + PO.[OrganizationId] = @OrganizationId + AND (@Status IS NULL OR PU.[Status] = @Status) +END \ No newline at end of file diff --git a/test/Core.Test/Services/OrganizationServiceTests.cs b/test/Core.Test/Services/OrganizationServiceTests.cs index 9275da54a..b208ddecb 100644 --- a/test/Core.Test/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/Services/OrganizationServiceTests.cs @@ -6,7 +6,9 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.Entities.Provider; using Bit.Core.Enums; +using Bit.Core.Enums.Provider; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.Data; @@ -1415,4 +1417,56 @@ public class OrganizationServiceTests await eventService.Received() .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored, eventSystemUser); } + + [Theory, BitAutoData] + public async Task HasConfirmedOwnersExcept_WithConfirmedOwner_ReturnsTrue(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new List { owner }); + + var result = await sutProvider.Sut.HasConfirmedOwnersExceptAsync(organization.Id, new List(), true); + + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasConfirmedOwnersExcept_ExcludingConfirmedOwner_ReturnsFalse(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new List { owner }); + + var result = await sutProvider.Sut.HasConfirmedOwnersExceptAsync(organization.Id, new List { owner.Id }, true); + + Assert.False(result); + } + + [Theory, BitAutoData] + public async Task HasConfirmedOwnersExcept_WithInvitedOwner_ReturnsFalse(Organization organization, [OrganizationUser(OrganizationUserStatusType.Invited, OrganizationUserType.Owner)] OrganizationUser owner, SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new List { owner }); + + var result = await sutProvider.Sut.HasConfirmedOwnersExceptAsync(organization.Id, new List(), true); + + Assert.False(result); + } + + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task HasConfirmedOwnersExcept_WithConfirmedProviderUser_IncludeProviderTrue_ReturnsTrue(bool includeProvider, Organization organization, ProviderUser providerUser, SutProvider sutProvider) + { + providerUser.Status = ProviderUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organization.Id, ProviderUserStatusType.Confirmed) + .Returns(new List { providerUser }); + + var result = await sutProvider.Sut.HasConfirmedOwnersExceptAsync(organization.Id, new List(), includeProvider); + + Assert.Equal(includeProvider, result); + } } diff --git a/util/Migrator/DbScripts/2023-05-03_00_ProviderUserReadByOrganizationIdStatus.sql b/util/Migrator/DbScripts/2023-05-03_00_ProviderUserReadByOrganizationIdStatus.sql new file mode 100644 index 000000000..c1076e509 --- /dev/null +++ b/util/Migrator/DbScripts/2023-05-03_00_ProviderUserReadByOrganizationIdStatus.sql @@ -0,0 +1,18 @@ +CREATE OR ALTER PROCEDURE [dbo].[ProviderUser_ReadByOrganizationIdStatus] + @OrganizationId UNIQUEIDENTIFIER, + @Status TINYINT +AS +BEGIN + SET NOCOUNT ON + + SELECT + PU.* + FROM + [dbo].[ProviderUserView] PU + INNER JOIN [dbo].[ProviderOrganizationView] as PO + ON PU.[ProviderId] = PO.[ProviderId] + WHERE + PO.[OrganizationId] = @OrganizationId + AND (@Status IS NULL OR PU.[Status] = @Status) +END +GO \ No newline at end of file