1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-24 12:35:25 +01:00

Turn on file scoped namespaces (#2225)

This commit is contained in:
Justin Baur 2022-08-29 14:53:16 -04:00 committed by GitHub
parent 7c4521e0b4
commit 34fb4cca2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1206 changed files with 73816 additions and 75022 deletions

View File

@ -114,6 +114,9 @@ csharp_new_line_before_finally = true
csharp_new_line_before_members_in_object_initializers = true csharp_new_line_before_members_in_object_initializers = true
csharp_new_line_before_members_in_anonymous_types = true csharp_new_line_before_members_in_anonymous_types = true
# Namespace settigns
csharp_style_namespace_declarations = file_scoped:warning
# All files # All files
[*] [*]
guidelines = 120 guidelines = 120

View File

@ -13,497 +13,496 @@ using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.DataProtection;
namespace Bit.Commercial.Core.Services namespace Bit.Commercial.Core.Services;
public class ProviderService : IProviderService
{ {
public class ProviderService : IProviderService public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 };
private readonly IDataProtector _dataProtector;
private readonly IMailService _mailService;
private readonly IEventService _eventService;
private readonly GlobalSettings _globalSettings;
private readonly IProviderRepository _providerRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IUserRepository _userRepository;
private readonly IUserService _userService;
private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
IUserService userService, IOrganizationService organizationService, IMailService mailService,
IDataProtectionProvider dataProtectionProvider, IEventService eventService,
IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext)
{ {
public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; _providerRepository = providerRepository;
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_organizationRepository = organizationRepository;
_userRepository = userRepository;
_userService = userService;
_organizationService = organizationService;
_mailService = mailService;
_eventService = eventService;
_globalSettings = globalSettings;
_dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
_currentContext = currentContext;
}
private readonly IDataProtector _dataProtector; public async Task CreateAsync(string ownerEmail)
private readonly IMailService _mailService; {
private readonly IEventService _eventService; var owner = await _userRepository.GetByEmailAsync(ownerEmail);
private readonly GlobalSettings _globalSettings; if (owner == null)
private readonly IProviderRepository _providerRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IUserRepository _userRepository;
private readonly IUserService _userService;
private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
IUserService userService, IOrganizationService organizationService, IMailService mailService,
IDataProtectionProvider dataProtectionProvider, IEventService eventService,
IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext)
{ {
_providerRepository = providerRepository; throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user.");
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_organizationRepository = organizationRepository;
_userRepository = userRepository;
_userService = userService;
_organizationService = organizationService;
_mailService = mailService;
_eventService = eventService;
_globalSettings = globalSettings;
_dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
_currentContext = currentContext;
} }
public async Task CreateAsync(string ownerEmail) var provider = new Provider
{ {
var owner = await _userRepository.GetByEmailAsync(ownerEmail); Status = ProviderStatusType.Pending,
if (owner == null) Enabled = true,
{ UseEvents = true,
throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user."); };
} await _providerRepository.CreateAsync(provider);
var provider = new Provider var providerUser = new ProviderUser
{
ProviderId = provider.Id,
UserId = owner.Id,
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Confirmed,
};
await _providerUserRepository.CreateAsync(providerUser);
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
}
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key)
{
var owner = await _userService.GetUserByIdAsync(ownerUserId);
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
if (provider.Status != ProviderStatusType.Pending)
{
throw new BadRequestException("Provider is already setup.");
}
if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId);
if (!(providerUser is { Type: ProviderUserType.ProviderAdmin }))
{
throw new BadRequestException("Invalid owner.");
}
provider.Status = ProviderStatusType.Created;
await _providerRepository.UpsertAsync(provider);
providerUser.Key = key;
await _providerUserRepository.ReplaceAsync(providerUser);
return provider;
}
public async Task UpdateAsync(Provider provider, bool updateBilling = false)
{
if (provider.Id == default)
{
throw new ArgumentException("Cannot create provider this way.");
}
await _providerRepository.ReplaceAsync(provider);
}
public async Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{
throw new InvalidOperationException("Invalid permissions.");
}
var emails = invite?.UserIdentifiers;
var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
if (provider == null || emails == null || !emails.Any())
{
throw new NotFoundException();
}
var providerUsers = new List<ProviderUser>();
foreach (var email in emails)
{
// Make sure user is not already invited
var existingProviderUserCount =
await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false);
if (existingProviderUserCount > 0)
{ {
Status = ProviderStatusType.Pending, continue;
Enabled = true, }
UseEvents = true,
};
await _providerRepository.CreateAsync(provider);
var providerUser = new ProviderUser var providerUser = new ProviderUser
{ {
ProviderId = provider.Id, ProviderId = invite.ProviderId,
UserId = owner.Id, UserId = null,
Type = ProviderUserType.ProviderAdmin, Email = email.ToLowerInvariant(),
Status = ProviderUserStatusType.Confirmed, Key = null,
Type = invite.Type,
Status = ProviderUserStatusType.Invited,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
}; };
await _providerUserRepository.CreateAsync(providerUser); await _providerUserRepository.CreateAsync(providerUser);
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
await SendInviteAsync(providerUser, provider);
providerUsers.Add(providerUser);
} }
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?)));
return providerUsers;
}
public async Task<List<Tuple<ProviderUser, string>>> ResendInvitesAsync(ProviderUserInvite<Guid> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{ {
var owner = await _userService.GetUserByIdAsync(ownerUserId); throw new BadRequestException("Invalid permissions.");
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
if (provider.Status != ProviderStatusType.Pending)
{
throw new BadRequestException("Provider is already setup.");
}
if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId);
if (!(providerUser is { Type: ProviderUserType.ProviderAdmin }))
{
throw new BadRequestException("Invalid owner.");
}
provider.Status = ProviderStatusType.Created;
await _providerRepository.UpsertAsync(provider);
providerUser.Key = key;
await _providerUserRepository.ReplaceAsync(providerUser);
return provider;
} }
public async Task UpdateAsync(Provider provider, bool updateBilling = false) var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
var result = new List<Tuple<ProviderUser, string>>();
foreach (var providerUser in providerUsers)
{ {
if (provider.Id == default) if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId)
{ {
throw new ArgumentException("Cannot create provider this way."); result.Add(Tuple.Create(providerUser, "User invalid."));
continue;
} }
await _providerRepository.ReplaceAsync(provider); await SendInviteAsync(providerUser, provider);
result.Add(Tuple.Create(providerUser, ""));
} }
public async Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite) return result;
}
public async Task<ProviderUser> AcceptUserAsync(Guid providerUserId, User user, string token)
{
var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId);
if (providerUser == null)
{ {
if (!_currentContext.ProviderManageUsers(invite.ProviderId)) throw new BadRequestException("User invalid.");
}
if (providerUser.Status != ProviderUserStatusType.Invited)
{
throw new BadRequestException("Already accepted.");
}
if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
if (string.IsNullOrWhiteSpace(providerUser.Email) ||
!providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
throw new BadRequestException("User email does not match invite.");
}
providerUser.Status = ProviderUserStatusType.Accepted;
providerUser.UserId = user.Id;
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser);
return providerUser;
}
public async Task<List<Tuple<ProviderUser, string>>> ConfirmUsersAsync(Guid providerId, Dictionary<Guid, string> keys,
Guid confirmingUserId)
{
var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys);
var validProviderUsers = providerUsers
.Where(u => u.UserId != null)
.ToList();
if (!validProviderUsers.Any())
{
return new List<Tuple<ProviderUser, string>>();
}
var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList();
var provider = await _providerRepository.GetByIdAsync(providerId);
var users = await _userRepository.GetManyAsync(validOrganizationUserIds);
var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u);
var result = new List<Tuple<ProviderUser, string>>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var user in users)
{
if (!keyedFilteredUsers.ContainsKey(user.Id))
{ {
throw new InvalidOperationException("Invalid permissions."); continue;
} }
var providerUser = keyedFilteredUsers[user.Id];
var emails = invite?.UserIdentifiers; try
var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
if (provider == null || emails == null || !emails.Any())
{ {
throw new NotFoundException(); if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId)
}
var providerUsers = new List<ProviderUser>();
foreach (var email in emails)
{
// Make sure user is not already invited
var existingProviderUserCount =
await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false);
if (existingProviderUserCount > 0)
{ {
continue; throw new BadRequestException("Invalid user.");
} }
var providerUser = new ProviderUser providerUser.Status = ProviderUserStatusType.Confirmed;
{ providerUser.Key = keys[providerUser.Id];
ProviderId = invite.ProviderId, providerUser.Email = null;
UserId = null,
Email = email.ToLowerInvariant(),
Key = null,
Type = invite.Type,
Status = ProviderUserStatusType.Invited,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await _providerUserRepository.CreateAsync(providerUser); await _providerUserRepository.ReplaceAsync(providerUser);
events.Add((providerUser, EventType.ProviderUser_Confirmed, null));
await SendInviteAsync(providerUser, provider); await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email);
providerUsers.Add(providerUser);
}
await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?)));
return providerUsers;
}
public async Task<List<Tuple<ProviderUser, string>>> ResendInvitesAsync(ProviderUserInvite<Guid> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{
throw new BadRequestException("Invalid permissions.");
}
var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
var result = new List<Tuple<ProviderUser, string>>();
foreach (var providerUser in providerUsers)
{
if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId)
{
result.Add(Tuple.Create(providerUser, "User invalid."));
continue;
}
await SendInviteAsync(providerUser, provider);
result.Add(Tuple.Create(providerUser, "")); result.Add(Tuple.Create(providerUser, ""));
} }
catch (BadRequestException e)
return result; {
result.Add(Tuple.Create(providerUser, e.Message));
}
} }
public async Task<ProviderUser> AcceptUserAsync(Guid providerUserId, User user, string token) await _eventService.LogProviderUsersEventAsync(events);
return result;
}
public async Task SaveUserAsync(ProviderUser user, Guid savingUserId)
{
if (user.Id.Equals(default))
{ {
var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId); throw new BadRequestException("Invite the user first.");
if (providerUser == null)
{
throw new BadRequestException("User invalid.");
}
if (providerUser.Status != ProviderUserStatusType.Invited)
{
throw new BadRequestException("Already accepted.");
}
if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
if (string.IsNullOrWhiteSpace(providerUser.Email) ||
!providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
throw new BadRequestException("User email does not match invite.");
}
providerUser.Status = ProviderUserStatusType.Accepted;
providerUser.UserId = user.Id;
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser);
return providerUser;
} }
public async Task<List<Tuple<ProviderUser, string>>> ConfirmUsersAsync(Guid providerId, Dictionary<Guid, string> keys, if (user.Type != ProviderUserType.ProviderAdmin &&
Guid confirmingUserId) !await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id }))
{ {
var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys); throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
var validProviderUsers = providerUsers }
.Where(u => u.UserId != null)
.ToList();
if (!validProviderUsers.Any()) await _providerUserRepository.ReplaceAsync(user);
await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated);
}
public async Task<List<Tuple<ProviderUser, string>>> DeleteUsersAsync(Guid providerId,
IEnumerable<Guid> providerUserIds, Guid deletingUserId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
throw new NotFoundException();
}
var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds);
var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue)
.Select(pu => pu.UserId.Value));
var keyedUsers = users.ToDictionary(u => u.Id);
if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds))
{
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
}
var result = new List<Tuple<ProviderUser, string>>();
var deletedUserIds = new List<Guid>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var providerUser in providerUsers)
{
try
{ {
return new List<Tuple<ProviderUser, string>>(); if (providerUser.ProviderId != providerId)
}
var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList();
var provider = await _providerRepository.GetByIdAsync(providerId);
var users = await _userRepository.GetManyAsync(validOrganizationUserIds);
var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u);
var result = new List<Tuple<ProviderUser, string>>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var user in users)
{
if (!keyedFilteredUsers.ContainsKey(user.Id))
{ {
continue; throw new BadRequestException("Invalid user.");
} }
var providerUser = keyedFilteredUsers[user.Id]; if (providerUser.UserId == deletingUserId)
try
{ {
if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId) throw new BadRequestException("You cannot remove yourself.");
}
events.Add((providerUser, EventType.ProviderUser_Removed, null));
var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault());
var email = user == null ? providerUser.Email : user.Email;
if (!string.IsNullOrWhiteSpace(email))
{
await _mailService.SendProviderUserRemoved(provider.Name, email);
}
result.Add(Tuple.Create(providerUser, ""));
deletedUserIds.Add(providerUser.Id);
}
catch (BadRequestException e)
{
result.Add(Tuple.Create(providerUser, e.Message));
}
await _providerUserRepository.DeleteManyAsync(deletedUserIds);
}
await _eventService.LogProviderUsersEventAsync(events);
return result;
}
public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key)
{
var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
if (po != null)
{
throw new BadRequestException("Organization already belongs to a provider.");
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
ThrowOnInvalidPlanType(organization.PlanType);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organizationId,
Key = key,
};
await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added);
}
public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId,
OrganizationSignup organizationSignup, string clientOwnerEmail, User user)
{
ThrowOnInvalidPlanType(organizationSignup.Plan);
var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organization.Id,
Key = organizationSignup.OwnerKey,
};
await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created);
await _organizationService.InviteUsersAsync(organization.Id, user.Id,
new (OrganizationUserInvite, string)[]
{
(
new OrganizationUserInvite
{ {
throw new BadRequestException("Invalid user."); Emails = new[] { clientOwnerEmail },
} AccessAll = true,
Type = OrganizationUserType.Owner,
Permissions = null,
Collections = Array.Empty<SelectionReadOnly>(),
},
null
)
});
providerUser.Status = ProviderUserStatusType.Confirmed; return providerOrganization;
providerUser.Key = keys[providerUser.Id]; }
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser); public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId)
events.Add((providerUser, EventType.ProviderUser_Confirmed, null)); {
await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email); var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId);
result.Add(Tuple.Create(providerUser, "")); if (providerOrganization == null || providerOrganization.ProviderId != providerId)
} {
catch (BadRequestException e) throw new BadRequestException("Invalid organization.");
{
result.Add(Tuple.Create(providerUser, e.Message));
}
}
await _eventService.LogProviderUsersEventAsync(events);
return result;
} }
public async Task SaveUserAsync(ProviderUser user, Guid savingUserId) if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false))
{ {
if (user.Id.Equals(default)) throw new BadRequestException("Organization needs to have at least one confirmed owner.");
{
throw new BadRequestException("Invite the user first.");
}
if (user.Type != ProviderUserType.ProviderAdmin &&
!await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id }))
{
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
}
await _providerUserRepository.ReplaceAsync(user);
await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated);
} }
public async Task<List<Tuple<ProviderUser, string>>> DeleteUsersAsync(Guid providerId, await _providerOrganizationRepository.DeleteAsync(providerOrganization);
IEnumerable<Guid> providerUserIds, Guid deletingUserId) await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
}
public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
var owner = await _userRepository.GetByIdAsync(ownerId);
if (owner == null)
{ {
var provider = await _providerRepository.GetByIdAsync(providerId); throw new BadRequestException("Invalid owner.");
}
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
}
if (provider == null) private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail)
{ {
throw new NotFoundException(); var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
} await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail);
}
var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds); public async Task LogProviderAccessToOrganizationAsync(Guid organizationId)
var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue) {
.Select(pu => pu.UserId.Value)); if (organizationId == default)
var keyedUsers = users.ToDictionary(u => u.Id); {
return;
if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds))
{
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
}
var result = new List<Tuple<ProviderUser, string>>();
var deletedUserIds = new List<Guid>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var providerUser in providerUsers)
{
try
{
if (providerUser.ProviderId != providerId)
{
throw new BadRequestException("Invalid user.");
}
if (providerUser.UserId == deletingUserId)
{
throw new BadRequestException("You cannot remove yourself.");
}
events.Add((providerUser, EventType.ProviderUser_Removed, null));
var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault());
var email = user == null ? providerUser.Email : user.Email;
if (!string.IsNullOrWhiteSpace(email))
{
await _mailService.SendProviderUserRemoved(provider.Name, email);
}
result.Add(Tuple.Create(providerUser, ""));
deletedUserIds.Add(providerUser.Id);
}
catch (BadRequestException e)
{
result.Add(Tuple.Create(providerUser, e.Message));
}
await _providerUserRepository.DeleteManyAsync(deletedUserIds);
}
await _eventService.LogProviderUsersEventAsync(events);
return result;
} }
public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
var organization = await _organizationRepository.GetByIdAsync(organizationId);
if (providerOrganization != null)
{ {
var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId); await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed);
if (po != null)
{
throw new BadRequestException("Organization already belongs to a provider.");
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
ThrowOnInvalidPlanType(organization.PlanType);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organizationId,
Key = key,
};
await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added);
} }
if (organization != null)
public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId,
OrganizationSignup organizationSignup, string clientOwnerEmail, User user)
{ {
ThrowOnInvalidPlanType(organizationSignup.Plan); await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed);
var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organization.Id,
Key = organizationSignup.OwnerKey,
};
await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created);
await _organizationService.InviteUsersAsync(organization.Id, user.Id,
new (OrganizationUserInvite, string)[]
{
(
new OrganizationUserInvite
{
Emails = new[] { clientOwnerEmail },
AccessAll = true,
Type = OrganizationUserType.Owner,
Permissions = null,
Collections = Array.Empty<SelectionReadOnly>(),
},
null
)
});
return providerOrganization;
} }
}
public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) private async Task SendInviteAsync(ProviderUser providerUser, Provider provider)
{
var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow);
var token = _dataProtector.Protect(
$"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}");
await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email);
}
private async Task<bool> HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable<Guid> providerUserIds)
{
var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId,
ProviderUserType.ProviderAdmin);
var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed);
var confirmedOwnersIds = confirmedOwners.Select(u => u.Id);
return confirmedOwnersIds.Except(providerUserIds).Any();
}
private void ThrowOnInvalidPlanType(PlanType requestedType)
{
if (ProviderDisllowedOrganizationTypes.Contains(requestedType))
{ {
var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId); throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed.");
if (providerOrganization == null || providerOrganization.ProviderId != providerId)
{
throw new BadRequestException("Invalid organization.");
}
if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false))
{
throw new BadRequestException("Organization needs to have at least one confirmed owner.");
}
await _providerOrganizationRepository.DeleteAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
}
public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
var owner = await _userRepository.GetByIdAsync(ownerId);
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
}
private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail)
{
var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail);
}
public async Task LogProviderAccessToOrganizationAsync(Guid organizationId)
{
if (organizationId == default)
{
return;
}
var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
var organization = await _organizationRepository.GetByIdAsync(organizationId);
if (providerOrganization != null)
{
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed);
}
if (organization != null)
{
await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed);
}
}
private async Task SendInviteAsync(ProviderUser providerUser, Provider provider)
{
var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow);
var token = _dataProtector.Protect(
$"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}");
await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email);
}
private async Task<bool> HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable<Guid> providerUserIds)
{
var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId,
ProviderUserType.ProviderAdmin);
var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed);
var confirmedOwnersIds = confirmedOwners.Select(u => u.Id);
return confirmedOwnersIds.Except(providerUserIds).Any();
}
private void ThrowOnInvalidPlanType(PlanType requestedType)
{
if (ProviderDisllowedOrganizationTypes.Contains(requestedType))
{
throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed.");
}
} }
} }
} }

View File

@ -2,13 +2,12 @@
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
namespace Bit.Commercial.Core.Utilities namespace Bit.Commercial.Core.Utilities;
public static class ServiceCollectionExtensions
{ {
public static class ServiceCollectionExtensions public static void AddCommCoreServices(this IServiceCollection services)
{ {
public static void AddCommCoreServices(this IServiceCollection services) services.AddScoped<IProviderService, ProviderService>();
{
services.AddScoped<IProviderService, ProviderService>();
}
} }
} }

View File

@ -4,18 +4,17 @@ using Bit.Core.Models.OrganizationConnectionConfigs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Scim.Context namespace Bit.Scim.Context;
public interface IScimContext
{ {
public interface IScimContext ScimProviderType RequestScimProvider { get; set; }
{ ScimConfig ScimConfiguration { get; set; }
ScimProviderType RequestScimProvider { get; set; } Guid? OrganizationId { get; set; }
ScimConfig ScimConfiguration { get; set; } Organization Organization { get; set; }
Guid? OrganizationId { get; set; } Task BuildAsync(
Organization Organization { get; set; } HttpContext httpContext,
Task BuildAsync( GlobalSettings globalSettings,
HttpContext httpContext, IOrganizationRepository organizationRepository,
GlobalSettings globalSettings, IOrganizationConnectionRepository organizationConnectionRepository);
IOrganizationRepository organizationRepository,
IOrganizationConnectionRepository organizationConnectionRepository);
}
} }

View File

@ -4,61 +4,60 @@ using Bit.Core.Models.OrganizationConnectionConfigs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Scim.Context namespace Bit.Scim.Context;
public class ScimContext : IScimContext
{ {
public class ScimContext : IScimContext private bool _builtHttpContext;
public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default;
public ScimConfig ScimConfiguration { get; set; }
public Guid? OrganizationId { get; set; }
public Organization Organization { get; set; }
public async virtual Task BuildAsync(
HttpContext httpContext,
GlobalSettings globalSettings,
IOrganizationRepository organizationRepository,
IOrganizationConnectionRepository organizationConnectionRepository)
{ {
private bool _builtHttpContext; if (_builtHttpContext)
public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default;
public ScimConfig ScimConfiguration { get; set; }
public Guid? OrganizationId { get; set; }
public Organization Organization { get; set; }
public async virtual Task BuildAsync(
HttpContext httpContext,
GlobalSettings globalSettings,
IOrganizationRepository organizationRepository,
IOrganizationConnectionRepository organizationConnectionRepository)
{ {
if (_builtHttpContext) return;
{ }
return;
}
_builtHttpContext = true; _builtHttpContext = true;
string orgIdString = null; string orgIdString = null;
if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject)) if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject))
{ {
orgIdString = orgIdObject?.ToString(); orgIdString = orgIdObject?.ToString();
} }
if (Guid.TryParse(orgIdString, out var orgId)) if (Guid.TryParse(orgIdString, out var orgId))
{
OrganizationId = orgId;
Organization = await organizationRepository.GetByIdAsync(orgId);
if (Organization != null)
{ {
OrganizationId = orgId; var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id,
Organization = await organizationRepository.GetByIdAsync(orgId); OrganizationConnectionType.Scim);
if (Organization != null) ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig<ScimConfig>();
{
var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id,
OrganizationConnectionType.Scim);
ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig<ScimConfig>();
}
} }
}
if (RequestScimProvider == ScimProviderType.Default && if (RequestScimProvider == ScimProviderType.Default &&
httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent)) httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent))
{
if (userAgent.ToString().StartsWith("Okta"))
{ {
if (userAgent.ToString().StartsWith("Okta")) RequestScimProvider = ScimProviderType.Okta;
{
RequestScimProvider = ScimProviderType.Okta;
}
}
if (RequestScimProvider == ScimProviderType.Default &&
httpContext.Request.Headers.ContainsKey("Adscimversion"))
{
RequestScimProvider = ScimProviderType.AzureAd;
} }
} }
if (RequestScimProvider == ScimProviderType.Default &&
httpContext.Request.Headers.ContainsKey("Adscimversion"))
{
RequestScimProvider = ScimProviderType.AzureAd;
}
} }
} }

View File

@ -2,22 +2,21 @@
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Scim.Controllers namespace Bit.Scim.Controllers;
{
[AllowAnonymous]
public class InfoController : Controller
{
[HttpGet("~/alive")]
[HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")] [AllowAnonymous]
public JsonResult GetVersion() public class InfoController : Controller
{ {
return Json(CoreHelpers.GetVersion()); [HttpGet("~/alive")]
} [HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")]
public JsonResult GetVersion()
{
return Json(CoreHelpers.GetVersion());
} }
} }

View File

@ -8,321 +8,320 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Scim.Controllers.v2 namespace Bit.Scim.Controllers.v2;
[Authorize("Scim")]
[Route("v2/{organizationId}/groups")]
public class GroupsController : Controller
{ {
[Authorize("Scim")] private readonly ScimSettings _scimSettings;
[Route("v2/{organizationId}/groups")] private readonly IGroupRepository _groupRepository;
public class GroupsController : Controller private readonly IGroupService _groupService;
private readonly IScimContext _scimContext;
private readonly ILogger<GroupsController> _logger;
public GroupsController(
IGroupRepository groupRepository,
IGroupService groupService,
IOptions<ScimSettings> scimSettings,
IScimContext scimContext,
ILogger<GroupsController> logger)
{ {
private readonly ScimSettings _scimSettings; _scimSettings = scimSettings?.Value;
private readonly IGroupRepository _groupRepository; _groupRepository = groupRepository;
private readonly IGroupService _groupService; _groupService = groupService;
private readonly IScimContext _scimContext; _scimContext = scimContext;
private readonly ILogger<GroupsController> _logger; _logger = logger;
}
public GroupsController( [HttpGet("{id}")]
IGroupRepository groupRepository, public async Task<IActionResult> Get(Guid organizationId, Guid id)
IGroupService groupService, {
IOptions<ScimSettings> scimSettings, var group = await _groupRepository.GetByIdAsync(id);
IScimContext scimContext, if (group == null || group.OrganizationId != organizationId)
ILogger<GroupsController> logger)
{ {
_scimSettings = scimSettings?.Value; return new NotFoundObjectResult(new ScimErrorResponseModel
_groupRepository = groupRepository; {
_groupService = groupService; Status = 404,
_scimContext = scimContext; Detail = "Group not found."
_logger = logger; });
}
return new ObjectResult(new ScimGroupResponseModel(group));
}
[HttpGet("")]
public async Task<IActionResult> Get(
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{
string nameFilter = null;
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{
if (filter.StartsWith("displayName eq "))
{
nameFilter = filter.Substring(15).Trim('"');
}
else if (filter.StartsWith("externalId eq "))
{
externalIdFilter = filter.Substring(14).Trim('"');
}
} }
[HttpGet("{id}")] var groupList = new List<ScimGroupResponseModel>();
public async Task<IActionResult> Get(Guid organizationId, Guid id) var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
var totalResults = 0;
if (!string.IsNullOrWhiteSpace(nameFilter))
{ {
var group = await _groupRepository.GetByIdAsync(id); var group = groups.FirstOrDefault(g => g.Name == nameFilter);
if (group == null || group.OrganizationId != organizationId) if (group != null)
{ {
return new NotFoundObjectResult(new ScimErrorResponseModel groupList.Add(new ScimGroupResponseModel(group));
{
Status = 404,
Detail = "Group not found."
});
} }
return new ObjectResult(new ScimGroupResponseModel(group)); totalResults = groupList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
groupList = groups.OrderBy(g => g.Name)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(g => new ScimGroupResponseModel(g))
.ToList();
totalResults = groups.Count;
} }
[HttpGet("")] var result = new ScimListResponseModel<ScimGroupResponseModel>
public async Task<IActionResult> Get(
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{ {
string nameFilter = null; Resources = groupList,
string externalIdFilter = null; ItemsPerPage = count.GetValueOrDefault(groupList.Count),
if (!string.IsNullOrWhiteSpace(filter)) TotalResults = totalResults,
{ StartIndex = startIndex.GetValueOrDefault(1),
if (filter.StartsWith("displayName eq ")) };
{ return new ObjectResult(result);
nameFilter = filter.Substring(15).Trim('"'); }
}
else if (filter.StartsWith("externalId eq "))
{
externalIdFilter = filter.Substring(14).Trim('"');
}
}
var groupList = new List<ScimGroupResponseModel>(); [HttpPost("")]
var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimGroupRequestModel model)
var totalResults = 0; {
if (!string.IsNullOrWhiteSpace(nameFilter)) if (string.IsNullOrWhiteSpace(model.DisplayName))
{ {
var group = groups.FirstOrDefault(g => g.Name == nameFilter); return new BadRequestResult();
if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
groupList = groups.OrderBy(g => g.Name)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(g => new ScimGroupResponseModel(g))
.ToList();
totalResults = groups.Count;
}
var result = new ScimListResponseModel<ScimGroupResponseModel>
{
Resources = groupList,
ItemsPerPage = count.GetValueOrDefault(groupList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
} }
[HttpPost("")] var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimGroupRequestModel model) if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId))
{ {
if (string.IsNullOrWhiteSpace(model.DisplayName)) return new ConflictResult();
{
return new BadRequestResult();
}
var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId))
{
return new ConflictResult();
}
var group = model.ToGroup(organizationId);
await _groupService.SaveAsync(group, null);
await UpdateGroupMembersAsync(group, model, true);
var response = new ScimGroupResponseModel(group);
return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response);
} }
[HttpPut("{id}")] var group = model.ToGroup(organizationId);
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model) await _groupService.SaveAsync(group, null);
{ await UpdateGroupMembersAsync(group, model, true);
var group = await _groupRepository.GetByIdAsync(id); var response = new ScimGroupResponseModel(group);
if (group == null || group.OrganizationId != organizationId) return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response);
{ }
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
group.Name = model.DisplayName; [HttpPut("{id}")]
await _groupService.SaveAsync(group); public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model)
await UpdateGroupMembersAsync(group, model, false); {
return new ObjectResult(new ScimGroupResponseModel(group)); var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
} }
[HttpPatch("{id}")] group.Name = model.DisplayName;
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) await _groupService.SaveAsync(group);
{ await UpdateGroupMembersAsync(group, model, false);
var group = await _groupRepository.GetByIdAsync(id); return new ObjectResult(new ScimGroupResponseModel(group));
if (group == null || group.OrganizationId != organizationId) }
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
var operationHandled = false; [HttpPatch("{id}")]
foreach (var operation in model.Operations) public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{ {
// Replace operations Status = 404,
if (operation.Op?.ToLowerInvariant() == "replace") Detail = "Group not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Replace a list of members
if (operation.Path?.ToLowerInvariant() == "members")
{ {
// Replace a list of members var ids = GetOperationValueIds(operation.Value);
if (operation.Path?.ToLowerInvariant() == "members") await _groupRepository.UpdateUsersAsync(group.Id, ids);
{
var ids = GetOperationValueIds(operation.Value);
await _groupRepository.UpdateUsersAsync(group.Id, ids);
operationHandled = true;
}
// Replace group name from path
else if (operation.Path?.ToLowerInvariant() == "displayname")
{
group.Name = operation.Value.GetString();
await _groupService.SaveAsync(group);
operationHandled = true;
}
// Replace group name from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty))
{
group.Name = displayNameProperty.GetString();
await _groupService.SaveAsync(group);
operationHandled = true;
}
}
// Add a single member
else if (operation.Op?.ToLowerInvariant() == "add" &&
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{
var addId = GetOperationPathId(operation.Path);
if (addId.HasValue)
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
orgUserIds.Add(addId.Value);
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
}
}
// Add a list of members
else if (operation.Op?.ToLowerInvariant() == "add" &&
operation.Path?.ToLowerInvariant() == "members")
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Add(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true; operationHandled = true;
} }
// Remove a single member // Replace group name from path
else if (operation.Op?.ToLowerInvariant() == "remove" && else if (operation.Path?.ToLowerInvariant() == "displayname")
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{ {
var removeId = GetOperationPathId(operation.Path); group.Name = operation.Value.GetString();
if (removeId.HasValue) await _groupService.SaveAsync(group);
{ operationHandled = true;
await _groupService.DeleteUserAsync(group, removeId.Value);
operationHandled = true;
}
} }
// Remove a list of members // Replace group name from value object
else if (operation.Op?.ToLowerInvariant() == "remove" && else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path?.ToLowerInvariant() == "members") operation.Value.TryGetProperty("displayName", out var displayNameProperty))
{
group.Name = displayNameProperty.GetString();
await _groupService.SaveAsync(group);
operationHandled = true;
}
}
// Add a single member
else if (operation.Op?.ToLowerInvariant() == "add" &&
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{
var addId = GetOperationPathId(operation.Path);
if (addId.HasValue)
{ {
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value)) orgUserIds.Add(addId.Value);
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true; operationHandled = true;
} }
} }
// Add a list of members
if (!operationHandled) else if (operation.Op?.ToLowerInvariant() == "add" &&
operation.Path?.ToLowerInvariant() == "members")
{ {
_logger.LogWarning("Group patch operation not handled: {0} : ", var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); foreach (var v in GetOperationValueIds(operation.Value))
}
return new NoContentResult();
}
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{ {
Status = 404, orgUserIds.Add(v);
Detail = "Group not found." }
}); await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
} }
await _groupService.DeleteAsync(group); // Remove a single member
return new NoContentResult(); else if (operation.Op?.ToLowerInvariant() == "remove" &&
} !string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
private List<Guid> GetOperationValueIds(JsonElement objArray)
{
var ids = new List<Guid>();
foreach (var obj in objArray.EnumerateArray())
{ {
if (obj.TryGetProperty("value", out var valueProperty)) var removeId = GetOperationPathId(operation.Path);
if (removeId.HasValue)
{ {
if (valueProperty.TryGetGuid(out var guid)) await _groupService.DeleteUserAsync(group, removeId.Value);
{ operationHandled = true;
ids.Add(guid);
}
} }
} }
return ids; // Remove a list of members
} else if (operation.Op?.ToLowerInvariant() == "remove" &&
operation.Path?.ToLowerInvariant() == "members")
private Guid? GetOperationPathId(string path)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id))
{ {
return id; var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
} foreach (var v in GetOperationValueIds(operation.Value))
return null;
}
private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty)
{
if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta)
{
return;
}
if (model.Members == null)
{
return;
}
var memberIds = new List<Guid>();
foreach (var id in model.Members.Select(i => i.Value))
{
if (Guid.TryParse(id, out var guidId))
{ {
memberIds.Add(guidId); orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
}
}
if (!operationHandled)
{
_logger.LogWarning("Group patch operation not handled: {0} : ",
string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
}
return new NoContentResult();
}
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
await _groupService.DeleteAsync(group);
return new NoContentResult();
}
private List<Guid> GetOperationValueIds(JsonElement objArray)
{
var ids = new List<Guid>();
foreach (var obj in objArray.EnumerateArray())
{
if (obj.TryGetProperty("value", out var valueProperty))
{
if (valueProperty.TryGetGuid(out var guid))
{
ids.Add(guid);
} }
} }
if (!memberIds.Any() && skipIfEmpty)
{
return;
}
await _groupRepository.UpdateUsersAsync(group.Id, memberIds);
} }
return ids;
}
private Guid? GetOperationPathId(string path)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id))
{
return id;
}
return null;
}
private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty)
{
if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta)
{
return;
}
if (model.Members == null)
{
return;
}
var memberIds = new List<Guid>();
foreach (var id in model.Members.Select(i => i.Value))
{
if (Guid.TryParse(id, out var guidId))
{
memberIds.Add(guidId);
}
}
if (!memberIds.Any() && skipIfEmpty)
{
return;
}
await _groupRepository.UpdateUsersAsync(group.Id, memberIds);
} }
} }

View File

@ -9,287 +9,286 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Scim.Controllers.v2 namespace Bit.Scim.Controllers.v2;
[Authorize("Scim")]
[Route("v2/{organizationId}/users")]
public class UsersController : Controller
{ {
[Authorize("Scim")] private readonly IUserService _userService;
[Route("v2/{organizationId}/users")] private readonly IUserRepository _userRepository;
public class UsersController : Controller private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IOrganizationService _organizationService;
private readonly IScimContext _scimContext;
private readonly ScimSettings _scimSettings;
private readonly ILogger<UsersController> _logger;
public UsersController(
IUserService userService,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationService organizationService,
IScimContext scimContext,
IOptions<ScimSettings> scimSettings,
ILogger<UsersController> logger)
{ {
private readonly IUserService _userService; _userService = userService;
private readonly IUserRepository _userRepository; _userRepository = userRepository;
private readonly IOrganizationUserRepository _organizationUserRepository; _organizationUserRepository = organizationUserRepository;
private readonly IOrganizationService _organizationService; _organizationService = organizationService;
private readonly IScimContext _scimContext; _scimContext = scimContext;
private readonly ScimSettings _scimSettings; _scimSettings = scimSettings?.Value;
private readonly ILogger<UsersController> _logger; _logger = logger;
}
public UsersController( [HttpGet("{id}")]
IUserService userService, public async Task<IActionResult> Get(Guid organizationId, Guid id)
IUserRepository userRepository, {
IOrganizationUserRepository organizationUserRepository, var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id);
IOrganizationService organizationService, if (orgUser == null || orgUser.OrganizationId != organizationId)
IScimContext scimContext,
IOptions<ScimSettings> scimSettings,
ILogger<UsersController> logger)
{ {
_userService = userService; return new NotFoundObjectResult(new ScimErrorResponseModel
_userRepository = userRepository;
_organizationUserRepository = organizationUserRepository;
_organizationService = organizationService;
_scimContext = scimContext;
_scimSettings = scimSettings?.Value;
_logger = logger;
}
[HttpGet("{id}")]
public async Task<IActionResult> Get(Guid organizationId, Guid id)
{
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{ {
return new NotFoundObjectResult(new ScimErrorResponseModel Status = 404,
Detail = "User not found."
});
}
return new ObjectResult(new ScimUserResponseModel(orgUser));
}
[HttpGet("")]
public async Task<IActionResult> Get(
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{
string emailFilter = null;
string usernameFilter = null;
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{
if (filter.StartsWith("userName eq "))
{
usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant();
if (usernameFilter.Contains("@"))
{ {
Status = 404, emailFilter = usernameFilter;
Detail = "User not found." }
}); }
else if (filter.StartsWith("externalId eq "))
{
externalIdFilter = filter.Substring(14).Trim('"');
} }
return new ObjectResult(new ScimUserResponseModel(orgUser));
} }
[HttpGet("")] var userList = new List<ScimUserResponseModel> { };
public async Task<IActionResult> Get( var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
Guid organizationId, var totalResults = 0;
[FromQuery] string filter, if (!string.IsNullOrWhiteSpace(emailFilter))
[FromQuery] int? count,
[FromQuery] int? startIndex)
{ {
string emailFilter = null; var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter);
string usernameFilter = null; if (orgUser != null)
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{ {
if (filter.StartsWith("userName eq ")) userList.Add(new ScimUserResponseModel(orgUser));
{ }
usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant(); totalResults = userList.Count;
if (usernameFilter.Contains("@")) }
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
userList = orgUsers.OrderBy(ou => ou.Email)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(ou => new ScimUserResponseModel(ou))
.ToList();
totalResults = orgUsers.Count;
}
var result = new ScimListResponseModel<ScimUserResponseModel>
{
Resources = userList,
ItemsPerPage = count.GetValueOrDefault(userList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
}
[HttpPost("")]
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimUserRequestModel model)
{
var email = model.PrimaryEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
switch (_scimContext.RequestScimProvider)
{
case ScimProviderType.AzureAd:
email = model.UserName?.ToLowerInvariant();
break;
default:
email = model.WorkEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{ {
emailFilter = usernameFilter; email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant();
}
break;
}
}
if (string.IsNullOrWhiteSpace(email) || !model.Active)
{
return new BadRequestResult();
}
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email);
if (orgUserByEmail != null)
{
return new ConflictResult();
}
string externalId = null;
if (!string.IsNullOrWhiteSpace(model.ExternalId))
{
externalId = model.ExternalId;
}
else if (!string.IsNullOrWhiteSpace(model.UserName))
{
externalId = model.UserName;
}
else
{
externalId = CoreHelpers.RandomString(15);
}
var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId);
if (orgUserByExternalId != null)
{
return new ConflictResult();
}
var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email,
OrganizationUserType.User, false, externalId, new List<SelectionReadOnly>());
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id);
var response = new ScimUserResponseModel(orgUser);
return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response);
}
[HttpPut("{id}")]
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
}
else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
}
// Have to get full details object for response model
var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id);
return new ObjectResult(new ScimUserResponseModel(orgUserDetails));
}
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Active from path
if (operation.Path?.ToLowerInvariant() == "active")
{
var active = operation.Value.ToString()?.ToLowerInvariant();
var handled = await HandleActiveOperationAsync(orgUser, active == "true");
if (!operationHandled)
{
operationHandled = handled;
} }
} }
else if (filter.StartsWith("externalId eq ")) // Active from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("active", out var activeProperty))
{ {
externalIdFilter = filter.Substring(14).Trim('"'); var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean());
} if (!operationHandled)
}
var userList = new List<ScimUserResponseModel> { };
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var totalResults = 0;
if (!string.IsNullOrWhiteSpace(emailFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
userList = orgUsers.OrderBy(ou => ou.Email)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(ou => new ScimUserResponseModel(ou))
.ToList();
totalResults = orgUsers.Count;
}
var result = new ScimListResponseModel<ScimUserResponseModel>
{
Resources = userList,
ItemsPerPage = count.GetValueOrDefault(userList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
}
[HttpPost("")]
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimUserRequestModel model)
{
var email = model.PrimaryEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
switch (_scimContext.RequestScimProvider)
{
case ScimProviderType.AzureAd:
email = model.UserName?.ToLowerInvariant();
break;
default:
email = model.WorkEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant();
}
break;
}
}
if (string.IsNullOrWhiteSpace(email) || !model.Active)
{
return new BadRequestResult();
}
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email);
if (orgUserByEmail != null)
{
return new ConflictResult();
}
string externalId = null;
if (!string.IsNullOrWhiteSpace(model.ExternalId))
{
externalId = model.ExternalId;
}
else if (!string.IsNullOrWhiteSpace(model.UserName))
{
externalId = model.UserName;
}
else
{
externalId = CoreHelpers.RandomString(15);
}
var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId);
if (orgUserByExternalId != null)
{
return new ConflictResult();
}
var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email,
OrganizationUserType.User, false, externalId, new List<SelectionReadOnly>());
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id);
var response = new ScimUserResponseModel(orgUser);
return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response);
}
[HttpPut("{id}")]
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
}
else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
}
// Have to get full details object for response model
var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id);
return new ObjectResult(new ScimUserResponseModel(orgUserDetails));
}
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Active from path
if (operation.Path?.ToLowerInvariant() == "active")
{ {
var active = operation.Value.ToString()?.ToLowerInvariant(); operationHandled = handled;
var handled = await HandleActiveOperationAsync(orgUser, active == "true");
if (!operationHandled)
{
operationHandled = handled;
}
}
// Active from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("active", out var activeProperty))
{
var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean());
if (!operationHandled)
{
operationHandled = handled;
}
} }
} }
} }
if (!operationHandled)
{
_logger.LogWarning("User patch operation not handled: {operation} : ",
string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
}
return new NoContentResult();
} }
[HttpDelete("{id}")] if (!operationHandled)
public async Task<IActionResult> Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{ {
var orgUser = await _organizationUserRepository.GetByIdAsync(id); _logger.LogWarning("User patch operation not handled: {operation} : ",
if (orgUser == null || orgUser.OrganizationId != organizationId) string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
await _organizationService.DeleteUserAsync(organizationId, id, null);
return new NoContentResult();
} }
private async Task<bool> HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active) return new NoContentResult();
}
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{ {
if (active && orgUser.Status == OrganizationUserStatusType.Revoked) return new NotFoundObjectResult(new ScimErrorResponseModel
{ {
await _organizationService.RestoreUserAsync(orgUser, null, _userService); Status = 404,
return true; Detail = "User not found."
} });
else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
return true;
}
return false;
} }
await _organizationService.DeleteUserAsync(organizationId, id, null);
return new NoContentResult();
}
private async Task<bool> HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active)
{
if (active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
return true;
}
else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
return true;
}
return false;
} }
} }

View File

@ -1,18 +1,17 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
{
public abstract class BaseScimGroupModel : BaseScimModel
{
public BaseScimGroupModel(bool initSchema = false)
{
if (initSchema)
{
Schemas = new List<string> { ScimConstants.Scim2SchemaGroup };
}
}
public string DisplayName { get; set; } public abstract class BaseScimGroupModel : BaseScimModel
public string ExternalId { get; set; } {
public BaseScimGroupModel(bool initSchema = false)
{
if (initSchema)
{
Schemas = new List<string> { ScimConstants.Scim2SchemaGroup };
}
} }
public string DisplayName { get; set; }
public string ExternalId { get; set; }
} }

View File

@ -1,15 +1,14 @@
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public abstract class BaseScimModel
{ {
public abstract class BaseScimModel public BaseScimModel()
{ }
public BaseScimModel(string schema)
{ {
public BaseScimModel() Schemas = new List<string> { schema };
{ }
public BaseScimModel(string schema)
{
Schemas = new List<string> { schema };
}
public List<string> Schemas { get; set; }
} }
public List<string> Schemas { get; set; }
} }

View File

@ -1,56 +1,55 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public abstract class BaseScimUserModel : BaseScimModel
{ {
public abstract class BaseScimUserModel : BaseScimModel public BaseScimUserModel(bool initSchema = false)
{ {
public BaseScimUserModel(bool initSchema = false) if (initSchema)
{ {
if (initSchema) Schemas = new List<string> { ScimConstants.Scim2SchemaUser };
{
Schemas = new List<string> { ScimConstants.Scim2SchemaUser };
}
}
public string UserName { get; set; }
public NameModel Name { get; set; }
public List<EmailModel> Emails { get; set; }
public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value;
public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value;
public string DisplayName { get; set; }
public bool Active { get; set; }
public List<string> Groups { get; set; }
public string ExternalId { get; set; }
public class NameModel
{
public NameModel() { }
public NameModel(string name)
{
Formatted = name;
}
public string Formatted { get; set; }
public string GivenName { get; set; }
public string MiddleName { get; set; }
public string FamilyName { get; set; }
}
public class EmailModel
{
public EmailModel() { }
public EmailModel(string email)
{
Primary = true;
Value = email;
Type = "work";
}
public bool Primary { get; set; }
public string Value { get; set; }
public string Type { get; set; }
} }
} }
public string UserName { get; set; }
public NameModel Name { get; set; }
public List<EmailModel> Emails { get; set; }
public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value;
public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value;
public string DisplayName { get; set; }
public bool Active { get; set; }
public List<string> Groups { get; set; }
public string ExternalId { get; set; }
public class NameModel
{
public NameModel() { }
public NameModel(string name)
{
Formatted = name;
}
public string Formatted { get; set; }
public string GivenName { get; set; }
public string MiddleName { get; set; }
public string FamilyName { get; set; }
}
public class EmailModel
{
public EmailModel() { }
public EmailModel(string email)
{
Primary = true;
Value = email;
Type = "work";
}
public bool Primary { get; set; }
public string Value { get; set; }
public string Type { get; set; }
}
} }

View File

@ -1,14 +1,13 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
{
public class ScimErrorResponseModel : BaseScimModel
{
public ScimErrorResponseModel()
: base(ScimConstants.Scim2SchemaError)
{ }
public string Detail { get; set; } public class ScimErrorResponseModel : BaseScimModel
public int Status { get; set; } {
} public ScimErrorResponseModel()
: base(ScimConstants.Scim2SchemaError)
{ }
public string Detail { get; set; }
public int Status { get; set; }
} }

View File

@ -1,31 +1,30 @@
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public class ScimGroupRequestModel : BaseScimGroupModel
{ {
public class ScimGroupRequestModel : BaseScimGroupModel public ScimGroupRequestModel()
: base(false)
{ }
public Group ToGroup(Guid organizationId)
{ {
public ScimGroupRequestModel() var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId;
: base(false) return new Group
{ }
public Group ToGroup(Guid organizationId)
{ {
var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId; Name = DisplayName,
return new Group ExternalId = externalId,
{ OrganizationId = organizationId
Name = DisplayName, };
ExternalId = externalId, }
OrganizationId = organizationId
};
}
public List<GroupMembersModel> Members { get; set; } public List<GroupMembersModel> Members { get; set; }
public class GroupMembersModel public class GroupMembersModel
{ {
public string Value { get; set; } public string Value { get; set; }
public string Display { get; set; } public string Display { get; set; }
}
} }
} }

View File

@ -1,26 +1,25 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public class ScimGroupResponseModel : BaseScimGroupModel
{ {
public class ScimGroupResponseModel : BaseScimGroupModel public ScimGroupResponseModel()
: base(true)
{ {
public ScimGroupResponseModel() Meta = new ScimMetaModel("Group");
: base(true)
{
Meta = new ScimMetaModel("Group");
}
public ScimGroupResponseModel(Group group)
: this()
{
Id = group.Id.ToString();
DisplayName = group.Name;
ExternalId = group.ExternalId;
Meta.Created = group.CreationDate;
Meta.LastModified = group.RevisionDate;
}
public string Id { get; set; }
public ScimMetaModel Meta { get; private set; }
} }
public ScimGroupResponseModel(Group group)
: this()
{
Id = group.Id.ToString();
DisplayName = group.Name;
ExternalId = group.ExternalId;
Meta.Created = group.CreationDate;
Meta.LastModified = group.RevisionDate;
}
public string Id { get; set; }
public ScimMetaModel Meta { get; private set; }
} }

View File

@ -1,16 +1,15 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
{
public class ScimListResponseModel<T> : BaseScimModel
{
public ScimListResponseModel()
: base(ScimConstants.Scim2SchemaListResponse)
{ }
public int TotalResults { get; set; } public class ScimListResponseModel<T> : BaseScimModel
public int StartIndex { get; set; } {
public int ItemsPerPage { get; set; } public ScimListResponseModel()
public List<T> Resources { get; set; } : base(ScimConstants.Scim2SchemaListResponse)
} { }
public int TotalResults { get; set; }
public int StartIndex { get; set; }
public int ItemsPerPage { get; set; }
public List<T> Resources { get; set; }
} }

View File

@ -1,14 +1,13 @@
namespace Bit.Scim.Models namespace Bit.Scim.Models;
{
public class ScimMetaModel
{
public ScimMetaModel(string resourceType)
{
ResourceType = resourceType;
}
public string ResourceType { get; set; } public class ScimMetaModel
public DateTime? Created { get; set; } {
public DateTime? LastModified { get; set; } public ScimMetaModel(string resourceType)
{
ResourceType = resourceType;
} }
public string ResourceType { get; set; }
public DateTime? Created { get; set; }
public DateTime? LastModified { get; set; }
} }

View File

@ -1,19 +1,18 @@
using System.Text.Json; using System.Text.Json;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public class ScimPatchModel : BaseScimModel
{ {
public class ScimPatchModel : BaseScimModel public ScimPatchModel()
: base() { }
public List<OperationModel> Operations { get; set; }
public class OperationModel
{ {
public ScimPatchModel() public string Op { get; set; }
: base() { } public string Path { get; set; }
public JsonElement Value { get; set; }
public List<OperationModel> Operations { get; set; }
public class OperationModel
{
public string Op { get; set; }
public string Path { get; set; }
public JsonElement Value { get; set; }
}
} }
} }

View File

@ -1,9 +1,8 @@
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public class ScimUserRequestModel : BaseScimUserModel
{ {
public class ScimUserRequestModel : BaseScimUserModel public ScimUserRequestModel()
{ : base(false)
public ScimUserRequestModel() { }
: base(false)
{ }
}
} }

View File

@ -1,29 +1,28 @@
using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.Data.Organizations.OrganizationUsers;
namespace Bit.Scim.Models namespace Bit.Scim.Models;
public class ScimUserResponseModel : BaseScimUserModel
{ {
public class ScimUserResponseModel : BaseScimUserModel public ScimUserResponseModel()
: base(true)
{ {
public ScimUserResponseModel() Meta = new ScimMetaModel("User");
: base(true) Groups = new List<string>();
{
Meta = new ScimMetaModel("User");
Groups = new List<string>();
}
public ScimUserResponseModel(OrganizationUserUserDetails orgUser)
: this()
{
Id = orgUser.Id.ToString();
ExternalId = orgUser.ExternalId;
UserName = orgUser.Email;
DisplayName = orgUser.Name;
Emails = new List<EmailModel> { new EmailModel(orgUser.Email) };
Name = new NameModel(orgUser.Name);
Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked;
}
public string Id { get; set; }
public ScimMetaModel Meta { get; private set; }
} }
public ScimUserResponseModel(OrganizationUserUserDetails orgUser)
: this()
{
Id = orgUser.Id.ToString();
ExternalId = orgUser.ExternalId;
UserName = orgUser.Email;
DisplayName = orgUser.Name;
Emails = new List<EmailModel> { new EmailModel(orgUser.Email) };
Name = new NameModel(orgUser.Name);
Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked;
}
public string Id { get; set; }
public ScimMetaModel Meta { get; private set; }
} }

View File

@ -1,34 +1,33 @@
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Serilog.Events; using Serilog.Events;
namespace Bit.Scim namespace Bit.Scim;
public class Program
{ {
public class Program public static void Main(string[] args)
{ {
public static void Main(string[] args) Host
{ .CreateDefaultBuilder(args)
Host .ConfigureWebHostDefaults(webBuilder =>
.CreateDefaultBuilder(args) {
.ConfigureWebHostDefaults(webBuilder => webBuilder.UseStartup<Startup>();
{ webBuilder.ConfigureLogging((hostingContext, logging) =>
webBuilder.UseStartup<Startup>(); logging.AddSerilog(hostingContext, e =>
webBuilder.ConfigureLogging((hostingContext, logging) => {
logging.AddSerilog(hostingContext, e => var context = e.Properties["SourceContext"].ToString();
if (e.Properties.ContainsKey("RequestPath") &&
!string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) &&
(context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer")))
{ {
var context = e.Properties["SourceContext"].ToString(); return false;
}
if (e.Properties.ContainsKey("RequestPath") && return e.Level >= LogEventLevel.Warning;
!string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && }));
(context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) })
{ .Build()
return false; .Run();
}
return e.Level >= LogEventLevel.Warning;
}));
})
.Build()
.Run();
}
} }
} }

View File

@ -1,6 +1,5 @@
namespace Bit.Scim namespace Bit.Scim;
public class ScimSettings
{ {
public class ScimSettings
{
}
} }

View File

@ -9,108 +9,107 @@ using IdentityModel;
using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.DependencyInjection.Extensions;
using Stripe; using Stripe;
namespace Bit.Scim namespace Bit.Scim;
public class Startup
{ {
public class Startup public Startup(IWebHostEnvironment env, IConfiguration configuration)
{ {
public Startup(IWebHostEnvironment env, IConfiguration configuration) CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US");
Configuration = configuration;
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
services.Configure<ScimSettings>(Configuration.GetSection("ScimSettings"));
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
services.AddScoped<IScimContext, ScimContext>();
// Authentication
services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme)
.AddScheme<ApiKeyAuthenticationOptions, ApiKeyAuthenticationHandler>(
ApiKeyAuthenticationOptions.DefaultScheme, null);
services.AddAuthorization(config =>
{ {
CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); config.AddPolicy("Scim", policy =>
Configuration = configuration;
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
services.Configure<ScimSettings>(Configuration.GetSection("ScimSettings"));
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
services.AddScoped<IScimContext, ScimContext>();
// Authentication
services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme)
.AddScheme<ApiKeyAuthenticationOptions, ApiKeyAuthenticationHandler>(
ApiKeyAuthenticationOptions.DefaultScheme, null);
services.AddAuthorization(config =>
{ {
config.AddPolicy("Scim", policy => policy.RequireAuthenticatedUser();
{ policy.RequireClaim(JwtClaimTypes.Scope, "api.scim");
policy.RequireAuthenticatedUser();
policy.RequireClaim(JwtClaimTypes.Scope, "api.scim");
});
}); });
});
// Identity // Identity
services.AddCustomIdentityServices(globalSettings); services.AddCustomIdentityServices(globalSettings);
// Services // Services
services.AddBaseServices(globalSettings); services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings); services.AddDefaultServices(globalSettings);
services.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>(); services.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>();
// Mvc // Mvc
services.AddMvc(config => services.AddMvc(config =>
{
config.Filters.Add(new LoggingExceptionHandlerFilterAttribute());
});
services.Configure<RouteOptions>(options => options.LowercaseUrls = true);
}
public void Configure(
IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings)
{ {
app.UseSerilog(env, appLifetime, globalSettings); config.Filters.Add(new LoggingExceptionHandlerFilterAttribute());
});
services.Configure<RouteOptions>(options => options.LowercaseUrls = true);
}
// Add general security headers public void Configure(
app.UseMiddleware<SecurityHeadersMiddleware>(); IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings)
{
app.UseSerilog(env, appLifetime, globalSettings);
if (env.IsDevelopment()) // Add general security headers
{ app.UseMiddleware<SecurityHeadersMiddleware>();
app.UseDeveloperExceptionPage();
}
// Default Middleware if (env.IsDevelopment())
app.UseDefaultMiddleware(env, globalSettings); {
app.UseDeveloperExceptionPage();
// Add routing
app.UseRouting();
// Add Scim context
app.UseMiddleware<ScimContextMiddleware>();
// Add authentication and authorization to the request pipeline.
app.UseAuthentication();
app.UseAuthorization();
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add MVC to the request pipeline.
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
} }
// Default Middleware
app.UseDefaultMiddleware(env, globalSettings);
// Add routing
app.UseRouting();
// Add Scim context
app.UseMiddleware<ScimContextMiddleware>();
// Add authentication and authorization to the request pipeline.
app.UseAuthentication();
app.UseAuthorization();
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add MVC to the request pipeline.
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
} }
} }

View File

@ -8,83 +8,82 @@ using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Scim.Utilities namespace Bit.Scim.Utilities;
public class ApiKeyAuthenticationHandler : AuthenticationHandler<ApiKeyAuthenticationOptions>
{ {
public class ApiKeyAuthenticationHandler : AuthenticationHandler<ApiKeyAuthenticationOptions> private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository;
private readonly IScimContext _scimContext;
public ApiKeyAuthenticationHandler(
IOptionsMonitor<ApiKeyAuthenticationOptions> options,
ILoggerFactory logger,
UrlEncoder encoder,
ISystemClock clock,
IOrganizationRepository organizationRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
IScimContext scimContext) :
base(options, logger, encoder, clock)
{ {
private readonly IOrganizationRepository _organizationRepository; _organizationRepository = organizationRepository;
private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; _organizationApiKeyRepository = organizationApiKeyRepository;
private readonly IScimContext _scimContext; _scimContext = scimContext;
}
public ApiKeyAuthenticationHandler( protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
IOptionsMonitor<ApiKeyAuthenticationOptions> options, {
ILoggerFactory logger, var endpoint = Context.GetEndpoint();
UrlEncoder encoder, if (endpoint?.Metadata?.GetMetadata<IAllowAnonymous>() != null)
ISystemClock clock,
IOrganizationRepository organizationRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
IScimContext scimContext) :
base(options, logger, encoder, clock)
{ {
_organizationRepository = organizationRepository; return AuthenticateResult.NoResult();
_organizationApiKeyRepository = organizationApiKeyRepository;
_scimContext = scimContext;
} }
protected override async Task<AuthenticateResult> HandleAuthenticateAsync() if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null)
{ {
var endpoint = Context.GetEndpoint(); Logger.LogWarning("No organization.");
if (endpoint?.Metadata?.GetMetadata<IAllowAnonymous>() != null) return AuthenticateResult.Fail("Invalid parameters");
{
return AuthenticateResult.NoResult();
}
if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null)
{
Logger.LogWarning("No organization.");
return AuthenticateResult.Fail("Invalid parameters");
}
if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1)
{
Logger.LogWarning("An API request was received without the Authorization header");
return AuthenticateResult.Fail("Invalid parameters");
}
var apiKey = authHeader.ToString();
if (apiKey.StartsWith("Bearer "))
{
apiKey = apiKey.Substring(7);
}
if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim ||
_scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled)
{
Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId);
return AuthenticateResult.Fail("Invalid parameters");
}
var orgApiKey = (await _organizationApiKeyRepository
.GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim))
.FirstOrDefault();
if (orgApiKey?.ApiKey != apiKey)
{
Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey);
return AuthenticateResult.Fail("Invalid parameters");
}
Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId);
var claims = new[]
{
new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"),
new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()),
new Claim(JwtClaimTypes.Scope, "api.scim"),
};
var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity),
ApiKeyAuthenticationOptions.DefaultScheme);
return AuthenticateResult.Success(ticket);
} }
if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1)
{
Logger.LogWarning("An API request was received without the Authorization header");
return AuthenticateResult.Fail("Invalid parameters");
}
var apiKey = authHeader.ToString();
if (apiKey.StartsWith("Bearer "))
{
apiKey = apiKey.Substring(7);
}
if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim ||
_scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled)
{
Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId);
return AuthenticateResult.Fail("Invalid parameters");
}
var orgApiKey = (await _organizationApiKeyRepository
.GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim))
.FirstOrDefault();
if (orgApiKey?.ApiKey != apiKey)
{
Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey);
return AuthenticateResult.Fail("Invalid parameters");
}
Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId);
var claims = new[]
{
new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"),
new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()),
new Claim(JwtClaimTypes.Scope, "api.scim"),
};
var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity),
ApiKeyAuthenticationOptions.DefaultScheme);
return AuthenticateResult.Success(ticket);
} }
} }

View File

@ -1,9 +1,8 @@
using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication;
namespace Bit.Scim.Utilities namespace Bit.Scim.Utilities;
public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions
{ {
public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions public const string DefaultScheme = "ScimApiKey";
{
public const string DefaultScheme = "ScimApiKey";
}
} }

View File

@ -1,10 +1,9 @@
namespace Bit.Scim.Utilities namespace Bit.Scim.Utilities;
public static class ScimConstants
{ {
public static class ScimConstants public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse";
{ public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error";
public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse"; public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User";
public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error"; public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group";
public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User";
public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group";
}
} }

View File

@ -2,22 +2,21 @@
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Scim.Context; using Bit.Scim.Context;
namespace Bit.Scim.Utilities namespace Bit.Scim.Utilities;
public class ScimContextMiddleware
{ {
public class ScimContextMiddleware private readonly RequestDelegate _next;
public ScimContextMiddleware(RequestDelegate next)
{ {
private readonly RequestDelegate _next; _next = next;
}
public ScimContextMiddleware(RequestDelegate next) public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings,
{ IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository)
_next = next; {
} await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository);
await _next.Invoke(httpContext);
public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings,
IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository)
{
await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository);
await _next.Invoke(httpContext);
}
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -5,51 +5,50 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Sso.Controllers namespace Bit.Sso.Controllers;
public class HomeController : Controller
{ {
public class HomeController : Controller private readonly IIdentityServerInteractionService _interaction;
public HomeController(IIdentityServerInteractionService interaction)
{ {
private readonly IIdentityServerInteractionService _interaction; _interaction = interaction;
}
public HomeController(IIdentityServerInteractionService interaction) [Route("~/Error")]
[Route("~/Home/Error")]
[AllowAnonymous]
public async Task<IActionResult> Error(string errorId)
{
var vm = new ErrorViewModel();
// retrieve error details from identityserver
var message = string.IsNullOrWhiteSpace(errorId) ? null :
await _interaction.GetErrorContextAsync(errorId);
if (message != null)
{ {
_interaction = interaction; vm.Error = message;
}
else
{
vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier;
var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
var exception = exceptionHandlerPathFeature?.Error;
if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: "))
{
// Messages coming from aspnetcore with a message
// similar to "The registered sign-in schemes are: {schemes}."
// will expose other Org IDs and sign-in schemes enabled on
// the server. These errors should be truncated to just the
// scheme impacted (always the first sentence)
var cleanupPoint = opEx.Message.IndexOf(". ") + 1;
var exMessage = opEx.Message.Substring(0, cleanupPoint);
exception = new InvalidOperationException(exMessage, opEx);
}
vm.Exception = exception;
} }
[Route("~/Error")] return View("Error", vm);
[Route("~/Home/Error")]
[AllowAnonymous]
public async Task<IActionResult> Error(string errorId)
{
var vm = new ErrorViewModel();
// retrieve error details from identityserver
var message = string.IsNullOrWhiteSpace(errorId) ? null :
await _interaction.GetErrorContextAsync(errorId);
if (message != null)
{
vm.Error = message;
}
else
{
vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier;
var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
var exception = exceptionHandlerPathFeature?.Error;
if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: "))
{
// Messages coming from aspnetcore with a message
// similar to "The registered sign-in schemes are: {schemes}."
// will expose other Org IDs and sign-in schemes enabled on
// the server. These errors should be truncated to just the
// scheme impacted (always the first sentence)
var cleanupPoint = opEx.Message.IndexOf(". ") + 1;
var exMessage = opEx.Message.Substring(0, cleanupPoint);
exception = new InvalidOperationException(exMessage, opEx);
}
vm.Exception = exception;
}
return View("Error", vm);
}
} }
} }

View File

@ -1,21 +1,20 @@
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Sso.Controllers namespace Bit.Sso.Controllers;
{
public class InfoController : Controller
{
[HttpGet("~/alive")]
[HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")] public class InfoController : Controller
public JsonResult GetVersion() {
{ [HttpGet("~/alive")]
return Json(CoreHelpers.GetVersion()); [HttpGet("~/now")]
} public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")]
public JsonResult GetVersion()
{
return Json(CoreHelpers.GetVersion());
} }
} }

View File

@ -5,66 +5,65 @@ using Microsoft.AspNetCore.Mvc;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
using Sustainsys.Saml2.WebSso; using Sustainsys.Saml2.WebSso;
namespace Bit.Sso.Controllers namespace Bit.Sso.Controllers;
public class MetadataController : Controller
{ {
public class MetadataController : Controller private readonly IAuthenticationSchemeProvider _schemeProvider;
public MetadataController(
IAuthenticationSchemeProvider schemeProvider)
{ {
private readonly IAuthenticationSchemeProvider _schemeProvider; _schemeProvider = schemeProvider;
}
public MetadataController( [HttpGet("saml2/{scheme}")]
IAuthenticationSchemeProvider schemeProvider) public async Task<IActionResult> ViewAsync(string scheme)
{
if (string.IsNullOrWhiteSpace(scheme))
{ {
_schemeProvider = schemeProvider; return NotFound();
} }
[HttpGet("saml2/{scheme}")] var authScheme = await _schemeProvider.GetSchemeAsync(scheme);
public async Task<IActionResult> ViewAsync(string scheme) if (authScheme == null ||
!(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) ||
dynamicAuthScheme?.SsoType != SsoType.Saml2)
{ {
if (string.IsNullOrWhiteSpace(scheme)) return NotFound();
{
return NotFound();
}
var authScheme = await _schemeProvider.GetSchemeAsync(scheme);
if (authScheme == null ||
!(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) ||
dynamicAuthScheme?.SsoType != SsoType.Saml2)
{
return NotFound();
}
if (!(dynamicAuthScheme.Options is Saml2Options options))
{
return NotFound();
}
var uri = new Uri(
Request.Scheme
+ "://"
+ Request.Host
+ Request.Path
+ Request.QueryString);
var pathBase = Request.PathBase.Value;
pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase;
var requestdata = new HttpRequestData(
Request.Method,
uri,
pathBase,
null,
Request.Cookies,
(data) => data);
var metadataResult = CommandFactory
.GetCommand(CommandFactory.MetadataCommand)
.Run(requestdata, options);
//Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml");
return new ContentResult
{
Content = metadataResult.Content,
ContentType = "text/xml",
};
} }
if (!(dynamicAuthScheme.Options is Saml2Options options))
{
return NotFound();
}
var uri = new Uri(
Request.Scheme
+ "://"
+ Request.Host
+ Request.Path
+ Request.QueryString);
var pathBase = Request.PathBase.Value;
pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase;
var requestdata = new HttpRequestData(
Request.Method,
uri,
pathBase,
null,
Request.Cookies,
(data) => data);
var metadataResult = CommandFactory
.GetCommand(CommandFactory.MetadataCommand)
.Run(requestdata, options);
//Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml");
return new ContentResult
{
Content = metadataResult.Content,
ContentType = "text/xml",
};
} }
} }

View File

@ -1,27 +1,26 @@
using IdentityServer4.Models; using IdentityServer4.Models;
namespace Bit.Sso.Models namespace Bit.Sso.Models;
public class ErrorViewModel
{ {
public class ErrorViewModel private string _requestId;
public ErrorMessage Error { get; set; }
public Exception Exception { get; set; }
public string Message => Error?.Error;
public string Description => Error?.ErrorDescription ?? Exception?.Message;
public string RedirectUri => Error?.RedirectUri;
public string RequestId
{ {
private string _requestId; get
public ErrorMessage Error { get; set; }
public Exception Exception { get; set; }
public string Message => Error?.Error;
public string Description => Error?.ErrorDescription ?? Exception?.Message;
public string RedirectUri => Error?.RedirectUri;
public string RequestId
{ {
get return Error?.RequestId ?? _requestId;
{ }
return Error?.RequestId ?? _requestId; set
} {
set _requestId = value;
{
_requestId = value;
}
} }
} }
} }

View File

@ -1,7 +1,6 @@
namespace Bit.Sso.Models namespace Bit.Sso.Models;
public class RedirectViewModel
{ {
public class RedirectViewModel public string RedirectUrl { get; set; }
{
public string RedirectUrl { get; set; }
}
} }

View File

@ -1,9 +1,8 @@
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
namespace Bit.Sso.Models namespace Bit.Sso.Models;
public class SamlEnvironment
{ {
public class SamlEnvironment public X509Certificate2 SpSigningCertificate { get; set; }
{
public X509Certificate2 SpSigningCertificate { get; set; }
}
} }

View File

@ -1,13 +1,12 @@
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Sso.Models namespace Bit.Sso.Models;
public class SsoPreValidateResponseModel : JsonResult
{ {
public class SsoPreValidateResponseModel : JsonResult public SsoPreValidateResponseModel(string token) : base(new
{ {
public SsoPreValidateResponseModel(string token) : base(new token
{ })
token { }
})
{ }
}
} }

View File

@ -2,33 +2,32 @@
using Serilog; using Serilog;
using Serilog.Events; using Serilog.Events;
namespace Bit.Sso namespace Bit.Sso;
public class Program
{ {
public class Program public static void Main(string[] args)
{ {
public static void Main(string[] args) Host
{ .CreateDefaultBuilder(args)
Host .ConfigureCustomAppConfiguration(args)
.CreateDefaultBuilder(args) .ConfigureWebHostDefaults(webBuilder =>
.ConfigureCustomAppConfiguration(args) {
.ConfigureWebHostDefaults(webBuilder => webBuilder.UseStartup<Startup>();
webBuilder.ConfigureLogging((hostingContext, logging) =>
logging.AddSerilog(hostingContext, e =>
{ {
webBuilder.UseStartup<Startup>(); var context = e.Properties["SourceContext"].ToString();
webBuilder.ConfigureLogging((hostingContext, logging) => if (e.Properties.ContainsKey("RequestPath") &&
logging.AddSerilog(hostingContext, e => !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) &&
(context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer")))
{ {
var context = e.Properties["SourceContext"].ToString(); return false;
if (e.Properties.ContainsKey("RequestPath") && }
!string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && return e.Level >= LogEventLevel.Error;
(context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) }));
{ })
return false; .Build()
} .Run();
return e.Level >= LogEventLevel.Error;
}));
})
.Build()
.Run();
}
} }
} }

View File

@ -8,148 +8,147 @@ using IdentityServer4.Extensions;
using Microsoft.IdentityModel.Logging; using Microsoft.IdentityModel.Logging;
using Stripe; using Stripe;
namespace Bit.Sso namespace Bit.Sso;
public class Startup
{ {
public class Startup public Startup(IWebHostEnvironment env, IConfiguration configuration)
{ {
public Startup(IWebHostEnvironment env, IConfiguration configuration) Configuration = configuration;
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
// Caching
services.AddMemoryCache();
services.AddDistributedCache(globalSettings);
// Mvc
services.AddControllersWithViews();
// Cookies
if (Environment.IsDevelopment())
{ {
Configuration = configuration; services.Configure<CookiePolicyOptions>(options =>
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
// Caching
services.AddMemoryCache();
services.AddDistributedCache(globalSettings);
// Mvc
services.AddControllersWithViews();
// Cookies
if (Environment.IsDevelopment())
{ {
services.Configure<CookiePolicyOptions>(options => options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
options.OnAppendCookie = ctx =>
{ {
options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
options.OnAppendCookie = ctx => };
{
ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
};
});
}
// Authentication
services.AddDistributedIdentityServices(globalSettings);
services.AddAuthentication()
.AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme);
services.AddSsoServices(globalSettings);
// IdentityServer
services.AddSsoIdentityServerServices(Environment, globalSettings);
// Identity
services.AddCustomIdentityServices(globalSettings);
// Services
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.AddCoreLocalizationServices();
}
public void Configure(
IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings,
ILogger<Startup> logger)
{
if (env.IsDevelopment() || globalSettings.SelfHosted)
{
IdentityModelEventSource.ShowPII = true;
}
app.UseSerilog(env, appLifetime, globalSettings);
// Add general security headers
app.UseMiddleware<SecurityHeadersMiddleware>();
if (!env.IsDevelopment())
{
var uri = new Uri(globalSettings.BaseServiceUri.Sso);
app.Use(async (ctx, next) =>
{
ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}");
await next();
});
}
if (globalSettings.SelfHosted)
{
app.UsePathBase("/sso");
app.UseForwardedHeaders(globalSettings);
}
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
app.UseCookiePolicy();
}
else
{
app.UseExceptionHandler("/Error");
}
app.UseCoreLocalization();
// Add static files to the request pipeline.
app.UseStaticFiles();
// Add routing
app.UseRouting();
// Add Cors
app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings))
.AllowAnyMethod().AllowAnyHeader().AllowCredentials());
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add IdentityServer to the request pipeline.
app.UseIdentityServer(new IdentityServerMiddlewareOptions
{
AuthenticationMiddleware = app => app.UseMiddleware<SsoAuthenticationMiddleware>()
}); });
// Add Mvc stuff
app.UseAuthorization();
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
// Log startup
logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started.");
} }
// Authentication
services.AddDistributedIdentityServices(globalSettings);
services.AddAuthentication()
.AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme);
services.AddSsoServices(globalSettings);
// IdentityServer
services.AddSsoIdentityServerServices(Environment, globalSettings);
// Identity
services.AddCustomIdentityServices(globalSettings);
// Services
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.AddCoreLocalizationServices();
}
public void Configure(
IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings,
ILogger<Startup> logger)
{
if (env.IsDevelopment() || globalSettings.SelfHosted)
{
IdentityModelEventSource.ShowPII = true;
}
app.UseSerilog(env, appLifetime, globalSettings);
// Add general security headers
app.UseMiddleware<SecurityHeadersMiddleware>();
if (!env.IsDevelopment())
{
var uri = new Uri(globalSettings.BaseServiceUri.Sso);
app.Use(async (ctx, next) =>
{
ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}");
await next();
});
}
if (globalSettings.SelfHosted)
{
app.UsePathBase("/sso");
app.UseForwardedHeaders(globalSettings);
}
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
app.UseCookiePolicy();
}
else
{
app.UseExceptionHandler("/Error");
}
app.UseCoreLocalization();
// Add static files to the request pipeline.
app.UseStaticFiles();
// Add routing
app.UseRouting();
// Add Cors
app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings))
.AllowAnyMethod().AllowAnyHeader().AllowCredentials());
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add IdentityServer to the request pipeline.
app.UseIdentityServer(new IdentityServerMiddlewareOptions
{
AuthenticationMiddleware = app => app.UseMiddleware<SsoAuthenticationMiddleware>()
});
// Add Mvc stuff
app.UseAuthorization();
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
// Log startup
logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started.");
} }
} }

View File

@ -1,46 +1,45 @@
using System.Security.Claims; using System.Security.Claims;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class ClaimsExtensions
{ {
public static class ClaimsExtensions private static readonly Regex _normalizeTextRegEx =
new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline);
public static string GetFirstMatch(this IEnumerable<Claim> claims, params string[] possibleNames)
{ {
private static readonly Regex _normalizeTextRegEx = var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList();
new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline);
public static string GetFirstMatch(this IEnumerable<Claim> claims, params string[] possibleNames) // Order of prescendence is by passed in names
foreach (var name in possibleNames.Select(Normalize))
{ {
var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList(); // Second by order of claims (find claim by name)
foreach (var claim in normalizedClaims)
// Order of prescendence is by passed in names
foreach (var name in possibleNames.Select(Normalize))
{ {
// Second by order of claims (find claim by name) if (Equals(claim.Item1, name))
foreach (var claim in normalizedClaims)
{ {
if (Equals(claim.Item1, name)) return claim.Value;
{
return claim.Value;
}
} }
} }
return null;
} }
return null;
}
private static bool Equals(string text, string compare) private static bool Equals(string text, string compare)
{ {
return text == compare || return text == compare ||
(string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) || (string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) ||
string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase); string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase);
} }
private static string Normalize(string text) private static string Normalize(string text)
{
if (string.IsNullOrWhiteSpace(text))
{ {
if (string.IsNullOrWhiteSpace(text)) return text;
{
return text;
}
return _normalizeTextRegEx.Replace(text, string.Empty);
} }
return _normalizeTextRegEx.Replace(text, string.Empty);
} }
} }

View File

@ -5,32 +5,31 @@ using IdentityServer4.Services;
using IdentityServer4.Stores; using IdentityServer4.Stores;
using IdentityServer4.Validation; using IdentityServer4.Validation;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator
{ {
public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator private readonly GlobalSettings _globalSettings;
public DiscoveryResponseGenerator(
IdentityServerOptions options,
IResourceStore resourceStore,
IKeyMaterialService keys,
ExtensionGrantValidator extensionGrants,
ISecretsListParser secretParsers,
IResourceOwnerPasswordValidator resourceOwnerValidator,
ILogger<DiscoveryResponseGenerator> logger,
GlobalSettings globalSettings)
: base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger)
{ {
private readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
}
public DiscoveryResponseGenerator( public override async Task<Dictionary<string, object>> CreateDiscoveryDocumentAsync(
IdentityServerOptions options, string baseUrl, string issuerUri)
IResourceStore resourceStore, {
IKeyMaterialService keys, var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri);
ExtensionGrantValidator extensionGrants, return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso,
ISecretsListParser secretParsers, _globalSettings.BaseServiceUri.InternalSso);
IResourceOwnerPasswordValidator resourceOwnerValidator,
ILogger<DiscoveryResponseGenerator> logger,
GlobalSettings globalSettings)
: base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger)
{
_globalSettings = globalSettings;
}
public override async Task<Dictionary<string, object>> CreateDiscoveryDocumentAsync(
string baseUrl, string issuerUri)
{
var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri);
return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso,
_globalSettings.BaseServiceUri.InternalSso);
}
} }
} }

View File

@ -3,88 +3,87 @@ using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme
{ {
public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme public DynamicAuthenticationScheme(string name, string displayName, Type handlerType,
AuthenticationSchemeOptions options)
: base(name, displayName, handlerType)
{ {
public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, Options = options;
AuthenticationSchemeOptions options) }
: base(name, displayName, handlerType) public DynamicAuthenticationScheme(string name, string displayName, Type handlerType,
{ AuthenticationSchemeOptions options, SsoType ssoType)
Options = options; : this(name, displayName, handlerType, options)
} {
public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, SsoType = ssoType;
AuthenticationSchemeOptions options, SsoType ssoType) }
: this(name, displayName, handlerType, options)
{
SsoType = ssoType;
}
public AuthenticationSchemeOptions Options { get; set; } public AuthenticationSchemeOptions Options { get; set; }
public SsoType SsoType { get; set; } public SsoType SsoType { get; set; }
public async Task Validate() public async Task Validate()
{
switch (SsoType)
{ {
switch (SsoType) case SsoType.OpenIdConnect:
{ await ValidateOpenIdConnectAsync();
case SsoType.OpenIdConnect: break;
await ValidateOpenIdConnectAsync(); case SsoType.Saml2:
break; ValidateSaml();
case SsoType.Saml2: break;
ValidateSaml(); default:
break; break;
default:
break;
}
} }
}
private void ValidateSaml() private void ValidateSaml()
{
if (SsoType != SsoType.Saml2)
{ {
if (SsoType != SsoType.Saml2) return;
{
return;
}
if (!(Options is Saml2Options samlOptions))
{
throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError");
}
samlOptions.Validate(Name);
} }
if (!(Options is Saml2Options samlOptions))
private async Task ValidateOpenIdConnectAsync()
{ {
if (SsoType != SsoType.OpenIdConnect) throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError");
}
samlOptions.Validate(Name);
}
private async Task ValidateOpenIdConnectAsync()
{
if (SsoType != SsoType.OpenIdConnect)
{
return;
}
if (!(Options is OpenIdConnectOptions oidcOptions))
{
throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError");
}
oidcOptions.Validate();
if (oidcOptions.Configuration == null)
{
if (oidcOptions.ConfigurationManager == null)
{ {
return; throw new Exception("PostConfigurationNotExecutedError");
}
if (!(Options is OpenIdConnectOptions oidcOptions))
{
throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError");
}
oidcOptions.Validate();
if (oidcOptions.Configuration == null)
{
if (oidcOptions.ConfigurationManager == null)
{
throw new Exception("PostConfigurationNotExecutedError");
}
if (oidcOptions.Configuration == null)
{
try
{
oidcOptions.Configuration = await oidcOptions.ConfigurationManager
.GetConfigurationAsync(CancellationToken.None);
}
catch (Exception ex)
{
throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex);
}
}
} }
if (oidcOptions.Configuration == null) if (oidcOptions.Configuration == null)
{ {
throw new Exception("NoOpenIdConnectMetadataError"); try
{
oidcOptions.Configuration = await oidcOptions.ConfigurationManager
.GetConfigurationAsync(CancellationToken.None);
}
catch (Exception ex)
{
throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex);
}
} }
} }
if (oidcOptions.Configuration == null)
{
throw new Exception("NoOpenIdConnectMetadataError");
}
} }
} }

View File

@ -18,441 +18,440 @@ using Sustainsys.Saml2.AspNetCore2;
using Sustainsys.Saml2.Configuration; using Sustainsys.Saml2.Configuration;
using Sustainsys.Saml2.Saml2P; using Sustainsys.Saml2.Saml2P;
namespace Bit.Core.Business.Sso namespace Bit.Core.Business.Sso;
public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider
{ {
public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider private readonly IPostConfigureOptions<OpenIdConnectOptions> _oidcPostConfigureOptions;
private readonly IExtendedOptionsMonitorCache<OpenIdConnectOptions> _extendedOidcOptionsMonitorCache;
private readonly IPostConfigureOptions<Saml2Options> _saml2PostConfigureOptions;
private readonly IExtendedOptionsMonitorCache<Saml2Options> _extendedSaml2OptionsMonitorCache;
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly ILogger _logger;
private readonly GlobalSettings _globalSettings;
private readonly SamlEnvironment _samlEnvironment;
private readonly TimeSpan _schemeCacheLifetime;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedSchemes;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedHandlerSchemes;
private readonly SemaphoreSlim _semaphore;
private readonly IHttpContextAccessor _httpContextAccessor;
private DateTime? _lastSchemeLoad;
private IEnumerable<DynamicAuthenticationScheme> _schemesCopy = Array.Empty<DynamicAuthenticationScheme>();
private IEnumerable<DynamicAuthenticationScheme> _handlerSchemesCopy = Array.Empty<DynamicAuthenticationScheme>();
public DynamicAuthenticationSchemeProvider(
IOptions<AuthenticationOptions> options,
IPostConfigureOptions<OpenIdConnectOptions> oidcPostConfigureOptions,
IOptionsMonitorCache<OpenIdConnectOptions> oidcOptionsMonitorCache,
IPostConfigureOptions<Saml2Options> saml2PostConfigureOptions,
IOptionsMonitorCache<Saml2Options> saml2OptionsMonitorCache,
ISsoConfigRepository ssoConfigRepository,
ILogger<DynamicAuthenticationSchemeProvider> logger,
GlobalSettings globalSettings,
SamlEnvironment samlEnvironment,
IHttpContextAccessor httpContextAccessor)
: base(options)
{ {
private readonly IPostConfigureOptions<OpenIdConnectOptions> _oidcPostConfigureOptions; _oidcPostConfigureOptions = oidcPostConfigureOptions;
private readonly IExtendedOptionsMonitorCache<OpenIdConnectOptions> _extendedOidcOptionsMonitorCache; _extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as
private readonly IPostConfigureOptions<Saml2Options> _saml2PostConfigureOptions; IExtendedOptionsMonitorCache<OpenIdConnectOptions>;
private readonly IExtendedOptionsMonitorCache<Saml2Options> _extendedSaml2OptionsMonitorCache; if (_extendedOidcOptionsMonitorCache == null)
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly ILogger _logger;
private readonly GlobalSettings _globalSettings;
private readonly SamlEnvironment _samlEnvironment;
private readonly TimeSpan _schemeCacheLifetime;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedSchemes;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedHandlerSchemes;
private readonly SemaphoreSlim _semaphore;
private readonly IHttpContextAccessor _httpContextAccessor;
private DateTime? _lastSchemeLoad;
private IEnumerable<DynamicAuthenticationScheme> _schemesCopy = Array.Empty<DynamicAuthenticationScheme>();
private IEnumerable<DynamicAuthenticationScheme> _handlerSchemesCopy = Array.Empty<DynamicAuthenticationScheme>();
public DynamicAuthenticationSchemeProvider(
IOptions<AuthenticationOptions> options,
IPostConfigureOptions<OpenIdConnectOptions> oidcPostConfigureOptions,
IOptionsMonitorCache<OpenIdConnectOptions> oidcOptionsMonitorCache,
IPostConfigureOptions<Saml2Options> saml2PostConfigureOptions,
IOptionsMonitorCache<Saml2Options> saml2OptionsMonitorCache,
ISsoConfigRepository ssoConfigRepository,
ILogger<DynamicAuthenticationSchemeProvider> logger,
GlobalSettings globalSettings,
SamlEnvironment samlEnvironment,
IHttpContextAccessor httpContextAccessor)
: base(options)
{ {
_oidcPostConfigureOptions = oidcPostConfigureOptions; throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved.");
_extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as
IExtendedOptionsMonitorCache<OpenIdConnectOptions>;
if (_extendedOidcOptionsMonitorCache == null)
{
throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved.");
}
_saml2PostConfigureOptions = saml2PostConfigureOptions;
_extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as
IExtendedOptionsMonitorCache<Saml2Options>;
if (_extendedSaml2OptionsMonitorCache == null)
{
throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved.");
}
_ssoConfigRepository = ssoConfigRepository;
_logger = logger;
_globalSettings = globalSettings;
_schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30);
_samlEnvironment = samlEnvironment;
_cachedSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_cachedHandlerSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_semaphore = new SemaphoreSlim(1);
_httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor));
} }
private bool CacheIsValid _saml2PostConfigureOptions = saml2PostConfigureOptions;
_extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as
IExtendedOptionsMonitorCache<Saml2Options>;
if (_extendedSaml2OptionsMonitorCache == null)
{ {
get => _lastSchemeLoad.HasValue throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved.");
&& _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow;
} }
public override async Task<AuthenticationScheme> GetSchemeAsync(string name) _ssoConfigRepository = ssoConfigRepository;
_logger = logger;
_globalSettings = globalSettings;
_schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30);
_samlEnvironment = samlEnvironment;
_cachedSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_cachedHandlerSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_semaphore = new SemaphoreSlim(1);
_httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor));
}
private bool CacheIsValid
{
get => _lastSchemeLoad.HasValue
&& _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow;
}
public override async Task<AuthenticationScheme> GetSchemeAsync(string name)
{
var scheme = await base.GetSchemeAsync(name);
if (scheme != null)
{ {
var scheme = await base.GetSchemeAsync(name);
if (scheme != null)
{
return scheme;
}
try
{
var dynamicScheme = await GetDynamicSchemeAsync(name);
return dynamicScheme;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name);
}
return null;
}
public override async Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
{
var existingSchemes = await base.GetAllSchemesAsync();
var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_schemesCopy);
return schemes.ToArray();
}
public override async Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
{
var existingSchemes = await base.GetRequestHandlerSchemesAsync();
var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_handlerSchemesCopy);
return schemes.ToArray();
}
private async Task LoadAllDynamicSchemesIntoCacheAsync()
{
if (CacheIsValid)
{
// Our cache hasn't expired or been invalidated, ignore request
return;
}
await _semaphore.WaitAsync();
try
{
if (CacheIsValid)
{
// Just in case (double-checked locking pattern)
return;
}
// Save time just in case the following operation takes longer
var now = DateTime.UtcNow;
var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad);
foreach (var config in newSchemes)
{
DynamicAuthenticationScheme scheme;
try
{
scheme = GetSchemeFromSsoConfig(config);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id);
continue;
}
if (scheme == null)
{
continue;
}
SetSchemeInCache(scheme);
}
if (newSchemes.Any())
{
// Maintain "safe" copy for use in enumeration routines
_schemesCopy = _cachedSchemes.Values.ToArray();
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray();
}
_lastSchemeLoad = now;
}
finally
{
_semaphore.Release();
}
}
private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme)
{
if (!PostConfigureDynamicScheme(scheme))
{
return null;
}
_cachedSchemes[scheme.Name] = scheme;
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
{
_cachedHandlerSchemes[scheme.Name] = scheme;
}
return scheme; return scheme;
} }
private async Task<DynamicAuthenticationScheme> GetDynamicSchemeAsync(string name) try
{ {
if (_cachedSchemes.TryGetValue(name, out var cachedScheme)) var dynamicScheme = await GetDynamicSchemeAsync(name);
return dynamicScheme;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name);
}
return null;
}
public override async Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
{
var existingSchemes = await base.GetAllSchemesAsync();
var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_schemesCopy);
return schemes.ToArray();
}
public override async Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
{
var existingSchemes = await base.GetRequestHandlerSchemesAsync();
var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_handlerSchemesCopy);
return schemes.ToArray();
}
private async Task LoadAllDynamicSchemesIntoCacheAsync()
{
if (CacheIsValid)
{
// Our cache hasn't expired or been invalidated, ignore request
return;
}
await _semaphore.WaitAsync();
try
{
if (CacheIsValid)
{ {
return cachedScheme; // Just in case (double-checked locking pattern)
return;
} }
var scheme = await GetSchemeFromSsoConfigAsync(name); // Save time just in case the following operation takes longer
var now = DateTime.UtcNow;
var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad);
foreach (var config in newSchemes)
{
DynamicAuthenticationScheme scheme;
try
{
scheme = GetSchemeFromSsoConfig(config);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id);
continue;
}
if (scheme == null)
{
continue;
}
SetSchemeInCache(scheme);
}
if (newSchemes.Any())
{
// Maintain "safe" copy for use in enumeration routines
_schemesCopy = _cachedSchemes.Values.ToArray();
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray();
}
_lastSchemeLoad = now;
}
finally
{
_semaphore.Release();
}
}
private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme)
{
if (!PostConfigureDynamicScheme(scheme))
{
return null;
}
_cachedSchemes[scheme.Name] = scheme;
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
{
_cachedHandlerSchemes[scheme.Name] = scheme;
}
return scheme;
}
private async Task<DynamicAuthenticationScheme> GetDynamicSchemeAsync(string name)
{
if (_cachedSchemes.TryGetValue(name, out var cachedScheme))
{
return cachedScheme;
}
var scheme = await GetSchemeFromSsoConfigAsync(name);
if (scheme == null)
{
return null;
}
await _semaphore.WaitAsync();
try
{
scheme = SetSchemeInCache(scheme);
if (scheme == null) if (scheme == null)
{ {
return null; return null;
} }
await _semaphore.WaitAsync(); if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
try
{ {
scheme = SetSchemeInCache(scheme); _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray();
if (scheme == null) }
{ _schemesCopy = _cachedSchemes.Values.ToArray();
return null; }
} finally
{
// Note: _lastSchemeLoad is not set here, this is a one-off
// and should not impact loading further cache updates
_semaphore.Release();
}
return scheme;
}
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme)
{ {
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); try
} {
_schemesCopy = _cachedSchemes.Values.ToArray(); if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions)
}
finally
{ {
// Note: _lastSchemeLoad is not set here, this is a one-off _oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions);
// and should not impact loading further cache updates _extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions);
_semaphore.Release();
} }
return scheme; else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options)
{
_saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options);
_extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options);
}
return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error performing post configuration for '{0}' ({1})",
scheme.Name, scheme.DisplayName);
}
return false;
}
private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config)
{
var data = config.GetData();
return data.ConfigType switch
{
SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data),
SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data),
_ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"),
};
}
private async Task<DynamicAuthenticationScheme> GetSchemeFromSsoConfigAsync(string name)
{
if (!Guid.TryParse(name, out var organizationId))
{
_logger.LogWarning("Could not determine organization id from name, '{0}'", name);
return null;
}
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId);
if (ssoConfig == null || !ssoConfig.Enabled)
{
_logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name);
return null;
} }
private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme) return GetSchemeFromSsoConfig(ssoConfig);
}
private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config)
{
var oidcOptions = new OpenIdConnectOptions
{ {
try Authority = config.Authority,
ClientId = config.ClientId,
ClientSecret = config.ClientSecret,
ResponseType = "code",
ResponseMode = "form_post",
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
SignOutScheme = IdentityServerConstants.SignoutScheme,
SaveTokens = false, // reduce overall request size
TokenValidationParameters = new TokenValidationParameters
{ {
if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions) NameClaimType = JwtClaimTypes.Name,
{ RoleClaimType = JwtClaimTypes.Role,
_oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions); },
_extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions); CallbackPath = SsoConfigurationData.BuildCallbackPath(),
} SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(),
else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options) MetadataAddress = config.MetadataAddress,
{ // Prevents URLs that go beyond 1024 characters which may break for some servers
_saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options); AuthenticationMethod = config.RedirectBehavior,
_extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options); GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint,
} };
return true; oidcOptions.Scope
} .AddIfNotExists(OpenIdConnectScopes.OpenId)
catch (Exception ex) .AddIfNotExists(OpenIdConnectScopes.Email)
{ .AddIfNotExists(OpenIdConnectScopes.Profile);
_logger.LogError(ex, "Error performing post configuration for '{0}' ({1})", foreach (var scope in config.GetAdditionalScopes())
scheme.Name, scheme.DisplayName); {
} oidcOptions.Scope.AddIfNotExists(scope);
return false; }
if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue))
{
oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr);
} }
private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config) oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name);
// see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values)
if (!string.IsNullOrWhiteSpace(config.AcrValues))
{ {
var data = config.GetData(); oidcOptions.Events ??= new OpenIdConnectEvents();
return data.ConfigType switch oidcOptions.Events.OnRedirectToIdentityProvider = ctx =>
{ {
SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data), ctx.ProtocolMessage.AcrValues = config.AcrValues;
SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data), return Task.CompletedTask;
_ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"),
}; };
} }
private async Task<DynamicAuthenticationScheme> GetSchemeFromSsoConfigAsync(string name) return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler),
{ oidcOptions, SsoType.OpenIdConnect);
if (!Guid.TryParse(name, out var organizationId)) }
{
_logger.LogWarning("Could not determine organization id from name, '{0}'", name);
return null;
}
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId);
if (ssoConfig == null || !ssoConfig.Enabled)
{
_logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name);
return null;
}
return GetSchemeFromSsoConfig(ssoConfig); private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config)
{
if (_samlEnvironment == null)
{
throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}");
} }
private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config) var spEntityId = new Sustainsys.Saml2.Metadata.EntityId(
SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso));
bool? allowCreate = null;
if (config.SpNameIdFormat != Saml2NameIdFormat.Transient)
{ {
var oidcOptions = new OpenIdConnectOptions allowCreate = true;
{ }
Authority = config.Authority, var spOptions = new SPOptions
ClientId = config.ClientId, {
ClientSecret = config.ClientSecret, EntityId = spEntityId,
ResponseType = "code", ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name),
ResponseMode = "form_post", NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)),
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, WantAssertionsSigned = config.SpWantAssertionsSigned,
SignOutScheme = IdentityServerConstants.SignoutScheme, AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior),
SaveTokens = false, // reduce overall request size ValidateCertificates = config.SpValidateCertificates,
TokenValidationParameters = new TokenValidationParameters };
{ if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm))
NameClaimType = JwtClaimTypes.Name, {
RoleClaimType = JwtClaimTypes.Role, spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm;
}, }
CallbackPath = SsoConfigurationData.BuildCallbackPath(), if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm))
SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(), {
MetadataAddress = config.MetadataAddress, spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm;
// Prevents URLs that go beyond 1024 characters which may break for some servers }
AuthenticationMethod = config.RedirectBehavior, if (_samlEnvironment.SpSigningCertificate != null)
GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint, {
}; spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate);
oidcOptions.Scope
.AddIfNotExists(OpenIdConnectScopes.OpenId)
.AddIfNotExists(OpenIdConnectScopes.Email)
.AddIfNotExists(OpenIdConnectScopes.Profile);
foreach (var scope in config.GetAdditionalScopes())
{
oidcOptions.Scope.AddIfNotExists(scope);
}
if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue))
{
oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr);
}
oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name);
// see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values)
if (!string.IsNullOrWhiteSpace(config.AcrValues))
{
oidcOptions.Events ??= new OpenIdConnectEvents();
oidcOptions.Events.OnRedirectToIdentityProvider = ctx =>
{
ctx.ProtocolMessage.AcrValues = config.AcrValues;
return Task.CompletedTask;
};
}
return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler),
oidcOptions, SsoType.OpenIdConnect);
} }
private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config) var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId);
var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions)
{ {
if (_samlEnvironment == null) Binding = GetBindingType(config.IdpBindingType),
{ AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse,
throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}"); DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests,
} WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned,
};
var spEntityId = new Sustainsys.Saml2.Metadata.EntityId( if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl))
SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso));
bool? allowCreate = null;
if (config.SpNameIdFormat != Saml2NameIdFormat.Transient)
{
allowCreate = true;
}
var spOptions = new SPOptions
{
EntityId = spEntityId,
ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name),
NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)),
WantAssertionsSigned = config.SpWantAssertionsSigned,
AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior),
ValidateCertificates = config.SpValidateCertificates,
};
if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm))
{
spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm;
}
if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm))
{
spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm;
}
if (_samlEnvironment.SpSigningCertificate != null)
{
spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate);
}
var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId);
var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions)
{
Binding = GetBindingType(config.IdpBindingType),
AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse,
DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests,
WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned,
};
if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl))
{
idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl);
}
if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl))
{
idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl);
}
if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm))
{
idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm;
}
if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert))
{
var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert);
idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert));
}
idp.ArtifactResolutionServiceUrls.Clear();
// This must happen last since it calls Validate() internally.
idp.LoadMetadata = false;
var options = new Saml2Options
{
SPOptions = spOptions,
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme,
CookieManager = new IdentityServer.DistributedCacheCookieManager(),
};
options.IdentityProviders.Add(idp);
return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2);
}
private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format)
{ {
return format switch idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl);
{
Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified,
Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress,
Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName,
Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName,
Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName,
Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier,
Saml2NameIdFormat.Persistent => NameIdFormat.Persistent,
Saml2NameIdFormat.Transient => NameIdFormat.Transient,
_ => NameIdFormat.NotConfigured,
};
} }
if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl))
private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior)
{ {
return behavior switch idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl);
{
Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned,
Saml2SigningBehavior.Always => SigningBehavior.Always,
Saml2SigningBehavior.Never => SigningBehavior.Never,
_ => SigningBehavior.IfIdpWantAuthnRequestsSigned,
};
} }
if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm))
private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType)
{ {
return bindingType switch idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm;
{
Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect,
Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
_ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
};
} }
if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert))
{
var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert);
idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert));
}
idp.ArtifactResolutionServiceUrls.Clear();
// This must happen last since it calls Validate() internally.
idp.LoadMetadata = false;
var options = new Saml2Options
{
SPOptions = spOptions,
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme,
CookieManager = new IdentityServer.DistributedCacheCookieManager(),
};
options.IdentityProviders.Add(idp);
return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2);
}
private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format)
{
return format switch
{
Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified,
Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress,
Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName,
Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName,
Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName,
Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier,
Saml2NameIdFormat.Persistent => NameIdFormat.Persistent,
Saml2NameIdFormat.Transient => NameIdFormat.Transient,
_ => NameIdFormat.NotConfigured,
};
}
private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior)
{
return behavior switch
{
Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned,
Saml2SigningBehavior.Always => SigningBehavior.Always,
Saml2SigningBehavior.Never => SigningBehavior.Never,
_ => SigningBehavior.IfIdpWantAuthnRequestsSigned,
};
}
private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType)
{
return bindingType switch
{
Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect,
Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
_ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
};
} }
} }

View File

@ -1,37 +1,36 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public class ExtendedOptionsMonitorCache<TOptions> : IExtendedOptionsMonitorCache<TOptions> where TOptions : class
{ {
public class ExtendedOptionsMonitorCache<TOptions> : IExtendedOptionsMonitorCache<TOptions> where TOptions : class private readonly ConcurrentDictionary<string, Lazy<TOptions>> _cache =
new ConcurrentDictionary<string, Lazy<TOptions>>(StringComparer.Ordinal);
public void AddOrUpdate(string name, TOptions options)
{ {
private readonly ConcurrentDictionary<string, Lazy<TOptions>> _cache = _cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy<TOptions>(() => options),
new ConcurrentDictionary<string, Lazy<TOptions>>(StringComparer.Ordinal); (string s, Lazy<TOptions> lazy) => new Lazy<TOptions>(() => options));
}
public void AddOrUpdate(string name, TOptions options) public void Clear()
{ {
_cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy<TOptions>(() => options), _cache.Clear();
(string s, Lazy<TOptions> lazy) => new Lazy<TOptions>(() => options)); }
}
public void Clear() public TOptions GetOrAdd(string name, Func<TOptions> createOptions)
{ {
_cache.Clear(); return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy<TOptions>(createOptions)).Value;
} }
public TOptions GetOrAdd(string name, Func<TOptions> createOptions) public bool TryAdd(string name, TOptions options)
{ {
return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy<TOptions>(createOptions)).Value; return _cache.TryAdd(name ?? Options.DefaultName, new Lazy<TOptions>(() => options));
} }
public bool TryAdd(string name, TOptions options) public bool TryRemove(string name)
{ {
return _cache.TryAdd(name ?? Options.DefaultName, new Lazy<TOptions>(() => options)); return _cache.TryRemove(name ?? Options.DefaultName, out _);
}
public bool TryRemove(string name)
{
return _cache.TryRemove(name ?? Options.DefaultName, out _);
}
} }
} }

View File

@ -1,13 +1,12 @@
using Bit.Core.Enums; using Bit.Core.Enums;
using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
{
public interface IDynamicAuthenticationScheme
{
AuthenticationSchemeOptions Options { get; set; }
SsoType SsoType { get; set; }
Task Validate(); public interface IDynamicAuthenticationScheme
} {
AuthenticationSchemeOptions Options { get; set; }
SsoType SsoType { get; set; }
Task Validate();
} }

View File

@ -1,9 +1,8 @@
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public interface IExtendedOptionsMonitorCache<TOptions> : IOptionsMonitorCache<TOptions> where TOptions : class
{ {
public interface IExtendedOptionsMonitorCache<TOptions> : IOptionsMonitorCache<TOptions> where TOptions : class void AddOrUpdate(string name, TOptions options);
{
void AddOrUpdate(string name, TOptions options);
}
} }

View File

@ -1,63 +1,62 @@
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Microsoft.IdentityModel.Protocols.OpenIdConnect; using Microsoft.IdentityModel.Protocols.OpenIdConnect;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class OpenIdConnectOptionsExtensions
{ {
public static class OpenIdConnectOptionsExtensions public static async Task<bool> CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context)
{ {
public static async Task<bool> CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context) // Determine this is a valid request for our handler
if (options.CallbackPath != context.Request.Path &&
options.RemoteSignOutPath != context.Request.Path &&
options.SignedOutCallbackPath != context.Request.Path)
{ {
// Determine this is a valid request for our handler
if (options.CallbackPath != context.Request.Path &&
options.RemoteSignOutPath != context.Request.Path &&
options.SignedOutCallbackPath != context.Request.Path)
{
return false;
}
if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
{
return true;
}
try
{
// Parse out the message
OpenIdConnectMessage message = null;
if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
!string.IsNullOrEmpty(context.Request.ContentType) &&
context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) &&
context.Request.Body.CanRead)
{
var form = await context.Request.ReadFormAsync();
message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
var state = message?.State;
if (string.IsNullOrWhiteSpace(state))
{
// State is required, it will fail later on for this reason.
return false;
}
// Handle State if we've gotten that back
var decodedState = options.StateDataFormat.Unprotect(state);
if (decodedState != null && decodedState.Items.ContainsKey("scheme"))
{
return decodedState.Items["scheme"] == scheme;
}
}
catch
{
return false;
}
// This is likely not an appropriate handler
return false; return false;
} }
if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
{
return true;
}
try
{
// Parse out the message
OpenIdConnectMessage message = null;
if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
!string.IsNullOrEmpty(context.Request.ContentType) &&
context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) &&
context.Request.Body.CanRead)
{
var form = await context.Request.ReadFormAsync();
message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
var state = message?.State;
if (string.IsNullOrWhiteSpace(state))
{
// State is required, it will fail later on for this reason.
return false;
}
// Handle State if we've gotten that back
var decodedState = options.StateDataFormat.Unprotect(state);
if (decodedState != null && decodedState.Items.ContainsKey("scheme"))
{
return decodedState.Items["scheme"] == scheme;
}
}
catch
{
return false;
}
// This is likely not an appropriate handler
return false;
} }
} }

View File

@ -1,64 +1,63 @@
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
/// <summary>
/// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0
/// [RFC6749]. These values represent the standard scope values supported
/// by OAuth 2.0 and therefore OIDC.
/// </summary>
/// <remarks>
/// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes
/// </remarks>
public static class OpenIdConnectScopes
{ {
/// <summary> /// <summary>
/// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0 /// REQUIRED. Informs the Authorization Server that the Client is making
/// [RFC6749]. These values represent the standard scope values supported /// an OpenID Connect request. If the openid scope value is not present,
/// by OAuth 2.0 and therefore OIDC. /// the behavior is entirely unspecified.
/// </summary>
public const string OpenId = "openid";
/// <summary>
/// OPTIONAL. This scope value requests access to the End-User's default
/// profile Claims, which are: name, family_name, given_name,
/// middle_name, nickname, preferred_username, profile, picture,
/// website, gender, birthdate, zoneinfo, locale, and updated_at.
/// </summary>
public const string Profile = "profile";
/// <summary>
/// OPTIONAL. This scope value requests access to the email and
/// email_verified Claims.
/// </summary>
public const string Email = "email";
/// <summary>
/// OPTIONAL. This scope value requests access to the address Claim.
/// </summary>
public const string Address = "address";
/// <summary>
/// OPTIONAL. This scope value requests access to the phone_number and
/// phone_number_verified Claims.
/// </summary>
public const string Phone = "phone";
/// <summary>
/// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token
/// be issued that can be used to obtain an Access Token that grants
/// access to the End-User's UserInfo Endpoint even when the End-User is
/// not present (not logged in).
/// </summary>
public const string OfflineAccess = "offline_access";
/// <summary>
/// OPTIONAL. Authentication Context Class Reference. String specifying
/// an Authentication Context Class Reference value that identifies the
/// Authentication Context Class that the authentication performed
/// satisfied.
/// </summary> /// </summary>
/// <remarks> /// <remarks>
/// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes /// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2
/// </remarks> /// </remarks>
public static class OpenIdConnectScopes public const string Acr = "acr";
{
/// <summary>
/// REQUIRED. Informs the Authorization Server that the Client is making
/// an OpenID Connect request. If the openid scope value is not present,
/// the behavior is entirely unspecified.
/// </summary>
public const string OpenId = "openid";
/// <summary>
/// OPTIONAL. This scope value requests access to the End-User's default
/// profile Claims, which are: name, family_name, given_name,
/// middle_name, nickname, preferred_username, profile, picture,
/// website, gender, birthdate, zoneinfo, locale, and updated_at.
/// </summary>
public const string Profile = "profile";
/// <summary>
/// OPTIONAL. This scope value requests access to the email and
/// email_verified Claims.
/// </summary>
public const string Email = "email";
/// <summary>
/// OPTIONAL. This scope value requests access to the address Claim.
/// </summary>
public const string Address = "address";
/// <summary>
/// OPTIONAL. This scope value requests access to the phone_number and
/// phone_number_verified Claims.
/// </summary>
public const string Phone = "phone";
/// <summary>
/// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token
/// be issued that can be used to obtain an Access Token that grants
/// access to the End-User's UserInfo Endpoint even when the End-User is
/// not present (not logged in).
/// </summary>
public const string OfflineAccess = "offline_access";
/// <summary>
/// OPTIONAL. Authentication Context Class Reference. String specifying
/// an Authentication Context Class Reference value that identifies the
/// Authentication Context Class that the authentication performed
/// satisfied.
/// </summary>
/// <remarks>
/// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2
/// </remarks>
public const string Acr = "acr";
}
} }

View File

@ -4,102 +4,101 @@ using System.Xml;
using Sustainsys.Saml2; using Sustainsys.Saml2;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class Saml2OptionsExtensions
{ {
public static class Saml2OptionsExtensions public static async Task<bool> CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context)
{ {
public static async Task<bool> CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context) // Determine this is a valid request for our handler
if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal))
{ {
// Determine this is a valid request for our handler return false;
if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal)) }
{
return false;
}
var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default; var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default;
if (idp == null) if (idp == null)
{ {
return false; return false;
} }
if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
{
return true;
}
// We need to pull out and parse the response or request SAML envelope
XmlElement envelope = null;
try
{
if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
context.Request.HasFormContentType)
{
string encodedMessage;
if (context.Request.Form.TryGetValue("SAMLResponse", out var response))
{
encodedMessage = response.FirstOrDefault();
}
else
{
encodedMessage = context.Request.Form["SAMLRequest"];
}
if (string.IsNullOrWhiteSpace(encodedMessage))
{
return false;
}
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement;
}
else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ??
context.Request.Query["SAMLResponse"].FirstOrDefault();
try
{
var payload = Convert.FromBase64String(encodedPayload);
using var compressed = new MemoryStream(payload);
using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true);
using var deCompressed = new MemoryStream();
await decompressedStream.CopyToAsync(deCompressed);
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement;
}
catch (FormatException ex)
{
throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex);
}
}
}
catch
{
return false;
}
if (envelope == null)
{
return false;
}
// Double check the entity Ids
var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim();
if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase))
{
return false;
}
if (options.SPOptions.WantAssertionsSigned)
{
var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name];
var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys,
options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm);
if (!isAssertionSigned)
{
throw new Exception("Cannot verify SAML assertion signature.");
}
}
if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
{
return true; return true;
} }
// We need to pull out and parse the response or request SAML envelope
XmlElement envelope = null;
try
{
if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
context.Request.HasFormContentType)
{
string encodedMessage;
if (context.Request.Form.TryGetValue("SAMLResponse", out var response))
{
encodedMessage = response.FirstOrDefault();
}
else
{
encodedMessage = context.Request.Form["SAMLRequest"];
}
if (string.IsNullOrWhiteSpace(encodedMessage))
{
return false;
}
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement;
}
else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ??
context.Request.Query["SAMLResponse"].FirstOrDefault();
try
{
var payload = Convert.FromBase64String(encodedPayload);
using var compressed = new MemoryStream(payload);
using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true);
using var deCompressed = new MemoryStream();
await decompressedStream.CopyToAsync(deCompressed);
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement;
}
catch (FormatException ex)
{
throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex);
}
}
}
catch
{
return false;
}
if (envelope == null)
{
return false;
}
// Double check the entity Ids
var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim();
if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase))
{
return false;
}
if (options.SPOptions.WantAssertionsSigned)
{
var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name];
var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys,
options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm);
if (!isAssertionSigned)
{
throw new Exception("Cannot verify SAML assertion signature.");
}
}
return true;
} }
} }

View File

@ -1,12 +1,11 @@
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class SamlClaimTypes
{ {
public static class SamlClaimTypes public const string Email = "urn:oid:0.9.2342.19200300.100.1.3";
{ public const string GivenName = "urn:oid:2.5.4.42";
public const string Email = "urn:oid:0.9.2342.19200300.100.1.3"; public const string Surname = "urn:oid:2.5.4.4";
public const string GivenName = "urn:oid:2.5.4.42"; public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241";
public const string Surname = "urn:oid:2.5.4.4"; public const string CommonName = "urn:oid:2.5.4.3";
public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241"; public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1";
public const string CommonName = "urn:oid:2.5.4.3";
public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1";
}
} }

View File

@ -1,18 +1,17 @@
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class SamlNameIdFormats
{ {
public static class SamlNameIdFormats // Common
{ public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified";
// Common public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress";
public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"; public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent";
public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient";
public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"; // Not-so-common
public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"; public const string Upn = "http://schemas.xmlsoap.org/claims/UPN";
// Not-so-common public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName";
public const string Upn = "http://schemas.xmlsoap.org/claims/UPN"; public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName";
public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName"; public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName";
public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"; public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos";
public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"; public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity";
public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos";
public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity";
}
} }

View File

@ -1,7 +1,6 @@
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class SamlPropertyKeys
{ {
public static class SamlPropertyKeys public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format";
{
public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format";
}
} }

View File

@ -9,70 +9,69 @@ using IdentityServer4.ResponseHandling;
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public static class ServiceCollectionExtensions
{ {
public static class ServiceCollectionExtensions public static IServiceCollection AddSsoServices(this IServiceCollection services,
GlobalSettings globalSettings)
{ {
public static IServiceCollection AddSsoServices(this IServiceCollection services, // SAML SP Configuration
GlobalSettings globalSettings) var samlEnvironment = new SamlEnvironment
{ {
// SAML SP Configuration SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings),
var samlEnvironment = new SamlEnvironment };
services.AddSingleton(s => samlEnvironment);
services.AddSingleton<Microsoft.AspNetCore.Authentication.IAuthenticationSchemeProvider,
DynamicAuthenticationSchemeProvider>();
// Oidc
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<OpenIdConnectOptions>,
OpenIdConnectPostConfigureOptions>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<OpenIdConnectOptions>,
ExtendedOptionsMonitorCache<OpenIdConnectOptions>>();
// Saml2
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<Saml2Options>,
PostConfigureSaml2Options>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<Saml2Options>,
ExtendedOptionsMonitorCache<Saml2Options>>();
return services;
}
public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services,
IWebHostEnvironment env, GlobalSettings globalSettings)
{
services.AddTransient<IDiscoveryResponseGenerator, DiscoveryResponseGenerator>();
var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso);
var identityServerBuilder = services
.AddIdentityServer(options =>
{ {
SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings), options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}";
}; if (env.IsDevelopment())
services.AddSingleton(s => samlEnvironment);
services.AddSingleton<Microsoft.AspNetCore.Authentication.IAuthenticationSchemeProvider,
DynamicAuthenticationSchemeProvider>();
// Oidc
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<OpenIdConnectOptions>,
OpenIdConnectPostConfigureOptions>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<OpenIdConnectOptions>,
ExtendedOptionsMonitorCache<OpenIdConnectOptions>>();
// Saml2
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<Saml2Options>,
PostConfigureSaml2Options>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<Saml2Options>,
ExtendedOptionsMonitorCache<Saml2Options>>();
return services;
}
public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services,
IWebHostEnvironment env, GlobalSettings globalSettings)
{
services.AddTransient<IDiscoveryResponseGenerator, DiscoveryResponseGenerator>();
var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso);
var identityServerBuilder = services
.AddIdentityServer(options =>
{ {
options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
if (env.IsDevelopment()) }
{ else
options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
}
else
{
options.UserInteraction.ErrorUrl = "/Error";
options.UserInteraction.ErrorIdParameter = "errorId";
}
options.InputLengthRestrictions.UserName = 256;
})
.AddInMemoryCaching()
.AddInMemoryClients(new List<Client>
{ {
new OidcIdentityClient(globalSettings) options.UserInteraction.ErrorUrl = "/Error";
}) options.UserInteraction.ErrorIdParameter = "errorId";
.AddInMemoryIdentityResources(new List<IdentityResource> }
{ options.InputLengthRestrictions.UserName = 256;
new IdentityResources.OpenId(), })
new IdentityResources.Profile() .AddInMemoryCaching()
}) .AddInMemoryClients(new List<Client>
.AddIdentityServerCertificate(env, globalSettings); {
new OidcIdentityClient(globalSettings)
})
.AddInMemoryIdentityResources(new List<IdentityResource>
{
new IdentityResources.OpenId(),
new IdentityResources.Profile()
})
.AddIdentityServerCertificate(env, globalSettings);
return identityServerBuilder; return identityServerBuilder;
}
} }
} }

View File

@ -3,83 +3,82 @@ using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities namespace Bit.Sso.Utilities;
public class SsoAuthenticationMiddleware
{ {
public class SsoAuthenticationMiddleware private readonly RequestDelegate _next;
public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes)
{ {
private readonly RequestDelegate _next; _next = next ?? throw new ArgumentNullException(nameof(next));
Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes));
}
public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes) public IAuthenticationSchemeProvider Schemes { get; set; }
public async Task Invoke(HttpContext context)
{
if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart"))
|| (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart")))
{ {
_next = next ?? throw new ArgumentNullException(nameof(next)); throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed.");
Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes));
} }
public IAuthenticationSchemeProvider Schemes { get; set; } context.Features.Set<IAuthenticationFeature>(new AuthenticationFeature
public async Task Invoke(HttpContext context)
{ {
if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart")) OriginalPath = context.Request.Path,
|| (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart"))) OriginalPathBase = context.Request.PathBase
{ });
throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed.");
}
context.Features.Set<IAuthenticationFeature>(new AuthenticationFeature // Give any IAuthenticationRequestHandler schemes a chance to handle the request
var handlers = context.RequestServices.GetRequiredService<IAuthenticationHandlerProvider>();
foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync())
{
// Determine if scheme is appropriate for the current context FIRST
if (scheme is IDynamicAuthenticationScheme dynamicScheme)
{ {
OriginalPath = context.Request.Path, switch (dynamicScheme.SsoType)
OriginalPathBase = context.Request.PathBase
});
// Give any IAuthenticationRequestHandler schemes a chance to handle the request
var handlers = context.RequestServices.GetRequiredService<IAuthenticationHandlerProvider>();
foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync())
{
// Determine if scheme is appropriate for the current context FIRST
if (scheme is IDynamicAuthenticationScheme dynamicScheme)
{ {
switch (dynamicScheme.SsoType) case SsoType.OpenIdConnect:
{ default:
case SsoType.OpenIdConnect: if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions &&
default: !await oidcOptions.CouldHandleAsync(scheme.Name, context))
if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions && {
!await oidcOptions.CouldHandleAsync(scheme.Name, context)) // It's OIDC and Dynamic, but not a good fit
{ continue;
// It's OIDC and Dynamic, but not a good fit }
continue; break;
} case SsoType.Saml2:
break; if (dynamicScheme.Options is Saml2Options samlOptions &&
case SsoType.Saml2: !await samlOptions.CouldHandleAsync(scheme.Name, context))
if (dynamicScheme.Options is Saml2Options samlOptions && {
!await samlOptions.CouldHandleAsync(scheme.Name, context)) // It's SAML and Dynamic, but not a good fit
{ continue;
// It's SAML and Dynamic, but not a good fit }
continue; break;
}
break;
}
}
// This far it's not dynamic OR it is but "could" be handled
if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler &&
await handler.HandleRequestAsync())
{
return;
} }
} }
// Fallback to the default scheme from the provider // This far it's not dynamic OR it is but "could" be handled
var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync(); if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler &&
if (defaultAuthenticate != null) await handler.HandleRequestAsync())
{ {
var result = await context.AuthenticateAsync(defaultAuthenticate.Name); return;
if (result?.Principal != null)
{
context.User = result.Principal;
}
} }
await _next(context);
} }
// Fallback to the default scheme from the provider
var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync();
if (defaultAuthenticate != null)
{
var result = await context.AuthenticateAsync(defaultAuthenticate.Name);
if (result?.Principal != null)
{
context.User = result.Principal;
}
}
await _next(context);
} }
} }

View File

@ -3,43 +3,42 @@ using AutoFixture;
using AutoFixture.Xunit2; using AutoFixture.Xunit2;
using Bit.Core.Enums.Provider; using Bit.Core.Enums.Provider;
namespace Bit.Commercial.Core.Test.AutoFixture namespace Bit.Commercial.Core.Test.AutoFixture;
internal class ProviderUser : ICustomization
{ {
internal class ProviderUser : ICustomization public ProviderUserStatusType Status { get; set; }
public ProviderUserType Type { get; set; }
public ProviderUser(ProviderUserStatusType status, ProviderUserType type)
{ {
public ProviderUserStatusType Status { get; set; } Status = status;
public ProviderUserType Type { get; set; } Type = type;
public ProviderUser(ProviderUserStatusType status, ProviderUserType type)
{
Status = status;
Type = type;
}
public void Customize(IFixture fixture)
{
fixture.Customize<Bit.Core.Entities.Provider.ProviderUser>(composer => composer
.With(o => o.Type, Type)
.With(o => o.Status, Status));
}
} }
public class ProviderUserAttribute : CustomizeAttribute public void Customize(IFixture fixture)
{ {
private readonly ProviderUserStatusType _status; fixture.Customize<Bit.Core.Entities.Provider.ProviderUser>(composer => composer
private readonly ProviderUserType _type; .With(o => o.Type, Type)
.With(o => o.Status, Status));
public ProviderUserAttribute( }
ProviderUserStatusType status = ProviderUserStatusType.Confirmed, }
ProviderUserType type = ProviderUserType.ProviderAdmin)
{ public class ProviderUserAttribute : CustomizeAttribute
_status = status; {
_type = type; private readonly ProviderUserStatusType _status;
} private readonly ProviderUserType _type;
public override ICustomization GetCustomization(ParameterInfo parameter) public ProviderUserAttribute(
{ ProviderUserStatusType status = ProviderUserStatusType.Confirmed,
return new ProviderUser(_status, _type); ProviderUserType type = ProviderUserType.ProviderAdmin)
} {
_status = status;
_type = type;
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new ProviderUser(_status, _type);
} }
} }

View File

@ -1,16 +1,15 @@
namespace Bit.Admin namespace Bit.Admin;
{
public class AdminSettings
{
public virtual string Admins { get; set; }
public virtual CloudflareSettings Cloudflare { get; set; }
public int? DeleteTrashDaysAgo { get; set; }
public class CloudflareSettings public class AdminSettings
{ {
public string ZoneId { get; set; } public virtual string Admins { get; set; }
public string AuthEmail { get; set; } public virtual CloudflareSettings Cloudflare { get; set; }
public string AuthKey { get; set; } public int? DeleteTrashDaysAgo { get; set; }
}
public class CloudflareSettings
{
public string ZoneId { get; set; }
public string AuthEmail { get; set; }
public string AuthKey { get; set; }
} }
} }

View File

@ -1,24 +1,23 @@
using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
{
public class ErrorController : Controller
{
[Route("/error")]
public IActionResult Error(int? statusCode = null)
{
var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
TempData["Error"] = HttpContext.Features.Get<IExceptionHandlerFeature>()?.Error.Message;
if (exceptionHandlerPathFeature != null) public class ErrorController : Controller
{ {
return Redirect(exceptionHandlerPathFeature.Path); [Route("/error")]
} public IActionResult Error(int? statusCode = null)
else {
{ var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
return Redirect("/Home"); TempData["Error"] = HttpContext.Features.Get<IExceptionHandlerFeature>()?.Error.Message;
}
if (exceptionHandlerPathFeature != null)
{
return Redirect(exceptionHandlerPathFeature.Path);
}
else
{
return Redirect("/Home");
} }
} }
} }

View File

@ -6,109 +6,108 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Newtonsoft.Json; using Newtonsoft.Json;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
public class HomeController : Controller
{ {
public class HomeController : Controller private readonly GlobalSettings _globalSettings;
private readonly HttpClient _httpClient = new HttpClient();
private readonly ILogger<HomeController> _logger;
public HomeController(GlobalSettings globalSettings, ILogger<HomeController> logger)
{ {
private readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
private readonly HttpClient _httpClient = new HttpClient(); _logger = logger;
private readonly ILogger<HomeController> _logger;
public HomeController(GlobalSettings globalSettings, ILogger<HomeController> logger)
{
_globalSettings = globalSettings;
_logger = logger;
}
[Authorize]
public IActionResult Index()
{
return View(new HomeModel
{
GlobalSettings = _globalSettings,
CurrentVersion = Core.Utilities.CoreHelpers.GetVersion()
});
}
public IActionResult Error()
{
return View(new ErrorViewModel
{
RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier
});
}
public async Task<IActionResult> GetLatestVersion(ProjectType project, CancellationToken cancellationToken)
{
var requestUri = $"https://selfhost.bitwarden.com/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{
var latestVersions = JsonConvert.DeserializeObject<LatestVersions>(await response.Content.ReadAsStringAsync());
return project switch
{
ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion),
ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion),
_ => throw new System.NotImplementedException(),
};
}
}
catch (HttpRequestException e)
{
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError };
}
return new JsonResult("-");
}
public async Task<IActionResult> GetInstalledWebVersion(CancellationToken cancellationToken)
{
var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{
using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken);
var root = jsonDocument.RootElement;
return new JsonResult(root.GetProperty("version").GetString());
}
}
catch (HttpRequestException e)
{
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError };
}
return new JsonResult("-");
}
private class LatestVersions
{
[JsonProperty("versions")]
public Versions Versions { get; set; }
}
private class Versions
{
[JsonProperty("coreVersion")]
public string CoreVersion { get; set; }
[JsonProperty("webVersion")]
public string WebVersion { get; set; }
[JsonProperty("keyConnectorVersion")]
public string KeyConnectorVersion { get; set; }
}
} }
public enum ProjectType [Authorize]
public IActionResult Index()
{ {
Core, return View(new HomeModel
Web, {
GlobalSettings = _globalSettings,
CurrentVersion = Core.Utilities.CoreHelpers.GetVersion()
});
}
public IActionResult Error()
{
return View(new ErrorViewModel
{
RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier
});
}
public async Task<IActionResult> GetLatestVersion(ProjectType project, CancellationToken cancellationToken)
{
var requestUri = $"https://selfhost.bitwarden.com/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{
var latestVersions = JsonConvert.DeserializeObject<LatestVersions>(await response.Content.ReadAsStringAsync());
return project switch
{
ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion),
ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion),
_ => throw new System.NotImplementedException(),
};
}
}
catch (HttpRequestException e)
{
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError };
}
return new JsonResult("-");
}
public async Task<IActionResult> GetInstalledWebVersion(CancellationToken cancellationToken)
{
var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{
using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken);
var root = jsonDocument.RootElement;
return new JsonResult(root.GetProperty("version").GetString());
}
}
catch (HttpRequestException e)
{
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError };
}
return new JsonResult("-");
}
private class LatestVersions
{
[JsonProperty("versions")]
public Versions Versions { get; set; }
}
private class Versions
{
[JsonProperty("coreVersion")]
public string CoreVersion { get; set; }
[JsonProperty("webVersion")]
public string WebVersion { get; set; }
[JsonProperty("keyConnectorVersion")]
public string KeyConnectorVersion { get; set; }
} }
} }
public enum ProjectType
{
Core,
Web,
}

View File

@ -1,21 +1,20 @@
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
{
public class InfoController : Controller
{
[HttpGet("~/alive")]
[HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")] public class InfoController : Controller
public JsonResult GetVersion() {
{ [HttpGet("~/alive")]
return Json(CoreHelpers.GetVersion()); [HttpGet("~/now")]
} public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")]
public JsonResult GetVersion()
{
return Json(CoreHelpers.GetVersion());
} }
} }

View File

@ -3,91 +3,90 @@ using Bit.Core.Identity;
using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
public class LoginController : Controller
{ {
public class LoginController : Controller private readonly PasswordlessSignInManager<IdentityUser> _signInManager;
public LoginController(
PasswordlessSignInManager<IdentityUser> signInManager)
{ {
private readonly PasswordlessSignInManager<IdentityUser> _signInManager; _signInManager = signInManager;
}
public LoginController( public IActionResult Index(string returnUrl = null, int? error = null, int? success = null,
PasswordlessSignInManager<IdentityUser> signInManager) bool accessDenied = false)
{
if (!error.HasValue && accessDenied)
{ {
_signInManager = signInManager; error = 4;
} }
public IActionResult Index(string returnUrl = null, int? error = null, int? success = null, return View(new LoginModel
bool accessDenied = false)
{ {
if (!error.HasValue && accessDenied) ReturnUrl = returnUrl,
{ Error = GetMessage(error),
error = 4; Success = GetMessage(success)
} });
}
return View(new LoginModel [HttpPost]
{ [ValidateAntiForgeryToken]
ReturnUrl = returnUrl, public async Task<IActionResult> Index(LoginModel model)
Error = GetMessage(error), {
Success = GetMessage(success) if (ModelState.IsValid)
});
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Index(LoginModel model)
{ {
if (ModelState.IsValid) await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl);
{
await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl);
return RedirectToAction("Index", new
{
success = 3
});
}
return View(model);
}
public async Task<IActionResult> Confirm(string email, string token, string returnUrl)
{
var result = await _signInManager.PasswordlessSignInAsync(email, token, true);
if (!result.Succeeded)
{
return RedirectToAction("Index", new
{
error = 2
});
}
if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl))
{
return Redirect(returnUrl);
}
return RedirectToAction("Index", "Home");
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Logout()
{
await _signInManager.SignOutAsync();
return RedirectToAction("Index", new return RedirectToAction("Index", new
{ {
success = 1 success = 3
}); });
} }
private string GetMessage(int? messageCode) return View(model);
}
public async Task<IActionResult> Confirm(string email, string token, string returnUrl)
{
var result = await _signInManager.PasswordlessSignInAsync(email, token, true);
if (!result.Succeeded)
{ {
return messageCode switch return RedirectToAction("Index", new
{ {
1 => "You have been logged out.", error = 2
2 => "This login confirmation link is invalid. Try logging in again.", });
3 => "If a valid admin user with this email address exists, " +
"we've sent you an email with a secure link to log in.",
4 => "Access denied. Please log in.",
_ => null,
};
} }
if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl))
{
return Redirect(returnUrl);
}
return RedirectToAction("Index", "Home");
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Logout()
{
await _signInManager.SignOutAsync();
return RedirectToAction("Index", new
{
success = 1
});
}
private string GetMessage(int? messageCode)
{
return messageCode switch
{
1 => "You have been logged out.",
2 => "This login confirmation link is invalid. Try logging in again.",
3 => "If a valid admin user with this email address exists, " +
"we've sent you an email with a secure link to log in.",
4 => "Access denied. Please log in.",
_ => null,
};
} }
} }

View File

@ -7,87 +7,86 @@ using Microsoft.Azure.Cosmos;
using Microsoft.Azure.Cosmos.Linq; using Microsoft.Azure.Cosmos.Linq;
using Serilog.Events; using Serilog.Events;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
[Authorize]
[SelfHosted(NotSelfHostedOnly = true)]
public class LogsController : Controller
{ {
[Authorize] private const string Database = "Diagnostics";
[SelfHosted(NotSelfHostedOnly = true)] private const string Container = "Logs";
public class LogsController : Controller
private readonly GlobalSettings _globalSettings;
public LogsController(GlobalSettings globalSettings)
{ {
private const string Database = "Diagnostics"; _globalSettings = globalSettings;
private const string Container = "Logs"; }
private readonly GlobalSettings _globalSettings; public async Task<IActionResult> Index(string cursor = null, int count = 50,
LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null)
public LogsController(GlobalSettings globalSettings) {
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{ {
_globalSettings = globalSettings; var cosmosContainer = client.GetContainer(Database, Container);
} var query = cosmosContainer.GetItemLinqQueryable<LogModel>(
requestOptions: new QueryRequestOptions()
{
MaxItemCount = count
},
continuationToken: cursor
).AsQueryable();
public async Task<IActionResult> Index(string cursor = null, int count = 50, if (level.HasValue)
LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null)
{
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{ {
var cosmosContainer = client.GetContainer(Database, Container); query = query.Where(l => l.Level == level.Value.ToString());
var query = cosmosContainer.GetItemLinqQueryable<LogModel>(
requestOptions: new QueryRequestOptions()
{
MaxItemCount = count
},
continuationToken: cursor
).AsQueryable();
if (level.HasValue)
{
query = query.Where(l => l.Level == level.Value.ToString());
}
if (!string.IsNullOrWhiteSpace(project))
{
query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project);
}
if (start.HasValue)
{
query = query.Where(l => l.Timestamp >= start.Value);
}
if (end.HasValue)
{
query = query.Where(l => l.Timestamp <= end.Value);
}
var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator();
var response = await feedIterator.ReadNextAsync();
return View(new LogsModel
{
Level = level,
Project = project,
Start = start,
End = end,
Items = response.ToList(),
Count = count,
Cursor = cursor,
NextCursor = response.ContinuationToken
});
} }
} if (!string.IsNullOrWhiteSpace(project))
public async Task<IActionResult> View(Guid id)
{
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{ {
var cosmosContainer = client.GetContainer(Database, Container); query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project);
var query = cosmosContainer.GetItemLinqQueryable<LogDetailsModel>()
.AsQueryable()
.Where(l => l.Id == id.ToString());
var response = await query.ToFeedIterator().ReadNextAsync();
if (response == null || response.Count == 0)
{
return RedirectToAction("Index");
}
return View(response.First());
} }
if (start.HasValue)
{
query = query.Where(l => l.Timestamp >= start.Value);
}
if (end.HasValue)
{
query = query.Where(l => l.Timestamp <= end.Value);
}
var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator();
var response = await feedIterator.ReadNextAsync();
return View(new LogsModel
{
Level = level,
Project = project,
Start = start,
End = end,
Items = response.ToList(),
Count = count,
Cursor = cursor,
NextCursor = response.ContinuationToken
});
}
}
public async Task<IActionResult> View(Guid id)
{
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{
var cosmosContainer = client.GetContainer(Database, Container);
var query = cosmosContainer.GetItemLinqQueryable<LogDetailsModel>()
.AsQueryable()
.Where(l => l.Id == id.ToString());
var response = await query.ToFeedIterator().ReadNextAsync();
if (response == null || response.Count == 0)
{
return RedirectToAction("Index");
}
return View(response.First());
} }
} }
} }

View File

@ -11,207 +11,206 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
[Authorize]
public class OrganizationsController : Controller
{ {
[Authorize] private readonly IOrganizationRepository _organizationRepository;
public class OrganizationsController : Controller private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IOrganizationConnectionRepository _organizationConnectionRepository;
private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand;
private readonly ICipherRepository _cipherRepository;
private readonly ICollectionRepository _collectionRepository;
private readonly IGroupRepository _groupRepository;
private readonly IPolicyRepository _policyRepository;
private readonly IPaymentService _paymentService;
private readonly ILicensingService _licensingService;
private readonly IApplicationCacheService _applicationCacheService;
private readonly GlobalSettings _globalSettings;
private readonly IReferenceEventService _referenceEventService;
private readonly IUserService _userService;
private readonly ILogger<OrganizationsController> _logger;
public OrganizationsController(
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationConnectionRepository organizationConnectionRepository,
ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand,
ICipherRepository cipherRepository,
ICollectionRepository collectionRepository,
IGroupRepository groupRepository,
IPolicyRepository policyRepository,
IPaymentService paymentService,
ILicensingService licensingService,
IApplicationCacheService applicationCacheService,
GlobalSettings globalSettings,
IReferenceEventService referenceEventService,
IUserService userService,
ILogger<OrganizationsController> logger)
{ {
private readonly IOrganizationRepository _organizationRepository; _organizationRepository = organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository; _organizationUserRepository = organizationUserRepository;
private readonly IOrganizationConnectionRepository _organizationConnectionRepository; _organizationConnectionRepository = organizationConnectionRepository;
private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand; _syncSponsorshipsCommand = syncSponsorshipsCommand;
private readonly ICipherRepository _cipherRepository; _cipherRepository = cipherRepository;
private readonly ICollectionRepository _collectionRepository; _collectionRepository = collectionRepository;
private readonly IGroupRepository _groupRepository; _groupRepository = groupRepository;
private readonly IPolicyRepository _policyRepository; _policyRepository = policyRepository;
private readonly IPaymentService _paymentService; _paymentService = paymentService;
private readonly ILicensingService _licensingService; _licensingService = licensingService;
private readonly IApplicationCacheService _applicationCacheService; _applicationCacheService = applicationCacheService;
private readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
private readonly IReferenceEventService _referenceEventService; _referenceEventService = referenceEventService;
private readonly IUserService _userService; _userService = userService;
private readonly ILogger<OrganizationsController> _logger; _logger = logger;
public OrganizationsController(
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationConnectionRepository organizationConnectionRepository,
ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand,
ICipherRepository cipherRepository,
ICollectionRepository collectionRepository,
IGroupRepository groupRepository,
IPolicyRepository policyRepository,
IPaymentService paymentService,
ILicensingService licensingService,
IApplicationCacheService applicationCacheService,
GlobalSettings globalSettings,
IReferenceEventService referenceEventService,
IUserService userService,
ILogger<OrganizationsController> logger)
{
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_organizationConnectionRepository = organizationConnectionRepository;
_syncSponsorshipsCommand = syncSponsorshipsCommand;
_cipherRepository = cipherRepository;
_collectionRepository = collectionRepository;
_groupRepository = groupRepository;
_policyRepository = policyRepository;
_paymentService = paymentService;
_licensingService = licensingService;
_applicationCacheService = applicationCacheService;
_globalSettings = globalSettings;
_referenceEventService = referenceEventService;
_userService = userService;
_logger = logger;
}
public async Task<IActionResult> Index(string name = null, string userEmail = null, bool? paid = null,
int page = 1, int count = 25)
{
if (page < 1)
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count);
return View(new OrganizationsModel
{
Items = organizations as List<Organization>,
Name = string.IsNullOrWhiteSpace(name) ? null : name,
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
Paid = paid,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
}
public async Task<IActionResult> View(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies));
}
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(organization);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies,
billingInfo, billingSyncConnection, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, OrganizationEditModel model)
{
var organization = await _organizationRepository.GetByIdAsync(id);
model.ToOrganization(organization);
await _organizationRepository.ReplaceAsync(organization);
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization)
{
EventRaisedByUser = _userService.GetUserName(User),
SalesAssistedTrialStarted = model.SalesAssistedTrialStarted,
});
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization != null)
{
await _organizationRepository.DeleteAsync(organization);
await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id);
}
return RedirectToAction("Index");
}
public async Task<IActionResult> TriggerBillingSync(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault();
if (connection != null)
{
try
{
var config = connection.GetConfig<BillingSyncConfig>();
await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection);
TempData["ConnectionActivated"] = id;
TempData["ConnectionError"] = null;
}
catch (Exception ex)
{
TempData["ConnectionError"] = ex.Message;
_logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id);
}
if (_globalSettings.SelfHosted)
{
return RedirectToAction("View", new { id });
}
else
{
return RedirectToAction("Edit", new { id });
}
}
return RedirectToAction("Index");
}
} }
public async Task<IActionResult> Index(string name = null, string userEmail = null, bool? paid = null,
int page = 1, int count = 25)
{
if (page < 1)
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count);
return View(new OrganizationsModel
{
Items = organizations as List<Organization>,
Name = string.IsNullOrWhiteSpace(name) ? null : name,
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
Paid = paid,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
}
public async Task<IActionResult> View(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies));
}
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(organization);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies,
billingInfo, billingSyncConnection, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, OrganizationEditModel model)
{
var organization = await _organizationRepository.GetByIdAsync(id);
model.ToOrganization(organization);
await _organizationRepository.ReplaceAsync(organization);
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization)
{
EventRaisedByUser = _userService.GetUserName(User),
SalesAssistedTrialStarted = model.SalesAssistedTrialStarted,
});
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization != null)
{
await _organizationRepository.DeleteAsync(organization);
await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id);
}
return RedirectToAction("Index");
}
public async Task<IActionResult> TriggerBillingSync(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault();
if (connection != null)
{
try
{
var config = connection.GetConfig<BillingSyncConfig>();
await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection);
TempData["ConnectionActivated"] = id;
TempData["ConnectionError"] = null;
}
catch (Exception ex)
{
TempData["ConnectionError"] = ex.Message;
_logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id);
}
if (_globalSettings.SelfHosted)
{
return RedirectToAction("View", new { id });
}
else
{
return RedirectToAction("Edit", new { id });
}
}
return RedirectToAction("Index");
}
} }

View File

@ -7,128 +7,127 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
[Authorize]
[SelfHosted(NotSelfHostedOnly = true)]
public class ProvidersController : Controller
{ {
[Authorize] private readonly IProviderRepository _providerRepository;
[SelfHosted(NotSelfHostedOnly = true)] private readonly IProviderUserRepository _providerUserRepository;
public class ProvidersController : Controller private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly GlobalSettings _globalSettings;
private readonly IApplicationCacheService _applicationCacheService;
private readonly IProviderService _providerService;
public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService,
GlobalSettings globalSettings, IApplicationCacheService applicationCacheService)
{ {
private readonly IProviderRepository _providerRepository; _providerRepository = providerRepository;
private readonly IProviderUserRepository _providerUserRepository; _providerUserRepository = providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository; _providerOrganizationRepository = providerOrganizationRepository;
private readonly GlobalSettings _globalSettings; _providerService = providerService;
private readonly IApplicationCacheService _applicationCacheService; _globalSettings = globalSettings;
private readonly IProviderService _providerService; _applicationCacheService = applicationCacheService;
}
public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, public async Task<IActionResult> Index(string name = null, string userEmail = null, int page = 1, int count = 25)
IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService, {
GlobalSettings globalSettings, IApplicationCacheService applicationCacheService) if (page < 1)
{ {
_providerRepository = providerRepository; page = 1;
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_providerService = providerService;
_globalSettings = globalSettings;
_applicationCacheService = applicationCacheService;
} }
public async Task<IActionResult> Index(string name = null, string userEmail = null, int page = 1, int count = 25) if (count < 1)
{ {
if (page < 1) count = 1;
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count);
return View(new ProvidersModel
{
Items = providers as List<Provider>,
Name = string.IsNullOrWhiteSpace(name) ? null : name,
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
} }
public IActionResult Create(string ownerEmail = null) var skip = (page - 1) * count;
var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count);
return View(new ProvidersModel
{ {
return View(new CreateProviderModel Items = providers as List<Provider>,
{ Name = string.IsNullOrWhiteSpace(name) ? null : name,
OwnerEmail = ownerEmail UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
}); Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
}
public IActionResult Create(string ownerEmail = null)
{
return View(new CreateProviderModel
{
OwnerEmail = ownerEmail
});
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Create(CreateProviderModel model)
{
if (!ModelState.IsValid)
{
return View(model);
} }
[HttpPost] await _providerService.CreateAsync(model.OwnerEmail);
[ValidateAntiForgeryToken]
public async Task<IActionResult> Create(CreateProviderModel model) return RedirectToAction("Index");
}
public async Task<IActionResult> View(Guid id)
{
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{ {
if (!ModelState.IsValid)
{
return View(model);
}
await _providerService.CreateAsync(model.OwnerEmail);
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
public async Task<IActionResult> View(Guid id) var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id);
{ var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
var provider = await _providerRepository.GetByIdAsync(id); return View(new ProviderViewModel(provider, users, providerOrganizations));
if (provider == null) }
{
return RedirectToAction("Index");
}
var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); [SelfHosted(NotSelfHostedOnly = true)]
var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); public async Task<IActionResult> Edit(Guid id)
return View(new ProviderViewModel(provider, users, providerOrganizations)); {
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{
return RedirectToAction("Index");
} }
[SelfHosted(NotSelfHostedOnly = true)] var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id);
public async Task<IActionResult> Edit(Guid id) var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
{ return View(new ProviderEditModel(provider, users, providerOrganizations));
var provider = await _providerRepository.GetByIdAsync(id); }
if (provider == null)
{
return RedirectToAction("Index");
}
var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); [HttpPost]
var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); [ValidateAntiForgeryToken]
return View(new ProviderEditModel(provider, users, providerOrganizations)); [SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, ProviderEditModel model)
{
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{
return RedirectToAction("Index");
} }
[HttpPost] model.ToProvider(provider);
[ValidateAntiForgeryToken] await _providerRepository.ReplaceAsync(provider);
[SelfHosted(NotSelfHostedOnly = true)] await _applicationCacheService.UpsertProviderAbilityAsync(provider);
public async Task<IActionResult> Edit(Guid id, ProviderEditModel model) return RedirectToAction("Edit", new { id });
{ }
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{
return RedirectToAction("Index");
}
model.ToProvider(provider); public async Task<IActionResult> ResendInvite(Guid ownerId, Guid providerId)
await _providerRepository.ReplaceAsync(provider); {
await _applicationCacheService.UpsertProviderAbilityAsync(provider); await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId);
return RedirectToAction("Edit", new { id }); TempData["InviteResentTo"] = ownerId;
} return RedirectToAction("Edit", new { id = providerId });
public async Task<IActionResult> ResendInvite(Guid ownerId, Guid providerId)
{
await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId);
TempData["InviteResentTo"] = ownerId;
return RedirectToAction("Edit", new { id = providerId });
}
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -7,105 +7,104 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers namespace Bit.Admin.Controllers;
[Authorize]
public class UsersController : Controller
{ {
[Authorize] private readonly IUserRepository _userRepository;
public class UsersController : Controller private readonly ICipherRepository _cipherRepository;
private readonly IPaymentService _paymentService;
private readonly GlobalSettings _globalSettings;
public UsersController(
IUserRepository userRepository,
ICipherRepository cipherRepository,
IPaymentService paymentService,
GlobalSettings globalSettings)
{ {
private readonly IUserRepository _userRepository; _userRepository = userRepository;
private readonly ICipherRepository _cipherRepository; _cipherRepository = cipherRepository;
private readonly IPaymentService _paymentService; _paymentService = paymentService;
private readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
}
public UsersController( public async Task<IActionResult> Index(string email, int page = 1, int count = 25)
IUserRepository userRepository, {
ICipherRepository cipherRepository, if (page < 1)
IPaymentService paymentService,
GlobalSettings globalSettings)
{ {
_userRepository = userRepository; page = 1;
_cipherRepository = cipherRepository;
_paymentService = paymentService;
_globalSettings = globalSettings;
} }
public async Task<IActionResult> Index(string email, int page = 1, int count = 25) if (count < 1)
{ {
if (page < 1) count = 1;
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var users = await _userRepository.SearchAsync(email, skip, count);
return View(new UsersModel
{
Items = users as List<User>,
Email = string.IsNullOrWhiteSpace(email) ? null : email,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit"
});
} }
public async Task<IActionResult> View(Guid id) var skip = (page - 1) * count;
var users = await _userRepository.SearchAsync(email, skip, count);
return View(new UsersModel
{ {
var user = await _userRepository.GetByIdAsync(id); Items = users as List<User>,
if (user == null) Email = string.IsNullOrWhiteSpace(email) ? null : email,
{ Page = page,
return RedirectToAction("Index"); Count = count,
} Action = _globalSettings.SelfHosted ? "View" : "Edit"
});
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); public async Task<IActionResult> View(Guid id)
return View(new UserViewModel(user, ciphers)); {
} var user = await _userRepository.GetByIdAsync(id);
if (user == null)
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{ {
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(user);
return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, UserEditModel model)
{
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
model.ToUser(user);
await _userRepository.ReplaceAsync(user);
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var user = await _userRepository.GetByIdAsync(id);
if (user != null)
{
await _userRepository.DeleteAsync(user);
}
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
return View(new UserViewModel(user, ciphers));
}
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(user);
return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, UserEditModel model)
{
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
model.ToUser(user);
await _userRepository.ReplaceAsync(user);
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var user = await _userRepository.GetByIdAsync(id);
if (user != null)
{
await _userRepository.DeleteAsync(user);
}
return RedirectToAction("Index");
} }
} }

View File

@ -4,81 +4,80 @@ using Amazon.SQS.Model;
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Admin.HostedServices namespace Bit.Admin.HostedServices;
public class AmazonSqsBlockIpHostedService : BlockIpHostedService
{ {
public class AmazonSqsBlockIpHostedService : BlockIpHostedService private AmazonSQSClient _client;
public AmazonSqsBlockIpHostedService(
ILogger<AmazonSqsBlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
: base(logger, adminSettings, globalSettings)
{ }
public override void Dispose()
{ {
private AmazonSQSClient _client; _client?.Dispose();
}
public AmazonSqsBlockIpHostedService( protected override async Task ExecuteAsync(CancellationToken cancellationToken)
ILogger<AmazonSqsBlockIpHostedService> logger, {
IOptions<AdminSettings> adminSettings, _client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId,
GlobalSettings globalSettings) _globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region));
: base(logger, adminSettings, globalSettings) var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken);
{ } var blockIpQueueUrl = blockIpQueue.QueueUrl;
var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken);
var unblockIpQueueUrl = unblockIpQueue.QueueUrl;
public override void Dispose() while (!cancellationToken.IsCancellationRequested)
{ {
_client?.Dispose(); var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest
}
protected override async Task ExecuteAsync(CancellationToken cancellationToken)
{
_client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId,
_globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region));
var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken);
var blockIpQueueUrl = blockIpQueue.QueueUrl;
var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken);
var unblockIpQueueUrl = unblockIpQueue.QueueUrl;
while (!cancellationToken.IsCancellationRequested)
{ {
var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest QueueUrl = blockIpQueueUrl,
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (blockMessageResponse.Messages.Any())
{
foreach (var message in blockMessageResponse.Messages)
{ {
QueueUrl = blockIpQueueUrl, try
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (blockMessageResponse.Messages.Any())
{
foreach (var message in blockMessageResponse.Messages)
{ {
try await BlockIpAsync(message.Body, cancellationToken);
{
await BlockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to block IP.");
}
await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken);
} }
} catch (Exception e)
var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest
{
QueueUrl = unblockIpQueueUrl,
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (unblockMessageResponse.Messages.Any())
{
foreach (var message in unblockMessageResponse.Messages)
{ {
try _logger.LogError(e, "Failed to block IP.");
{
await UnblockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken);
} }
await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken);
} }
await Task.Delay(TimeSpan.FromSeconds(15));
} }
var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest
{
QueueUrl = unblockIpQueueUrl,
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (unblockMessageResponse.Messages.Any())
{
foreach (var message in unblockMessageResponse.Messages)
{
try
{
await UnblockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken);
}
}
await Task.Delay(TimeSpan.FromSeconds(15));
} }
} }
} }

View File

@ -2,63 +2,62 @@
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Admin.HostedServices namespace Bit.Admin.HostedServices;
public class AzureQueueBlockIpHostedService : BlockIpHostedService
{ {
public class AzureQueueBlockIpHostedService : BlockIpHostedService private QueueClient _blockIpQueueClient;
private QueueClient _unblockIpQueueClient;
public AzureQueueBlockIpHostedService(
ILogger<AzureQueueBlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
: base(logger, adminSettings, globalSettings)
{ }
protected override async Task ExecuteAsync(CancellationToken cancellationToken)
{ {
private QueueClient _blockIpQueueClient; _blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip");
private QueueClient _unblockIpQueueClient; _unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip");
public AzureQueueBlockIpHostedService( while (!cancellationToken.IsCancellationRequested)
ILogger<AzureQueueBlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
: base(logger, adminSettings, globalSettings)
{ }
protected override async Task ExecuteAsync(CancellationToken cancellationToken)
{ {
_blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip"); var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32);
_unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip"); if (blockMessages.Value?.Any() ?? false)
while (!cancellationToken.IsCancellationRequested)
{ {
var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); foreach (var message in blockMessages.Value)
if (blockMessages.Value?.Any() ?? false)
{ {
foreach (var message in blockMessages.Value) try
{ {
try await BlockIpAsync(message.MessageText, cancellationToken);
{
await BlockIpAsync(message.MessageText, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to block IP.");
}
await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
} }
} catch (Exception e)
var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32);
if (unblockMessages.Value?.Any() ?? false)
{
foreach (var message in unblockMessages.Value)
{ {
try _logger.LogError(e, "Failed to block IP.");
{
await UnblockIpAsync(message.MessageText, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
} }
await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
} }
await Task.Delay(TimeSpan.FromSeconds(15));
} }
var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32);
if (unblockMessages.Value?.Any() ?? false)
{
foreach (var message in unblockMessages.Value)
{
try
{
await UnblockIpAsync(message.MessageText, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
}
}
await Task.Delay(TimeSpan.FromSeconds(15));
} }
} }
} }

View File

@ -6,97 +6,96 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Admin.HostedServices namespace Bit.Admin.HostedServices;
public class AzureQueueMailHostedService : IHostedService
{ {
public class AzureQueueMailHostedService : IHostedService private readonly ILogger<AzureQueueMailHostedService> _logger;
private readonly GlobalSettings _globalSettings;
private readonly IMailService _mailService;
private CancellationTokenSource _cts;
private Task _executingTask;
private QueueClient _mailQueueClient;
public AzureQueueMailHostedService(
ILogger<AzureQueueMailHostedService> logger,
IMailService mailService,
GlobalSettings globalSettings)
{ {
private readonly ILogger<AzureQueueMailHostedService> _logger; _logger = logger;
private readonly GlobalSettings _globalSettings; _mailService = mailService;
private readonly IMailService _mailService; _globalSettings = globalSettings;
private CancellationTokenSource _cts; }
private Task _executingTask;
private QueueClient _mailQueueClient; public Task StartAsync(CancellationToken cancellationToken)
{
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public AzureQueueMailHostedService( public async Task StopAsync(CancellationToken cancellationToken)
ILogger<AzureQueueMailHostedService> logger, {
IMailService mailService, if (_executingTask == null)
GlobalSettings globalSettings)
{ {
_logger = logger; return;
_mailService = mailService;
_globalSettings = globalSettings;
} }
_cts.Cancel();
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
cancellationToken.ThrowIfCancellationRequested();
}
public Task StartAsync(CancellationToken cancellationToken) private async Task ExecuteAsync(CancellationToken cancellationToken)
{ {
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); _mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail");
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken) QueueMessage[] mailMessages;
while (!cancellationToken.IsCancellationRequested)
{ {
if (_executingTask == null) if (!(mailMessages = await RetrieveMessagesAsync()).Any())
{ {
return; await Task.Delay(TimeSpan.FromSeconds(15));
} }
_cts.Cancel();
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
cancellationToken.ThrowIfCancellationRequested();
}
private async Task ExecuteAsync(CancellationToken cancellationToken) foreach (var message in mailMessages)
{
_mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail");
QueueMessage[] mailMessages;
while (!cancellationToken.IsCancellationRequested)
{ {
if (!(mailMessages = await RetrieveMessagesAsync()).Any()) try
{ {
await Task.Delay(TimeSpan.FromSeconds(15)); using var document = JsonDocument.Parse(message.DecodeMessageText());
} var root = document.RootElement;
foreach (var message in mailMessages) if (root.ValueKind == JsonValueKind.Array)
{
try
{ {
using var document = JsonDocument.Parse(message.DecodeMessageText()); foreach (var mailQueueMessage in root.ToObject<List<MailQueueMessage>>())
var root = document.RootElement;
if (root.ValueKind == JsonValueKind.Array)
{ {
foreach (var mailQueueMessage in root.ToObject<List<MailQueueMessage>>())
{
await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage);
}
}
else if (root.ValueKind == JsonValueKind.Object)
{
var mailQueueMessage = root.ToObject<MailQueueMessage>();
await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage);
} }
} }
catch (Exception e) else if (root.ValueKind == JsonValueKind.Object)
{ {
_logger.LogError(e, "Failed to send email"); var mailQueueMessage = root.ToObject<MailQueueMessage>();
// TODO: retries? await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage);
} }
}
catch (Exception e)
{
_logger.LogError(e, "Failed to send email");
// TODO: retries?
}
await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
break; break;
}
} }
} }
} }
}
private async Task<QueueMessage[]> RetrieveMessagesAsync() private async Task<QueueMessage[]> RetrieveMessagesAsync()
{ {
return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { }; return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { };
}
} }
} }

View File

@ -1,71 +1,105 @@
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Admin.HostedServices namespace Bit.Admin.HostedServices;
public abstract class BlockIpHostedService : IHostedService, IDisposable
{ {
public abstract class BlockIpHostedService : IHostedService, IDisposable protected readonly ILogger<BlockIpHostedService> _logger;
protected readonly GlobalSettings _globalSettings;
private readonly AdminSettings _adminSettings;
private Task _executingTask;
private CancellationTokenSource _cts;
private HttpClient _httpClient = new HttpClient();
public BlockIpHostedService(
ILogger<BlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
{ {
protected readonly ILogger<BlockIpHostedService> _logger; _logger = logger;
protected readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
private readonly AdminSettings _adminSettings; _adminSettings = adminSettings?.Value;
}
private Task _executingTask; public Task StartAsync(CancellationToken cancellationToken)
private CancellationTokenSource _cts; {
private HttpClient _httpClient = new HttpClient(); _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public BlockIpHostedService( public async Task StopAsync(CancellationToken cancellationToken)
ILogger<BlockIpHostedService> logger, {
IOptions<AdminSettings> adminSettings, if (_executingTask == null)
GlobalSettings globalSettings)
{ {
_logger = logger; return;
_globalSettings = globalSettings;
_adminSettings = adminSettings?.Value;
} }
_cts.Cancel();
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
cancellationToken.ThrowIfCancellationRequested();
}
public Task StartAsync(CancellationToken cancellationToken) public virtual void Dispose()
{ { }
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken) protected abstract Task ExecuteAsync(CancellationToken cancellationToken);
protected async Task BlockIpAsync(string message, CancellationToken cancellationToken)
{
var request = new HttpRequestMessage();
request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Post;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules");
request.Content = JsonContent.Create(new
{ {
if (_executingTask == null) mode = "block",
configuration = new
{ {
return; target = "ip",
} value = message
_cts.Cancel(); },
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}."
cancellationToken.ThrowIfCancellationRequested(); });
var response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode)
{
return;
} }
public virtual void Dispose() var accessRuleResponse = await response.Content.ReadFromJsonAsync<AccessRuleResponse>(cancellationToken: cancellationToken);
{ } if (!accessRuleResponse.Success)
protected abstract Task ExecuteAsync(CancellationToken cancellationToken);
protected async Task BlockIpAsync(string message, CancellationToken cancellationToken)
{ {
return;
}
// TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue
}
protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(message))
{
return;
}
if (message.Contains(".") || message.Contains(":"))
{
// IP address messages
var request = new HttpRequestMessage(); var request = new HttpRequestMessage();
request.Headers.Accept.Clear(); request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Post; request.Method = HttpMethod.Get;
request.RequestUri = new Uri("https://api.cloudflare.com/" + request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules"); $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" +
$"configuration_target=ip&configuration_value={message}");
request.Content = JsonContent.Create(new
{
mode = "block",
configuration = new
{
target = "ip",
value = message
},
notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}."
});
var response = await _httpClient.SendAsync(request, cancellationToken); var response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode) if (!response.IsSuccessStatusCode)
@ -73,93 +107,58 @@ namespace Bit.Admin.HostedServices
return; return;
} }
var accessRuleResponse = await response.Content.ReadFromJsonAsync<AccessRuleResponse>(cancellationToken: cancellationToken); var listResponse = await response.Content.ReadFromJsonAsync<ListResponse>(cancellationToken: cancellationToken);
if (!accessRuleResponse.Success) if (!listResponse.Success)
{ {
return; return;
} }
// TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue foreach (var rule in listResponse.Result)
}
protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(message))
{ {
return; await DeleteAccessRuleAsync(rule.Id, cancellationToken);
}
if (message.Contains(".") || message.Contains(":"))
{
// IP address messages
var request = new HttpRequestMessage();
request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Get;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" +
$"configuration_target=ip&configuration_value={message}");
var response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode)
{
return;
}
var listResponse = await response.Content.ReadFromJsonAsync<ListResponse>(cancellationToken: cancellationToken);
if (!listResponse.Success)
{
return;
}
foreach (var rule in listResponse.Result)
{
await DeleteAccessRuleAsync(rule.Id, cancellationToken);
}
}
else
{
// Rule Id messages
await DeleteAccessRuleAsync(message, cancellationToken);
} }
} }
else
protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken)
{ {
var request = new HttpRequestMessage(); // Rule Id messages
request.Headers.Accept.Clear(); await DeleteAccessRuleAsync(message, cancellationToken);
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Delete;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}");
await _httpClient.SendAsync(request, cancellationToken);
} }
}
public class ListResponse protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken)
{
var request = new HttpRequestMessage();
request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Delete;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}");
await _httpClient.SendAsync(request, cancellationToken);
}
public class ListResponse
{
public bool Success { get; set; }
public List<AccessRuleResultResponse> Result { get; set; }
}
public class AccessRuleResponse
{
public bool Success { get; set; }
public AccessRuleResultResponse Result { get; set; }
}
public class AccessRuleResultResponse
{
public string Id { get; set; }
public string Notes { get; set; }
public ConfigurationResponse Configuration { get; set; }
public class ConfigurationResponse
{ {
public bool Success { get; set; } public string Target { get; set; }
public List<AccessRuleResultResponse> Result { get; set; } public string Value { get; set; }
}
public class AccessRuleResponse
{
public bool Success { get; set; }
public AccessRuleResultResponse Result { get; set; }
}
public class AccessRuleResultResponse
{
public string Id { get; set; }
public string Notes { get; set; }
public ConfigurationResponse Configuration { get; set; }
public class ConfigurationResponse
{
public string Target { get; set; }
public string Value { get; set; }
}
} }
} }
} }

View File

@ -3,62 +3,61 @@ using Bit.Core.Jobs;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Migrator; using Bit.Migrator;
namespace Bit.Admin.HostedServices namespace Bit.Admin.HostedServices;
public class DatabaseMigrationHostedService : IHostedService, IDisposable
{ {
public class DatabaseMigrationHostedService : IHostedService, IDisposable private readonly GlobalSettings _globalSettings;
private readonly ILogger<DatabaseMigrationHostedService> _logger;
private readonly DbMigrator _dbMigrator;
public DatabaseMigrationHostedService(
GlobalSettings globalSettings,
ILogger<DatabaseMigrationHostedService> logger,
ILogger<DbMigrator> migratorLogger,
ILogger<JobListener> listenerLogger)
{ {
private readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
private readonly ILogger<DatabaseMigrationHostedService> _logger; _logger = logger;
private readonly DbMigrator _dbMigrator; _dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger);
}
public DatabaseMigrationHostedService( public virtual async Task StartAsync(CancellationToken cancellationToken)
GlobalSettings globalSettings, {
ILogger<DatabaseMigrationHostedService> logger, // Wait 20 seconds to allow database to come online
ILogger<DbMigrator> migratorLogger, await Task.Delay(20000);
ILogger<JobListener> listenerLogger)
var maxMigrationAttempts = 10;
for (var i = 1; i <= maxMigrationAttempts; i++)
{ {
_globalSettings = globalSettings; try
_logger = logger;
_dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger);
}
public virtual async Task StartAsync(CancellationToken cancellationToken)
{
// Wait 20 seconds to allow database to come online
await Task.Delay(20000);
var maxMigrationAttempts = 10;
for (var i = 1; i <= maxMigrationAttempts; i++)
{ {
try _dbMigrator.MigrateMsSqlDatabase(true, cancellationToken);
// TODO: Maybe flip a flag somewhere to indicate migration is complete??
break;
}
catch (SqlException e)
{
if (i >= maxMigrationAttempts)
{ {
_dbMigrator.MigrateMsSqlDatabase(true, cancellationToken); _logger.LogError(e, "Database failed to migrate.");
// TODO: Maybe flip a flag somewhere to indicate migration is complete?? throw;
break;
} }
catch (SqlException e) else
{ {
if (i >= maxMigrationAttempts) _logger.LogError(e,
{ "Database unavailable for migration. Trying again (attempt #{0})...", i + 1);
_logger.LogError(e, "Database failed to migrate."); await Task.Delay(20000);
throw;
}
else
{
_logger.LogError(e,
"Database unavailable for migration. Trying again (attempt #{0})...", i + 1);
await Task.Delay(20000);
}
} }
} }
} }
public virtual Task StopAsync(CancellationToken cancellationToken)
{
return Task.FromResult(0);
}
public virtual void Dispose()
{ }
} }
public virtual Task StopAsync(CancellationToken cancellationToken)
{
return Task.FromResult(0);
}
public virtual void Dispose()
{ }
} }

View File

@ -3,27 +3,26 @@ using Bit.Core.Jobs;
using Bit.Core.Settings; using Bit.Core.Settings;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class AliveJob : BaseJob
{ {
public class AliveJob : BaseJob private readonly GlobalSettings _globalSettings;
private HttpClient _httpClient = new HttpClient();
public AliveJob(
GlobalSettings globalSettings,
ILogger<AliveJob> logger)
: base(logger)
{ {
private readonly GlobalSettings _globalSettings; _globalSettings = globalSettings;
private HttpClient _httpClient = new HttpClient(); }
public AliveJob( protected async override Task ExecuteJobAsync(IJobExecutionContext context)
GlobalSettings globalSettings, {
ILogger<AliveJob> logger) _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive");
: base(logger) var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin);
{ _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " +
_globalSettings = globalSettings; response.StatusCode);
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive");
var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " +
response.StatusCode);
}
} }
} }

View File

@ -3,25 +3,24 @@ using Bit.Core.Jobs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class DatabaseExpiredGrantsJob : BaseJob
{ {
public class DatabaseExpiredGrantsJob : BaseJob private readonly IMaintenanceRepository _maintenanceRepository;
public DatabaseExpiredGrantsJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseExpiredGrantsJob> logger)
: base(logger)
{ {
private readonly IMaintenanceRepository _maintenanceRepository; _maintenanceRepository = maintenanceRepository;
}
public DatabaseExpiredGrantsJob( protected async override Task ExecuteJobAsync(IJobExecutionContext context)
IMaintenanceRepository maintenanceRepository, {
ILogger<DatabaseExpiredGrantsJob> logger) _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync");
: base(logger) await _maintenanceRepository.DeleteExpiredGrantsAsync();
{ _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync");
_maintenanceRepository = maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync");
await _maintenanceRepository.DeleteExpiredGrantsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync");
}
} }
} }

View File

@ -4,36 +4,35 @@ using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class DatabaseExpiredSponsorshipsJob : BaseJob
{ {
public class DatabaseExpiredSponsorshipsJob : BaseJob private GlobalSettings _globalSettings;
private readonly IMaintenanceRepository _maintenanceRepository;
public DatabaseExpiredSponsorshipsJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseExpiredSponsorshipsJob> logger,
GlobalSettings globalSettings)
: base(logger)
{ {
private GlobalSettings _globalSettings; _maintenanceRepository = maintenanceRepository;
private readonly IMaintenanceRepository _maintenanceRepository; _globalSettings = globalSettings;
}
public DatabaseExpiredSponsorshipsJob( protected override async Task ExecuteJobAsync(IJobExecutionContext context)
IMaintenanceRepository maintenanceRepository, {
ILogger<DatabaseExpiredSponsorshipsJob> logger, if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication)
GlobalSettings globalSettings)
: base(logger)
{ {
_maintenanceRepository = maintenanceRepository; return;
_globalSettings = globalSettings;
} }
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync");
protected override async Task ExecuteJobAsync(IJobExecutionContext context) // allow a 90 day grace period before deleting
{ var deleteDate = DateTime.UtcNow.AddDays(-90);
if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication)
{
return;
}
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync");
// allow a 90 day grace period before deleting await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate);
var deleteDate = DateTime.UtcNow.AddDays(-90); _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync");
await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync");
}
} }
} }

View File

@ -3,25 +3,24 @@ using Bit.Core.Jobs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class DatabaseRebuildlIndexesJob : BaseJob
{ {
public class DatabaseRebuildlIndexesJob : BaseJob private readonly IMaintenanceRepository _maintenanceRepository;
public DatabaseRebuildlIndexesJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseRebuildlIndexesJob> logger)
: base(logger)
{ {
private readonly IMaintenanceRepository _maintenanceRepository; _maintenanceRepository = maintenanceRepository;
}
public DatabaseRebuildlIndexesJob( protected async override Task ExecuteJobAsync(IJobExecutionContext context)
IMaintenanceRepository maintenanceRepository, {
ILogger<DatabaseRebuildlIndexesJob> logger) _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync");
: base(logger) await _maintenanceRepository.RebuildIndexesAsync();
{ _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync");
_maintenanceRepository = maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync");
await _maintenanceRepository.RebuildIndexesAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync");
}
} }
} }

View File

@ -3,28 +3,27 @@ using Bit.Core.Jobs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class DatabaseUpdateStatisticsJob : BaseJob
{ {
public class DatabaseUpdateStatisticsJob : BaseJob private readonly IMaintenanceRepository _maintenanceRepository;
public DatabaseUpdateStatisticsJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseUpdateStatisticsJob> logger)
: base(logger)
{ {
private readonly IMaintenanceRepository _maintenanceRepository; _maintenanceRepository = maintenanceRepository;
}
public DatabaseUpdateStatisticsJob( protected async override Task ExecuteJobAsync(IJobExecutionContext context)
IMaintenanceRepository maintenanceRepository, {
ILogger<DatabaseUpdateStatisticsJob> logger) _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync");
: base(logger) await _maintenanceRepository.UpdateStatisticsAsync();
{ _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync");
_maintenanceRepository = maintenanceRepository; _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync");
} await _maintenanceRepository.DisableCipherAutoStatsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync");
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync");
await _maintenanceRepository.UpdateStatisticsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync");
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync");
await _maintenanceRepository.DisableCipherAutoStatsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync");
}
} }
} }

View File

@ -4,34 +4,33 @@ using Bit.Core.Repositories;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class DeleteCiphersJob : BaseJob
{ {
public class DeleteCiphersJob : BaseJob private readonly ICipherRepository _cipherRepository;
private readonly AdminSettings _adminSettings;
public DeleteCiphersJob(
ICipherRepository cipherRepository,
IOptions<AdminSettings> adminSettings,
ILogger<DeleteCiphersJob> logger)
: base(logger)
{ {
private readonly ICipherRepository _cipherRepository; _cipherRepository = cipherRepository;
private readonly AdminSettings _adminSettings; _adminSettings = adminSettings?.Value;
}
public DeleteCiphersJob( protected async override Task ExecuteJobAsync(IJobExecutionContext context)
ICipherRepository cipherRepository, {
IOptions<AdminSettings> adminSettings, _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync");
ILogger<DeleteCiphersJob> logger) var deleteDate = DateTime.UtcNow.AddDays(-30);
: base(logger) var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault();
if (daysAgoSetting > 0)
{ {
_cipherRepository = cipherRepository; deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting);
_adminSettings = adminSettings?.Value;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync");
var deleteDate = DateTime.UtcNow.AddDays(-30);
var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault();
if (daysAgoSetting > 0)
{
deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting);
}
await _cipherRepository.DeleteDeletedAsync(deleteDate);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync");
} }
await _cipherRepository.DeleteDeletedAsync(deleteDate);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync");
} }
} }

View File

@ -4,38 +4,37 @@ using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class DeleteSendsJob : BaseJob
{ {
public class DeleteSendsJob : BaseJob private readonly ISendRepository _sendRepository;
private readonly IServiceProvider _serviceProvider;
public DeleteSendsJob(
ISendRepository sendRepository,
IServiceProvider serviceProvider,
ILogger<DatabaseExpiredGrantsJob> logger)
: base(logger)
{ {
private readonly ISendRepository _sendRepository; _sendRepository = sendRepository;
private readonly IServiceProvider _serviceProvider; _serviceProvider = serviceProvider;
}
public DeleteSendsJob( protected async override Task ExecuteJobAsync(IJobExecutionContext context)
ISendRepository sendRepository, {
IServiceProvider serviceProvider, var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow);
ILogger<DatabaseExpiredGrantsJob> logger) _logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count);
: base(logger) if (!sends.Any())
{ {
_sendRepository = sendRepository; return;
_serviceProvider = serviceProvider;
} }
using (var scope = _serviceProvider.CreateScope())
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{ {
var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow); var sendService = scope.ServiceProvider.GetRequiredService<ISendService>();
_logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count); foreach (var send in sends)
if (!sends.Any())
{ {
return; await sendService.DeleteSendAsync(send);
}
using (var scope = _serviceProvider.CreateScope())
{
var sendService = scope.ServiceProvider.GetRequiredService<ISendService>();
foreach (var send in sends)
{
await sendService.DeleteSendAsync(send);
}
} }
} }
} }

View File

@ -3,94 +3,93 @@ using Bit.Core.Jobs;
using Bit.Core.Settings; using Bit.Core.Settings;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs namespace Bit.Admin.Jobs;
public class JobsHostedService : BaseJobsHostedService
{ {
public class JobsHostedService : BaseJobsHostedService public JobsHostedService(
GlobalSettings globalSettings,
IServiceProvider serviceProvider,
ILogger<JobsHostedService> logger,
ILogger<JobListener> listenerLogger)
: base(globalSettings, serviceProvider, logger, listenerLogger) { }
public override async Task StartAsync(CancellationToken cancellationToken)
{ {
public JobsHostedService( var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
GlobalSettings globalSettings, TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") :
IServiceProvider serviceProvider, TimeZoneInfo.FindSystemTimeZoneById("America/New_York");
ILogger<JobsHostedService> logger, if (_globalSettings.SelfHosted)
ILogger<JobListener> listenerLogger)
: base(globalSettings, serviceProvider, logger, listenerLogger) { }
public override async Task StartAsync(CancellationToken cancellationToken)
{ {
var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? timeZone = TimeZoneInfo.Local;
TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") :
TimeZoneInfo.FindSystemTimeZoneById("America/New_York");
if (_globalSettings.SelfHosted)
{
timeZone = TimeZoneInfo.Local;
}
var everyTopOfTheHourTrigger = TriggerBuilder.Create()
.WithIdentity("EveryTopOfTheHourTrigger")
.StartNow()
.WithCronSchedule("0 0 * * * ?")
.Build();
var everyFiveMinutesTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFiveMinutesTrigger")
.StartNow()
.WithCronSchedule("0 */5 * * * ?")
.Build();
var everyFridayAt10pmTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFridayAt10pmTrigger")
.StartNow()
.WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone))
.Build();
var everySaturdayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySaturdayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone))
.Build();
var everySundayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySundayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone))
.Build();
var everyMondayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EveryMondayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone))
.Build();
var everyDayAtMidnightUtc = TriggerBuilder.Create()
.WithIdentity("EveryDayAtMidnightUtc")
.StartNow()
.WithCronSchedule("0 0 0 * * ?")
.Build();
var jobs = new List<Tuple<Type, ITrigger>>
{
new Tuple<Type, ITrigger>(typeof(DeleteSendsJob), everyFiveMinutesTrigger),
new Tuple<Type, ITrigger>(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger),
new Tuple<Type, ITrigger>(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger),
new Tuple<Type, ITrigger>(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger),
new Tuple<Type, ITrigger>(typeof(DeleteCiphersJob), everyDayAtMidnightUtc),
new Tuple<Type, ITrigger>(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger)
};
if (!_globalSettings.SelfHosted)
{
jobs.Add(new Tuple<Type, ITrigger>(typeof(AliveJob), everyTopOfTheHourTrigger));
}
Jobs = jobs;
await base.StartAsync(cancellationToken);
} }
public static void AddJobsServices(IServiceCollection services, bool selfHosted) var everyTopOfTheHourTrigger = TriggerBuilder.Create()
.WithIdentity("EveryTopOfTheHourTrigger")
.StartNow()
.WithCronSchedule("0 0 * * * ?")
.Build();
var everyFiveMinutesTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFiveMinutesTrigger")
.StartNow()
.WithCronSchedule("0 */5 * * * ?")
.Build();
var everyFridayAt10pmTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFridayAt10pmTrigger")
.StartNow()
.WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone))
.Build();
var everySaturdayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySaturdayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone))
.Build();
var everySundayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySundayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone))
.Build();
var everyMondayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EveryMondayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone))
.Build();
var everyDayAtMidnightUtc = TriggerBuilder.Create()
.WithIdentity("EveryDayAtMidnightUtc")
.StartNow()
.WithCronSchedule("0 0 0 * * ?")
.Build();
var jobs = new List<Tuple<Type, ITrigger>>
{ {
if (!selfHosted) new Tuple<Type, ITrigger>(typeof(DeleteSendsJob), everyFiveMinutesTrigger),
{ new Tuple<Type, ITrigger>(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger),
services.AddTransient<AliveJob>(); new Tuple<Type, ITrigger>(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger),
} new Tuple<Type, ITrigger>(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger),
services.AddTransient<DatabaseUpdateStatisticsJob>(); new Tuple<Type, ITrigger>(typeof(DeleteCiphersJob), everyDayAtMidnightUtc),
services.AddTransient<DatabaseRebuildlIndexesJob>(); new Tuple<Type, ITrigger>(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger)
services.AddTransient<DatabaseExpiredGrantsJob>(); };
services.AddTransient<DatabaseExpiredSponsorshipsJob>();
services.AddTransient<DeleteSendsJob>(); if (!_globalSettings.SelfHosted)
services.AddTransient<DeleteCiphersJob>(); {
jobs.Add(new Tuple<Type, ITrigger>(typeof(AliveJob), everyTopOfTheHourTrigger));
} }
Jobs = jobs;
await base.StartAsync(cancellationToken);
}
public static void AddJobsServices(IServiceCollection services, bool selfHosted)
{
if (!selfHosted)
{
services.AddTransient<AliveJob>();
}
services.AddTransient<DatabaseUpdateStatisticsJob>();
services.AddTransient<DatabaseRebuildlIndexesJob>();
services.AddTransient<DatabaseExpiredGrantsJob>();
services.AddTransient<DatabaseExpiredSponsorshipsJob>();
services.AddTransient<DeleteSendsJob>();
services.AddTransient<DeleteCiphersJob>();
} }
} }

View File

@ -1,11 +1,10 @@
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class BillingInformationModel
{ {
public class BillingInformationModel public BillingInfo BillingInfo { get; set; }
{ public Guid? UserId { get; set; }
public BillingInfo BillingInfo { get; set; } public Guid? OrganizationId { get; set; }
public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; }
}
} }

View File

@ -1,27 +1,26 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
{
public class ChargeBraintreeModel : IValidatableObject
{
[Required]
[Display(Name = "Braintree Customer Id")]
public string Id { get; set; }
[Required]
[Display(Name = "Amount")]
public decimal? Amount { get; set; }
public string TransactionId { get; set; }
public string PayPalTransactionId { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext) public class ChargeBraintreeModel : IValidatableObject
{
[Required]
[Display(Name = "Braintree Customer Id")]
public string Id { get; set; }
[Required]
[Display(Name = "Amount")]
public decimal? Amount { get; set; }
public string TransactionId { get; set; }
public string PayPalTransactionId { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Id != null)
{ {
if (Id != null) if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') ||
!Guid.TryParse(Id.Substring(1, 32), out var guid))
{ {
if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || yield return new ValidationResult("Customer Id is not a valid format.");
!Guid.TryParse(Id.Substring(1, 32), out var guid))
{
yield return new ValidationResult("Customer Id is not a valid format.");
}
} }
} }
} }

View File

@ -1,13 +1,12 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
{
public class CreateProviderModel
{
public CreateProviderModel() { }
[Display(Name = "Owner Email")] public class CreateProviderModel
[Required] {
public string OwnerEmail { get; set; } public CreateProviderModel() { }
}
[Display(Name = "Owner Email")]
[Required]
public string OwnerEmail { get; set; }
} }

View File

@ -2,77 +2,76 @@
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class CreateUpdateTransactionModel : IValidatableObject
{ {
public class CreateUpdateTransactionModel : IValidatableObject public CreateUpdateTransactionModel() { }
public CreateUpdateTransactionModel(Transaction transaction)
{ {
public CreateUpdateTransactionModel() { } Edit = true;
UserId = transaction.UserId;
OrganizationId = transaction.OrganizationId;
Amount = transaction.Amount;
RefundedAmount = transaction.RefundedAmount;
Refunded = transaction.Refunded.GetValueOrDefault();
Details = transaction.Details;
Date = transaction.CreationDate;
PaymentMethod = transaction.PaymentMethodType;
Gateway = transaction.Gateway;
GatewayId = transaction.GatewayId;
Type = transaction.Type;
}
public CreateUpdateTransactionModel(Transaction transaction) public bool Edit { get; set; }
[Display(Name = "User Id")]
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Required]
public decimal? Amount { get; set; }
[Display(Name = "Refunded Amount")]
public decimal? RefundedAmount { get; set; }
public bool Refunded { get; set; }
[Required]
public string Details { get; set; }
[Required]
public DateTime? Date { get; set; }
[Display(Name = "Payment Method")]
public PaymentMethodType? PaymentMethod { get; set; }
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Id")]
public string GatewayId { get; set; }
[Required]
public TransactionType? Type { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue))
{ {
Edit = true; yield return new ValidationResult("Must provide either User Id, or Organization Id.");
UserId = transaction.UserId;
OrganizationId = transaction.OrganizationId;
Amount = transaction.Amount;
RefundedAmount = transaction.RefundedAmount;
Refunded = transaction.Refunded.GetValueOrDefault();
Details = transaction.Details;
Date = transaction.CreationDate;
PaymentMethod = transaction.PaymentMethodType;
Gateway = transaction.Gateway;
GatewayId = transaction.GatewayId;
Type = transaction.Type;
}
public bool Edit { get; set; }
[Display(Name = "User Id")]
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Required]
public decimal? Amount { get; set; }
[Display(Name = "Refunded Amount")]
public decimal? RefundedAmount { get; set; }
public bool Refunded { get; set; }
[Required]
public string Details { get; set; }
[Required]
public DateTime? Date { get; set; }
[Display(Name = "Payment Method")]
public PaymentMethodType? PaymentMethod { get; set; }
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Id")]
public string GatewayId { get; set; }
[Required]
public TransactionType? Type { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue))
{
yield return new ValidationResult("Must provide either User Id, or Organization Id.");
}
}
public Transaction ToTransaction(Guid? id = null)
{
return new Transaction
{
Id = id.GetValueOrDefault(),
UserId = UserId,
OrganizationId = OrganizationId,
Amount = Amount.Value,
RefundedAmount = RefundedAmount,
Refunded = Refunded ? true : (bool?)null,
Details = Details,
CreationDate = Date.Value,
PaymentMethodType = PaymentMethod,
Gateway = Gateway,
GatewayId = GatewayId,
Type = Type.Value
};
} }
} }
public Transaction ToTransaction(Guid? id = null)
{
return new Transaction
{
Id = id.GetValueOrDefault(),
UserId = UserId,
OrganizationId = OrganizationId,
Amount = Amount.Value,
RefundedAmount = RefundedAmount,
Refunded = Refunded ? true : (bool?)null,
Details = Details,
CreationDate = Date.Value,
PaymentMethodType = PaymentMethod,
Gateway = Gateway,
GatewayId = GatewayId,
Type = Type.Value
};
}
} }

View File

@ -1,10 +1,9 @@
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class CursorPagedModel<T>
{ {
public class CursorPagedModel<T> public List<T> Items { get; set; }
{ public int Count { get; set; }
public List<T> Items { get; set; } public string Cursor { get; set; }
public int Count { get; set; } public string NextCursor { get; set; }
public string Cursor { get; set; }
public string NextCursor { get; set; }
}
} }

View File

@ -1,9 +1,8 @@
namespace Bit.Admin.Models namespace Bit.Admin.Models;
{
public class ErrorViewModel
{
public string RequestId { get; set; }
public bool ShowRequestId => !string.IsNullOrEmpty(RequestId); public class ErrorViewModel
} {
public string RequestId { get; set; }
public bool ShowRequestId => !string.IsNullOrEmpty(RequestId);
} }

View File

@ -1,10 +1,9 @@
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class HomeModel
{ {
public class HomeModel public string CurrentVersion { get; set; }
{ public GlobalSettings GlobalSettings { get; set; }
public string CurrentVersion { get; set; }
public GlobalSettings GlobalSettings { get; set; }
}
} }

View File

@ -1,35 +1,34 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class LicenseModel : IValidatableObject
{ {
public class LicenseModel : IValidatableObject [Display(Name = "User Id")]
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Display(Name = "Installation Id")]
public Guid? InstallationId { get; set; }
[Required]
[Display(Name = "Version")]
public int Version { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
[Display(Name = "User Id")] if (UserId.HasValue && OrganizationId.HasValue)
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Display(Name = "Installation Id")]
public Guid? InstallationId { get; set; }
[Required]
[Display(Name = "Version")]
public int Version { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
if (UserId.HasValue && OrganizationId.HasValue) yield return new ValidationResult("Use either User Id or Organization Id. Not both.");
{ }
yield return new ValidationResult("Use either User Id or Organization Id. Not both.");
}
if (!UserId.HasValue && !OrganizationId.HasValue) if (!UserId.HasValue && !OrganizationId.HasValue)
{ {
yield return new ValidationResult("User Id or Organization Id is required."); yield return new ValidationResult("User Id or Organization Id is required.");
} }
if (OrganizationId.HasValue && !InstallationId.HasValue) if (OrganizationId.HasValue && !InstallationId.HasValue)
{ {
yield return new ValidationResult("Installation Id is required for organization licenses."); yield return new ValidationResult("Installation Id is required for organization licenses.");
}
} }
} }
} }

View File

@ -1,55 +1,54 @@
using Microsoft.Azure.Documents; using Microsoft.Azure.Documents;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class LogModel : Resource
{ {
public class LogModel : Resource public long EventIdHash { get; set; }
{ public string Level { get; set; }
public long EventIdHash { get; set; } public string Message { get; set; }
public string Level { get; set; } public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message;
public string Message { get; set; } public string MessageTemplate { get; set; }
public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message; public IDictionary<string, object> Properties { get; set; }
public string MessageTemplate { get; set; } public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null;
public IDictionary<string, object> Properties { get; set; } }
public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null;
}
public class LogDetailsModel : LogModel public class LogDetailsModel : LogModel
{ {
public JObject Exception { get; set; } public JObject Exception { get; set; }
public string ExceptionToString(JObject e) public string ExceptionToString(JObject e)
{
if (e == null)
{ {
if (e == null) return null;
{
return null;
}
var val = string.Empty;
if (e["Message"] != null && e["Message"].ToObject<string>() != null)
{
val += "Message:\n";
val += e["Message"] + "\n";
}
if (e["StackTrace"] != null && e["StackTrace"].ToObject<string>() != null)
{
val += "\nStack Trace:\n";
val += e["StackTrace"];
}
else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject<string>() != null)
{
val += "\nStack Trace String:\n";
val += e["StackTraceString"];
}
if (e["InnerException"] != null && e["InnerException"].ToObject<JObject>() != null)
{
val += "\n\n=== Inner Exception ===\n\n";
val += ExceptionToString(e["InnerException"].ToObject<JObject>());
}
return val;
} }
var val = string.Empty;
if (e["Message"] != null && e["Message"].ToObject<string>() != null)
{
val += "Message:\n";
val += e["Message"] + "\n";
}
if (e["StackTrace"] != null && e["StackTrace"].ToObject<string>() != null)
{
val += "\nStack Trace:\n";
val += e["StackTrace"];
}
else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject<string>() != null)
{
val += "\nStack Trace String:\n";
val += e["StackTraceString"];
}
if (e["InnerException"] != null && e["InnerException"].ToObject<JObject>() != null)
{
val += "\n\n=== Inner Exception ===\n\n";
val += ExceptionToString(e["InnerException"].ToObject<JObject>());
}
return val;
} }
} }

View File

@ -1,14 +1,13 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class LoginModel
{ {
public class LoginModel [Required]
{ [EmailAddress]
[Required] public string Email { get; set; }
[EmailAddress] public string ReturnUrl { get; set; }
public string Email { get; set; } public string Error { get; set; }
public string ReturnUrl { get; set; } public string Success { get; set; }
public string Error { get; set; }
public string Success { get; set; }
}
} }

View File

@ -1,12 +1,11 @@
using Serilog.Events; using Serilog.Events;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class LogsModel : CursorPagedModel<LogModel>
{ {
public class LogsModel : CursorPagedModel<LogModel> public LogEventLevel? Level { get; set; }
{ public string Project { get; set; }
public LogEventLevel? Level { get; set; } public DateTime? Start { get; set; }
public string Project { get; set; } public DateTime? End { get; set; }
public DateTime? Start { get; set; }
public DateTime? End { get; set; }
}
} }

View File

@ -6,148 +6,147 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class OrganizationEditModel : OrganizationViewModel
{ {
public class OrganizationEditModel : OrganizationViewModel public OrganizationEditModel() { }
public OrganizationEditModel(Organization org, IEnumerable<OrganizationUserUserDetails> orgUsers,
IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections, IEnumerable<Group> groups,
IEnumerable<Policy> policies, BillingInfo billingInfo, IEnumerable<OrganizationConnection> connections,
GlobalSettings globalSettings)
: base(org, connections, orgUsers, ciphers, collections, groups, policies)
{ {
public OrganizationEditModel() { } BillingInfo = billingInfo;
BraintreeMerchantId = globalSettings.Braintree.MerchantId;
public OrganizationEditModel(Organization org, IEnumerable<OrganizationUserUserDetails> orgUsers, Name = org.Name;
IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections, IEnumerable<Group> groups, BusinessName = org.BusinessName;
IEnumerable<Policy> policies, BillingInfo billingInfo, IEnumerable<OrganizationConnection> connections, BillingEmail = org.BillingEmail;
GlobalSettings globalSettings) PlanType = org.PlanType;
: base(org, connections, orgUsers, ciphers, collections, groups, policies) Plan = org.Plan;
{ Seats = org.Seats;
BillingInfo = billingInfo; MaxAutoscaleSeats = org.MaxAutoscaleSeats;
BraintreeMerchantId = globalSettings.Braintree.MerchantId; MaxCollections = org.MaxCollections;
UsePolicies = org.UsePolicies;
UseSso = org.UseSso;
UseKeyConnector = org.UseKeyConnector;
UseScim = org.UseScim;
UseGroups = org.UseGroups;
UseDirectory = org.UseDirectory;
UseEvents = org.UseEvents;
UseTotp = org.UseTotp;
Use2fa = org.Use2fa;
UseApi = org.UseApi;
UseResetPassword = org.UseResetPassword;
SelfHost = org.SelfHost;
UsersGetPremium = org.UsersGetPremium;
MaxStorageGb = org.MaxStorageGb;
Gateway = org.Gateway;
GatewayCustomerId = org.GatewayCustomerId;
GatewaySubscriptionId = org.GatewaySubscriptionId;
Enabled = org.Enabled;
LicenseKey = org.LicenseKey;
ExpirationDate = org.ExpirationDate;
}
Name = org.Name; public BillingInfo BillingInfo { get; set; }
BusinessName = org.BusinessName; public string RandomLicenseKey => CoreHelpers.SecureRandomString(20);
BillingEmail = org.BillingEmail; public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm");
PlanType = org.PlanType; public string BraintreeMerchantId { get; set; }
Plan = org.Plan;
Seats = org.Seats;
MaxAutoscaleSeats = org.MaxAutoscaleSeats;
MaxCollections = org.MaxCollections;
UsePolicies = org.UsePolicies;
UseSso = org.UseSso;
UseKeyConnector = org.UseKeyConnector;
UseScim = org.UseScim;
UseGroups = org.UseGroups;
UseDirectory = org.UseDirectory;
UseEvents = org.UseEvents;
UseTotp = org.UseTotp;
Use2fa = org.Use2fa;
UseApi = org.UseApi;
UseResetPassword = org.UseResetPassword;
SelfHost = org.SelfHost;
UsersGetPremium = org.UsersGetPremium;
MaxStorageGb = org.MaxStorageGb;
Gateway = org.Gateway;
GatewayCustomerId = org.GatewayCustomerId;
GatewaySubscriptionId = org.GatewaySubscriptionId;
Enabled = org.Enabled;
LicenseKey = org.LicenseKey;
ExpirationDate = org.ExpirationDate;
}
public BillingInfo BillingInfo { get; set; } [Required]
public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); [Display(Name = "Name")]
public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm"); public string Name { get; set; }
public string BraintreeMerchantId { get; set; } [Display(Name = "Business Name")]
public string BusinessName { get; set; }
[Display(Name = "Billing Email")]
public string BillingEmail { get; set; }
[Required]
[Display(Name = "Plan")]
public PlanType? PlanType { get; set; }
[Required]
[Display(Name = "Plan Name")]
public string Plan { get; set; }
[Display(Name = "Seats")]
public int? Seats { get; set; }
[Display(Name = "Max. Autoscale Seats")]
public int? MaxAutoscaleSeats { get; set; }
[Display(Name = "Max. Collections")]
public short? MaxCollections { get; set; }
[Display(Name = "Policies")]
public bool UsePolicies { get; set; }
[Display(Name = "SSO")]
public bool UseSso { get; set; }
[Display(Name = "Key Connector with Customer Encryption")]
public bool UseKeyConnector { get; set; }
[Display(Name = "Groups")]
public bool UseGroups { get; set; }
[Display(Name = "Directory")]
public bool UseDirectory { get; set; }
[Display(Name = "Events")]
public bool UseEvents { get; set; }
[Display(Name = "TOTP")]
public bool UseTotp { get; set; }
[Display(Name = "2FA")]
public bool Use2fa { get; set; }
[Display(Name = "API")]
public bool UseApi { get; set; }
[Display(Name = "Reset Password")]
public bool UseResetPassword { get; set; }
[Display(Name = "SCIM")]
public bool UseScim { get; set; }
[Display(Name = "Self Host")]
public bool SelfHost { get; set; }
[Display(Name = "Users Get Premium")]
public bool UsersGetPremium { get; set; }
[Display(Name = "Max. Storage GB")]
public short? MaxStorageGb { get; set; }
[Display(Name = "Gateway")]
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Customer Id")]
public string GatewayCustomerId { get; set; }
[Display(Name = "Gateway Subscription Id")]
public string GatewaySubscriptionId { get; set; }
[Display(Name = "Enabled")]
public bool Enabled { get; set; }
[Display(Name = "License Key")]
public string LicenseKey { get; set; }
[Display(Name = "Expiration Date")]
public DateTime? ExpirationDate { get; set; }
public bool SalesAssistedTrialStarted { get; set; }
[Required] public Organization ToOrganization(Organization existingOrganization)
[Display(Name = "Name")] {
public string Name { get; set; } existingOrganization.Name = Name;
[Display(Name = "Business Name")] existingOrganization.BusinessName = BusinessName;
public string BusinessName { get; set; } existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim();
[Display(Name = "Billing Email")] existingOrganization.PlanType = PlanType.Value;
public string BillingEmail { get; set; } existingOrganization.Plan = Plan;
[Required] existingOrganization.Seats = Seats;
[Display(Name = "Plan")] existingOrganization.MaxCollections = MaxCollections;
public PlanType? PlanType { get; set; } existingOrganization.UsePolicies = UsePolicies;
[Required] existingOrganization.UseSso = UseSso;
[Display(Name = "Plan Name")] existingOrganization.UseKeyConnector = UseKeyConnector;
public string Plan { get; set; } existingOrganization.UseScim = UseScim;
[Display(Name = "Seats")] existingOrganization.UseGroups = UseGroups;
public int? Seats { get; set; } existingOrganization.UseDirectory = UseDirectory;
[Display(Name = "Max. Autoscale Seats")] existingOrganization.UseEvents = UseEvents;
public int? MaxAutoscaleSeats { get; set; } existingOrganization.UseTotp = UseTotp;
[Display(Name = "Max. Collections")] existingOrganization.Use2fa = Use2fa;
public short? MaxCollections { get; set; } existingOrganization.UseApi = UseApi;
[Display(Name = "Policies")] existingOrganization.UseResetPassword = UseResetPassword;
public bool UsePolicies { get; set; } existingOrganization.SelfHost = SelfHost;
[Display(Name = "SSO")] existingOrganization.UsersGetPremium = UsersGetPremium;
public bool UseSso { get; set; } existingOrganization.MaxStorageGb = MaxStorageGb;
[Display(Name = "Key Connector with Customer Encryption")] existingOrganization.Gateway = Gateway;
public bool UseKeyConnector { get; set; } existingOrganization.GatewayCustomerId = GatewayCustomerId;
[Display(Name = "Groups")] existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId;
public bool UseGroups { get; set; } existingOrganization.Enabled = Enabled;
[Display(Name = "Directory")] existingOrganization.LicenseKey = LicenseKey;
public bool UseDirectory { get; set; } existingOrganization.ExpirationDate = ExpirationDate;
[Display(Name = "Events")] existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats;
public bool UseEvents { get; set; } return existingOrganization;
[Display(Name = "TOTP")]
public bool UseTotp { get; set; }
[Display(Name = "2FA")]
public bool Use2fa { get; set; }
[Display(Name = "API")]
public bool UseApi { get; set; }
[Display(Name = "Reset Password")]
public bool UseResetPassword { get; set; }
[Display(Name = "SCIM")]
public bool UseScim { get; set; }
[Display(Name = "Self Host")]
public bool SelfHost { get; set; }
[Display(Name = "Users Get Premium")]
public bool UsersGetPremium { get; set; }
[Display(Name = "Max. Storage GB")]
public short? MaxStorageGb { get; set; }
[Display(Name = "Gateway")]
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Customer Id")]
public string GatewayCustomerId { get; set; }
[Display(Name = "Gateway Subscription Id")]
public string GatewaySubscriptionId { get; set; }
[Display(Name = "Enabled")]
public bool Enabled { get; set; }
[Display(Name = "License Key")]
public string LicenseKey { get; set; }
[Display(Name = "Expiration Date")]
public DateTime? ExpirationDate { get; set; }
public bool SalesAssistedTrialStarted { get; set; }
public Organization ToOrganization(Organization existingOrganization)
{
existingOrganization.Name = Name;
existingOrganization.BusinessName = BusinessName;
existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim();
existingOrganization.PlanType = PlanType.Value;
existingOrganization.Plan = Plan;
existingOrganization.Seats = Seats;
existingOrganization.MaxCollections = MaxCollections;
existingOrganization.UsePolicies = UsePolicies;
existingOrganization.UseSso = UseSso;
existingOrganization.UseKeyConnector = UseKeyConnector;
existingOrganization.UseScim = UseScim;
existingOrganization.UseGroups = UseGroups;
existingOrganization.UseDirectory = UseDirectory;
existingOrganization.UseEvents = UseEvents;
existingOrganization.UseTotp = UseTotp;
existingOrganization.Use2fa = Use2fa;
existingOrganization.UseApi = UseApi;
existingOrganization.UseResetPassword = UseResetPassword;
existingOrganization.SelfHost = SelfHost;
existingOrganization.UsersGetPremium = UsersGetPremium;
existingOrganization.MaxStorageGb = MaxStorageGb;
existingOrganization.Gateway = Gateway;
existingOrganization.GatewayCustomerId = GatewayCustomerId;
existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId;
existingOrganization.Enabled = Enabled;
existingOrganization.LicenseKey = LicenseKey;
existingOrganization.ExpirationDate = ExpirationDate;
existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats;
return existingOrganization;
}
} }
} }

View File

@ -2,49 +2,48 @@
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.Data.Organizations.OrganizationUsers;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class OrganizationViewModel
{ {
public class OrganizationViewModel public OrganizationViewModel() { }
public OrganizationViewModel(Organization org, IEnumerable<OrganizationConnection> connections,
IEnumerable<OrganizationUserUserDetails> orgUsers, IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections,
IEnumerable<Group> groups, IEnumerable<Policy> policies)
{ {
public OrganizationViewModel() { } Organization = org;
Connections = connections ?? Enumerable.Empty<OrganizationConnection>();
public OrganizationViewModel(Organization org, IEnumerable<OrganizationConnection> connections, HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null;
IEnumerable<OrganizationUserUserDetails> orgUsers, IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections, UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited);
IEnumerable<Group> groups, IEnumerable<Policy> policies) UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted);
{ UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed);
Organization = org; UserCount = orgUsers.Count();
Connections = connections ?? Enumerable.Empty<OrganizationConnection>(); CipherCount = ciphers.Count();
HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null; CollectionCount = collections.Count();
UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited); GroupCount = groups?.Count() ?? 0;
UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted); PolicyCount = policies?.Count() ?? 0;
UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed); Owners = string.Join(", ",
UserCount = orgUsers.Count(); orgUsers
CipherCount = ciphers.Count(); .Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed)
CollectionCount = collections.Count(); .Select(u => u.Email));
GroupCount = groups?.Count() ?? 0; Admins = string.Join(", ",
PolicyCount = policies?.Count() ?? 0; orgUsers
Owners = string.Join(", ", .Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed)
orgUsers .Select(u => u.Email));
.Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed)
.Select(u => u.Email));
Admins = string.Join(", ",
orgUsers
.Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed)
.Select(u => u.Email));
}
public Organization Organization { get; set; }
public IEnumerable<OrganizationConnection> Connections { get; set; }
public string Owners { get; set; }
public string Admins { get; set; }
public int UserInvitedCount { get; set; }
public int UserConfirmedCount { get; set; }
public int UserAcceptedCount { get; set; }
public int UserCount { get; set; }
public int CipherCount { get; set; }
public int CollectionCount { get; set; }
public int GroupCount { get; set; }
public int PolicyCount { get; set; }
public bool HasPublicPrivateKeys { get; set; }
} }
public Organization Organization { get; set; }
public IEnumerable<OrganizationConnection> Connections { get; set; }
public string Owners { get; set; }
public string Admins { get; set; }
public int UserInvitedCount { get; set; }
public int UserConfirmedCount { get; set; }
public int UserAcceptedCount { get; set; }
public int UserCount { get; set; }
public int CipherCount { get; set; }
public int CollectionCount { get; set; }
public int GroupCount { get; set; }
public int PolicyCount { get; set; }
public bool HasPublicPrivateKeys { get; set; }
} }

View File

@ -1,13 +1,12 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class OrganizationsModel : PagedModel<Organization>
{ {
public class OrganizationsModel : PagedModel<Organization> public string Name { get; set; }
{ public string UserEmail { get; set; }
public string Name { get; set; } public bool? Paid { get; set; }
public string UserEmail { get; set; } public string Action { get; set; }
public bool? Paid { get; set; } public bool SelfHosted { get; set; }
public string Action { get; set; }
public bool SelfHosted { get; set; }
}
} }

View File

@ -1,11 +1,10 @@
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public abstract class PagedModel<T>
{ {
public abstract class PagedModel<T> public List<T> Items { get; set; }
{ public int Page { get; set; }
public List<T> Items { get; set; } public int Count { get; set; }
public int Page { get; set; } public int? PreviousPage => Page < 2 ? (int?)null : Page - 1;
public int Count { get; set; } public int? NextPage => Items.Count < Count ? (int?)null : Page + 1;
public int? PreviousPage => Page < 2 ? (int?)null : Page - 1;
public int? NextPage => Items.Count < Count ? (int?)null : Page + 1;
}
} }

View File

@ -1,14 +1,13 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class PromoteAdminModel
{ {
public class PromoteAdminModel [Required]
{ [Display(Name = "Admin User Id")]
[Required] public Guid? UserId { get; set; }
[Display(Name = "Admin User Id")] [Required]
public Guid? UserId { get; set; } [Display(Name = "Organization Id")]
[Required] public Guid? OrganizationId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
}
} }

View File

@ -2,33 +2,32 @@
using Bit.Core.Entities.Provider; using Bit.Core.Entities.Provider;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class ProviderEditModel : ProviderViewModel
{ {
public class ProviderEditModel : ProviderViewModel public ProviderEditModel() { }
public ProviderEditModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
: base(provider, providerUsers, organizations)
{ {
public ProviderEditModel() { } Name = provider.Name;
BusinessName = provider.BusinessName;
BillingEmail = provider.BillingEmail;
}
public ProviderEditModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations) [Display(Name = "Billing Email")]
: base(provider, providerUsers, organizations) public string BillingEmail { get; set; }
{ [Display(Name = "Business Name")]
Name = provider.Name; public string BusinessName { get; set; }
BusinessName = provider.BusinessName; public string Name { get; set; }
BillingEmail = provider.BillingEmail; [Display(Name = "Events")]
}
[Display(Name = "Billing Email")] public Provider ToProvider(Provider existingProvider)
public string BillingEmail { get; set; } {
[Display(Name = "Business Name")] existingProvider.Name = Name;
public string BusinessName { get; set; } existingProvider.BusinessName = BusinessName;
public string Name { get; set; } existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim();
[Display(Name = "Events")] return existingProvider;
public Provider ToProvider(Provider existingProvider)
{
existingProvider.Name = Name;
existingProvider.BusinessName = BusinessName;
existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim();
return existingProvider;
}
} }
} }

View File

@ -2,24 +2,23 @@
using Bit.Core.Enums.Provider; using Bit.Core.Enums.Provider;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class ProviderViewModel
{ {
public class ProviderViewModel public ProviderViewModel() { }
public ProviderViewModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
{ {
public ProviderViewModel() { } Provider = provider;
UserCount = providerUsers.Count();
ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin);
public ProviderViewModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations) ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id);
{
Provider = provider;
UserCount = providerUsers.Count();
ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin);
ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id);
}
public int UserCount { get; set; }
public Provider Provider { get; set; }
public IEnumerable<ProviderUserUserDetails> ProviderAdmins { get; set; }
public IEnumerable<ProviderOrganizationOrganizationDetails> ProviderOrganizations { get; set; }
} }
public int UserCount { get; set; }
public Provider Provider { get; set; }
public IEnumerable<ProviderUserUserDetails> ProviderAdmins { get; set; }
public IEnumerable<ProviderOrganizationOrganizationDetails> ProviderOrganizations { get; set; }
} }

View File

@ -1,13 +1,12 @@
using Bit.Core.Entities.Provider; using Bit.Core.Entities.Provider;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class ProvidersModel : PagedModel<Provider>
{ {
public class ProvidersModel : PagedModel<Provider> public string Name { get; set; }
{ public string UserEmail { get; set; }
public string Name { get; set; } public bool? Paid { get; set; }
public string UserEmail { get; set; } public string Action { get; set; }
public bool? Paid { get; set; } public bool SelfHosted { get; set; }
public string Action { get; set; }
public bool SelfHosted { get; set; }
}
} }

View File

@ -1,43 +1,42 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using Bit.Core.Models.BitStripe; using Bit.Core.Models.BitStripe;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class StripeSubscriptionRowModel
{ {
public class StripeSubscriptionRowModel public Stripe.Subscription Subscription { get; set; }
{ public bool Selected { get; set; }
public Stripe.Subscription Subscription { get; set; }
public bool Selected { get; set; }
public StripeSubscriptionRowModel() { } public StripeSubscriptionRowModel() { }
public StripeSubscriptionRowModel(Stripe.Subscription subscription) public StripeSubscriptionRowModel(Stripe.Subscription subscription)
{ {
Subscription = subscription; Subscription = subscription;
}
} }
}
public enum StripeSubscriptionsAction public enum StripeSubscriptionsAction
{ {
Search, Search,
PreviousPage, PreviousPage,
NextPage, NextPage,
Export, Export,
BulkCancel BulkCancel
} }
public class StripeSubscriptionsModel : IValidatableObject public class StripeSubscriptionsModel : IValidatableObject
{
public List<StripeSubscriptionRowModel> Items { get; set; }
public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search;
public string Message { get; set; }
public List<Stripe.Price> Prices { get; set; }
public List<Stripe.TestHelpers.TestClock> TestClocks { get; set; }
public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions();
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
public List<StripeSubscriptionRowModel> Items { get; set; } if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid")
public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search;
public string Message { get; set; }
public List<Stripe.Price> Prices { get; set; }
public List<Stripe.TestHelpers.TestClock> TestClocks { get; set; }
public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions();
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions");
{
yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions");
}
} }
} }
} }

View File

@ -1,11 +1,10 @@
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class TaxRateAddEditModel
{ {
public class TaxRateAddEditModel public string StripeTaxRateId { get; set; }
{ public string Country { get; set; }
public string StripeTaxRateId { get; set; } public string State { get; set; }
public string Country { get; set; } public string PostalCode { get; set; }
public string State { get; set; } public decimal Rate { get; set; }
public string PostalCode { get; set; }
public decimal Rate { get; set; }
}
} }

View File

@ -1,9 +1,8 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class TaxRatesModel : PagedModel<TaxRate>
{ {
public class TaxRatesModel : PagedModel<TaxRate> public string Message { get; set; }
{
public string Message { get; set; }
}
} }

View File

@ -4,71 +4,70 @@ using Bit.Core.Models.Business;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class UserEditModel : UserViewModel
{ {
public class UserEditModel : UserViewModel public UserEditModel() { }
public UserEditModel(User user, IEnumerable<Cipher> ciphers, BillingInfo billingInfo,
GlobalSettings globalSettings)
: base(user, ciphers)
{ {
public UserEditModel() { } BillingInfo = billingInfo;
BraintreeMerchantId = globalSettings.Braintree.MerchantId;
public UserEditModel(User user, IEnumerable<Cipher> ciphers, BillingInfo billingInfo, Name = user.Name;
GlobalSettings globalSettings) Email = user.Email;
: base(user, ciphers) EmailVerified = user.EmailVerified;
{ Premium = user.Premium;
BillingInfo = billingInfo; MaxStorageGb = user.MaxStorageGb;
BraintreeMerchantId = globalSettings.Braintree.MerchantId; Gateway = user.Gateway;
GatewayCustomerId = user.GatewayCustomerId;
GatewaySubscriptionId = user.GatewaySubscriptionId;
LicenseKey = user.LicenseKey;
PremiumExpirationDate = user.PremiumExpirationDate;
}
Name = user.Name; public BillingInfo BillingInfo { get; set; }
Email = user.Email; public string RandomLicenseKey => CoreHelpers.SecureRandomString(20);
EmailVerified = user.EmailVerified; public string OneYearExpirationDate => DateTime.Now.AddYears(1).ToString("yyyy-MM-ddTHH:mm");
Premium = user.Premium; public string BraintreeMerchantId { get; set; }
MaxStorageGb = user.MaxStorageGb;
Gateway = user.Gateway;
GatewayCustomerId = user.GatewayCustomerId;
GatewaySubscriptionId = user.GatewaySubscriptionId;
LicenseKey = user.LicenseKey;
PremiumExpirationDate = user.PremiumExpirationDate;
}
public BillingInfo BillingInfo { get; set; } [Display(Name = "Name")]
public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); public string Name { get; set; }
public string OneYearExpirationDate => DateTime.Now.AddYears(1).ToString("yyyy-MM-ddTHH:mm"); [Required]
public string BraintreeMerchantId { get; set; } [Display(Name = "Email")]
public string Email { get; set; }
[Display(Name = "Email Verified")]
public bool EmailVerified { get; set; }
[Display(Name = "Premium")]
public bool Premium { get; set; }
[Display(Name = "Max. Storage GB")]
public short? MaxStorageGb { get; set; }
[Display(Name = "Gateway")]
public Core.Enums.GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Customer Id")]
public string GatewayCustomerId { get; set; }
[Display(Name = "Gateway Subscription Id")]
public string GatewaySubscriptionId { get; set; }
[Display(Name = "License Key")]
public string LicenseKey { get; set; }
[Display(Name = "Premium Expiration Date")]
public DateTime? PremiumExpirationDate { get; set; }
[Display(Name = "Name")] public User ToUser(User existingUser)
public string Name { get; set; } {
[Required] existingUser.Name = Name;
[Display(Name = "Email")] existingUser.Email = Email;
public string Email { get; set; } existingUser.EmailVerified = EmailVerified;
[Display(Name = "Email Verified")] existingUser.Premium = Premium;
public bool EmailVerified { get; set; } existingUser.MaxStorageGb = MaxStorageGb;
[Display(Name = "Premium")] existingUser.Gateway = Gateway;
public bool Premium { get; set; } existingUser.GatewayCustomerId = GatewayCustomerId;
[Display(Name = "Max. Storage GB")] existingUser.GatewaySubscriptionId = GatewaySubscriptionId;
public short? MaxStorageGb { get; set; } existingUser.LicenseKey = LicenseKey;
[Display(Name = "Gateway")] existingUser.PremiumExpirationDate = PremiumExpirationDate;
public Core.Enums.GatewayType? Gateway { get; set; } return existingUser;
[Display(Name = "Gateway Customer Id")]
public string GatewayCustomerId { get; set; }
[Display(Name = "Gateway Subscription Id")]
public string GatewaySubscriptionId { get; set; }
[Display(Name = "License Key")]
public string LicenseKey { get; set; }
[Display(Name = "Premium Expiration Date")]
public DateTime? PremiumExpirationDate { get; set; }
public User ToUser(User existingUser)
{
existingUser.Name = Name;
existingUser.Email = Email;
existingUser.EmailVerified = EmailVerified;
existingUser.Premium = Premium;
existingUser.MaxStorageGb = MaxStorageGb;
existingUser.Gateway = Gateway;
existingUser.GatewayCustomerId = GatewayCustomerId;
existingUser.GatewaySubscriptionId = GatewaySubscriptionId;
existingUser.LicenseKey = LicenseKey;
existingUser.PremiumExpirationDate = PremiumExpirationDate;
return existingUser;
}
} }
} }

View File

@ -1,18 +1,17 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Admin.Models namespace Bit.Admin.Models;
public class UserViewModel
{ {
public class UserViewModel public UserViewModel() { }
public UserViewModel(User user, IEnumerable<Cipher> ciphers)
{ {
public UserViewModel() { } User = user;
CipherCount = ciphers.Count();
public UserViewModel(User user, IEnumerable<Cipher> ciphers)
{
User = user;
CipherCount = ciphers.Count();
}
public User User { get; set; }
public int CipherCount { get; set; }
} }
public User User { get; set; }
public int CipherCount { get; set; }
} }

Some files were not shown because too many files have changed in this diff Show More