diff --git a/src/Api/Controllers/PushController.cs b/src/Api/Controllers/PushController.cs index c83eb200b..383980510 100644 --- a/src/Api/Controllers/PushController.cs +++ b/src/Api/Controllers/PushController.cs @@ -46,7 +46,7 @@ public class PushController : Controller public async Task PostDelete([FromBody] PushDeviceRequestModel model) { CheckUsage(); - await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id), model.Type); + await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id)); } [HttpPut("add-organization")] @@ -54,7 +54,7 @@ public class PushController : Controller { CheckUsage(); await _pushRegistrationService.AddUserRegistrationOrganizationAsync( - model.Devices.Select(d => new KeyValuePair(Prefix(d.Id), d.Type)), + model.Devices.Select(d => Prefix(d.Id)), Prefix(model.OrganizationId)); } @@ -63,7 +63,7 @@ public class PushController : Controller { CheckUsage(); await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync( - model.Devices.Select(d => new KeyValuePair(Prefix(d.Id), d.Type)), + model.Devices.Select(d => Prefix(d.Id)), Prefix(model.OrganizationId)); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs index 09444306e..e6d56ea87 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs @@ -162,12 +162,12 @@ public class RemoveOrganizationUserCommand : IRemoveOrganizationUserCommand } } - private async Task>> GetUserDeviceIdsAsync(Guid userId) + private async Task> GetUserDeviceIdsAsync(Guid userId) { var devices = await _deviceRepository.GetManyByUserIdAsync(userId); return devices .Where(d => !string.IsNullOrWhiteSpace(d.PushToken)) - .Select(d => new KeyValuePair(d.Id.ToString(), d.Type)); + .Select(d => d.Id.ToString()); } private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId) diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 50a2ed84e..f44ce686f 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -1838,12 +1838,12 @@ public class OrganizationService : IOrganizationService } - private async Task>> GetUserDeviceIdsAsync(Guid userId) + private async Task> GetUserDeviceIdsAsync(Guid userId) { var devices = await _deviceRepository.GetManyByUserIdAsync(userId); return devices .Where(d => !string.IsNullOrWhiteSpace(d.PushToken)) - .Select(d => new KeyValuePair(d.Id.ToString(), d.Type)); + .Select(d => d.Id.ToString()); } public async Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null) diff --git a/src/Core/Models/Api/Request/PushDeviceRequestModel.cs b/src/Core/Models/Api/Request/PushDeviceRequestModel.cs index e1866b6f2..8b97dcc36 100644 --- a/src/Core/Models/Api/Request/PushDeviceRequestModel.cs +++ b/src/Core/Models/Api/Request/PushDeviceRequestModel.cs @@ -1,5 +1,4 @@ using System.ComponentModel.DataAnnotations; -using Bit.Core.Enums; namespace Bit.Core.Models.Api; @@ -7,6 +6,4 @@ public class PushDeviceRequestModel { [Required] public string Id { get; set; } - [Required] - public DeviceType Type { get; set; } } diff --git a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs index 9f7ed5f28..f8c2d296f 100644 --- a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs +++ b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs @@ -1,5 +1,4 @@ using System.ComponentModel.DataAnnotations; -using Bit.Core.Enums; namespace Bit.Core.Models.Api; @@ -8,9 +7,9 @@ public class PushUpdateRequestModel public PushUpdateRequestModel() { } - public PushUpdateRequestModel(IEnumerable> devices, string organizationId) + public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId) { - Devices = devices.Select(d => new PushDeviceRequestModel { Id = d.Key, Type = d.Value }); + Devices = deviceIds.Select(d => new PushDeviceRequestModel { Id = d }); OrganizationId = organizationId; } diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs index 3186efc66..a3d960b24 100644 --- a/src/Core/Models/Data/InstallationDeviceEntity.cs +++ b/src/Core/Models/Data/InstallationDeviceEntity.cs @@ -37,4 +37,25 @@ public class InstallationDeviceEntity : ITableEntity { return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; } + public static bool TryParse(string deviceId, out InstallationDeviceEntity installationDeviceEntity) + { + installationDeviceEntity = null; + var installationId = Guid.Empty; + var deviceIdGuid = Guid.Empty; + if (!IsInstallationDeviceId(deviceId)) + { + return false; + } + var parts = deviceId.Split("_"); + if (parts.Length < 2) + { + return false; + } + if (!Guid.TryParse(parts[0], out installationId) || !Guid.TryParse(parts[1], out deviceIdGuid)) + { + return false; + } + installationDeviceEntity = new InstallationDeviceEntity(installationId, deviceIdGuid); + return true; + } } diff --git a/src/Core/NotificationHub/INotificationHubClientProxy.cs b/src/Core/NotificationHub/INotificationHubClientProxy.cs new file mode 100644 index 000000000..82b4d3959 --- /dev/null +++ b/src/Core/NotificationHub/INotificationHubClientProxy.cs @@ -0,0 +1,8 @@ +using Microsoft.Azure.NotificationHubs; + +namespace Bit.Core.NotificationHub; + +public interface INotificationHubProxy +{ + Task<(INotificationHubClient Client, NotificationOutcome Outcome)[]> SendTemplateNotificationAsync(IDictionary properties, string tagExpression); +} diff --git a/src/Core/NotificationHub/INotificationHubPool.cs b/src/Core/NotificationHub/INotificationHubPool.cs new file mode 100644 index 000000000..7c383d7b9 --- /dev/null +++ b/src/Core/NotificationHub/INotificationHubPool.cs @@ -0,0 +1,9 @@ +using Microsoft.Azure.NotificationHubs; + +namespace Bit.Core.NotificationHub; + +public interface INotificationHubPool +{ + NotificationHubClient ClientFor(Guid comb); + INotificationHubProxy AllClients { get; } +} diff --git a/src/Core/NotificationHub/NotificationHubClientProxy.cs b/src/Core/NotificationHub/NotificationHubClientProxy.cs new file mode 100644 index 000000000..815ac8836 --- /dev/null +++ b/src/Core/NotificationHub/NotificationHubClientProxy.cs @@ -0,0 +1,26 @@ +using Microsoft.Azure.NotificationHubs; + +namespace Bit.Core.NotificationHub; + +public class NotificationHubClientProxy : INotificationHubProxy +{ + private readonly IEnumerable _clients; + + public NotificationHubClientProxy(IEnumerable clients) + { + _clients = clients; + } + + private async Task<(INotificationHubClient, T)[]> ApplyToAllClientsAsync(Func> action) + { + var tasks = _clients.Select(async c => (c, await action(c))); + return await Task.WhenAll(tasks); + } + + // partial proxy of INotificationHubClient implementation + // Note: Any other methods that are needed can simply be delegated as done here. + public async Task<(INotificationHubClient Client, NotificationOutcome Outcome)[]> SendTemplateNotificationAsync(IDictionary properties, string tagExpression) + { + return await ApplyToAllClientsAsync(async c => await c.SendTemplateNotificationAsync(properties, tagExpression)); + } +} diff --git a/src/Core/NotificationHub/NotificationHubConnection.cs b/src/Core/NotificationHub/NotificationHubConnection.cs new file mode 100644 index 000000000..3a1437f70 --- /dev/null +++ b/src/Core/NotificationHub/NotificationHubConnection.cs @@ -0,0 +1,128 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Azure.NotificationHubs; + +class NotificationHubConnection +{ + public string HubName { get; init; } + public string ConnectionString { get; init; } + public bool EnableSendTracing { get; init; } + private NotificationHubClient _hubClient; + /// + /// Gets the NotificationHubClient for this connection. + /// + /// If the client is null, it will be initialized. + /// + /// Exception if the connection is invalid. + /// + public NotificationHubClient HubClient + { + get + { + if (_hubClient == null) + { + if (!IsValid) + { + throw new Exception("Invalid notification hub settings"); + } + Init(); + } + return _hubClient; + } + private set + { + _hubClient = value; + } + } + /// + /// Gets the start date for registration. + /// + /// If null, registration is always disabled. + /// + public DateTime? RegistrationStartDate { get; init; } + /// + /// Gets the end date for registration. + /// + /// If null, registration has no end date. + /// + public DateTime? RegistrationEndDate { get; init; } + /// + /// Gets whether all data needed to generate a connection to Notification Hub is present. + /// + public bool IsValid + { + get + { + { + var invalid = string.IsNullOrWhiteSpace(HubName) || string.IsNullOrWhiteSpace(ConnectionString); + return !invalid; + } + } + } + + public string LogString + { + get + { + return $"HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}"; + } + } + + /// + /// Gets whether registration is enabled for the given comb ID. + /// This is based off of the generation time encoded in the comb ID. + /// + /// + /// + public bool RegistrationEnabled(Guid comb) + { + var combTime = CoreHelpers.DateFromComb(comb); + return RegistrationEnabled(combTime); + } + + /// + /// Gets whether registration is enabled for the given time. + /// + /// The time to check + /// + public bool RegistrationEnabled(DateTime queryTime) + { + if (queryTime >= RegistrationEndDate || RegistrationStartDate == null) + { + return false; + } + + return RegistrationStartDate < queryTime; + } + + private NotificationHubConnection() { } + + /// + /// Creates a new NotificationHubConnection from the given settings. + /// + /// + /// + public static NotificationHubConnection From(GlobalSettings.NotificationHubSettings settings) + { + return new() + { + HubName = settings.HubName, + ConnectionString = settings.ConnectionString, + EnableSendTracing = settings.EnableSendTracing, + // Comb time is not precise enough for millisecond accuracy + RegistrationStartDate = settings.RegistrationStartDate.HasValue ? Truncate(settings.RegistrationStartDate.Value, TimeSpan.FromMilliseconds(10)) : null, + RegistrationEndDate = settings.RegistrationEndDate + }; + } + + private NotificationHubConnection Init() + { + HubClient = NotificationHubClient.CreateClientFromConnectionString(ConnectionString, HubName, EnableSendTracing); + return this; + } + + private static DateTime Truncate(DateTime dateTime, TimeSpan resolution) + { + return dateTime.AddTicks(-(dateTime.Ticks % resolution.Ticks)); + } +} diff --git a/src/Core/NotificationHub/NotificationHubPool.cs b/src/Core/NotificationHub/NotificationHubPool.cs new file mode 100644 index 000000000..7448aad5b --- /dev/null +++ b/src/Core/NotificationHub/NotificationHubPool.cs @@ -0,0 +1,62 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Azure.NotificationHubs; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.NotificationHub; + +public class NotificationHubPool : INotificationHubPool +{ + private List _connections { get; } + private readonly IEnumerable _clients; + private readonly ILogger _logger; + public NotificationHubPool(ILogger logger, GlobalSettings globalSettings) + { + _logger = logger; + _connections = FilterInvalidHubs(globalSettings.NotificationHubPool.NotificationHubs); + _clients = _connections.GroupBy(c => c.ConnectionString).Select(g => g.First().HubClient); + } + + private List FilterInvalidHubs(IEnumerable hubs) + { + List result = new(); + _logger.LogDebug("Filtering {HubCount} notification hubs", hubs.Count()); + foreach (var hub in hubs) + { + var connection = NotificationHubConnection.From(hub); + if (!connection.IsValid) + { + _logger.LogWarning("Invalid notification hub settings: {HubName}", hub.HubName ?? "hub name missing"); + continue; + } + _logger.LogDebug("Adding notification hub: {ConnectionLogString}", connection.LogString); + result.Add(connection); + } + + return result; + } + + + /// + /// Gets the NotificationHubClient for the given comb ID. + /// + /// + /// + /// Thrown when no notification hub is found for a given comb. + public NotificationHubClient ClientFor(Guid comb) + { + var possibleConnections = _connections.Where(c => c.RegistrationEnabled(comb)).ToArray(); + if (possibleConnections.Length == 0) + { + throw new InvalidOperationException($"No valid notification hubs are available for the given comb ({comb}).\n" + + $"The comb's datetime is {CoreHelpers.DateFromComb(comb)}." + + $"Hub start and end times are configured as follows:\n" + + string.Join("\n", _connections.Select(c => $"Hub {c.HubName} - Start: {c.RegistrationStartDate}, End: {c.RegistrationEndDate}"))); + } + var resolvedConnection = possibleConnections[CoreHelpers.BinForComb(comb, possibleConnections.Length)]; + _logger.LogTrace("Resolved notification hub for comb {Comb} out of {HubCount} hubs.\n{ConnectionInfo}", comb, possibleConnections.Length, resolvedConnection.LogString); + return resolvedConnection.HubClient; + } + + public INotificationHubProxy AllClients { get { return new NotificationHubClientProxy(_clients); } } +} diff --git a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs similarity index 84% rename from src/Core/Services/Implementations/NotificationHubPushNotificationService.cs rename to src/Core/NotificationHub/NotificationHubPushNotificationService.cs index 480f0dfa9..6143676de 100644 --- a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs +++ b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs @@ -6,45 +6,31 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Settings; +using Bit.Core.Services; using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; using Microsoft.AspNetCore.Http; -using Microsoft.Azure.NotificationHubs; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.NotificationHub; public class NotificationHubPushNotificationService : IPushNotificationService { private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; private readonly IHttpContextAccessor _httpContextAccessor; - private readonly List _clients = []; private readonly bool _enableTracing = false; + private readonly INotificationHubPool _notificationHubPool; private readonly ILogger _logger; public NotificationHubPushNotificationService( IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, + INotificationHubPool notificationHubPool, IHttpContextAccessor httpContextAccessor, ILogger logger) { _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; _httpContextAccessor = httpContextAccessor; - - foreach (var hub in globalSettings.NotificationHubs) - { - var client = NotificationHubClient.CreateClientFromConnectionString( - hub.ConnectionString, - hub.HubName, - hub.EnableSendTracing); - _clients.Add(client); - - _enableTracing = _enableTracing || hub.EnableSendTracing; - } - + _notificationHubPool = notificationHubPool; _logger = logger; } @@ -264,30 +250,23 @@ public class NotificationHubPushNotificationService : IPushNotificationService private async Task SendPayloadAsync(string tag, PushType type, object payload) { - var tasks = new List>(); - foreach (var client in _clients) - { - var task = client.SendTemplateNotificationAsync( - new Dictionary - { - { "type", ((byte)type).ToString() }, - { "payload", JsonSerializer.Serialize(payload) } - }, tag); - tasks.Add(task); - } - - await Task.WhenAll(tasks); + var results = await _notificationHubPool.AllClients.SendTemplateNotificationAsync( + new Dictionary + { + { "type", ((byte)type).ToString() }, + { "payload", JsonSerializer.Serialize(payload) } + }, tag); if (_enableTracing) { - for (var i = 0; i < tasks.Count; i++) + foreach (var (client, outcome) in results) { - if (_clients[i].EnableTestSend) + if (!client.EnableTestSend) { - var outcome = await tasks[i]; - _logger.LogInformation("Azure Notification Hub Tracking ID: {id} | {type} push notification with {success} successes and {failure} failures with a payload of {@payload} and result of {@results}", - outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results); + continue; } + _logger.LogInformation("Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}", + outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results); } } } diff --git a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs similarity index 64% rename from src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs rename to src/Core/NotificationHub/NotificationHubPushRegistrationService.cs index 87df60e8e..ae32babf4 100644 --- a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs +++ b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs @@ -1,50 +1,34 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Settings; using Microsoft.Azure.NotificationHubs; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.NotificationHub; public class NotificationHubPushRegistrationService : IPushRegistrationService { private readonly IInstallationDeviceRepository _installationDeviceRepository; private readonly GlobalSettings _globalSettings; + private readonly INotificationHubPool _notificationHubPool; private readonly IServiceProvider _serviceProvider; private readonly ILogger _logger; - private Dictionary _clients = []; public NotificationHubPushRegistrationService( IInstallationDeviceRepository installationDeviceRepository, GlobalSettings globalSettings, + INotificationHubPool notificationHubPool, IServiceProvider serviceProvider, ILogger logger) { _installationDeviceRepository = installationDeviceRepository; _globalSettings = globalSettings; + _notificationHubPool = notificationHubPool; _serviceProvider = serviceProvider; _logger = logger; - - // Is this dirty to do in the ctor? - void addHub(NotificationHubType type) - { - var hubRegistration = globalSettings.NotificationHubs.FirstOrDefault( - h => h.HubType == type && h.EnableRegistration); - if (hubRegistration != null) - { - var client = NotificationHubClient.CreateClientFromConnectionString( - hubRegistration.ConnectionString, - hubRegistration.HubName, - hubRegistration.EnableSendTracing); - _clients.Add(type, client); - } - } - - addHub(NotificationHubType.General); - addHub(NotificationHubType.iOS); - addHub(NotificationHubType.Android); } public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, @@ -117,7 +101,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, userId, identifier); - await GetClient(type).CreateOrUpdateInstallationAsync(installation); + await ClientFor(GetComb(deviceId)).CreateOrUpdateInstallationAsync(installation); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); @@ -152,11 +136,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService installation.Templates.Add(fullTemplateId, template); } - public async Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType) + public async Task DeleteRegistrationAsync(string deviceId) { try { - await GetClient(deviceType).DeleteInstallationAsync(deviceId); + await ClientFor(GetComb(deviceId)).DeleteInstallationAsync(deviceId); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId)); @@ -168,31 +152,31 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService } } - public async Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { - await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Add, $"organizationId:{organizationId}"); - if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key)) + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}"); + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) { - var entities = devices.Select(e => new InstallationDeviceEntity(e.Key)); + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); } } - public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { - await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Remove, + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove, $"organizationId:{organizationId}"); - if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key)) + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) { - var entities = devices.Select(e => new InstallationDeviceEntity(e.Key)); + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); } } - private async Task PatchTagsForUserDevicesAsync(IEnumerable> devices, UpdateOperationType op, + private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op, string tag) { - if (!devices.Any()) + if (!deviceIds.Any()) { return; } @@ -212,11 +196,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService operation.Path += $"/{tag}"; } - foreach (var device in devices) + foreach (var deviceId in deviceIds) { try { - await GetClient(device.Value).PatchInstallationAsync(device.Key, new List { operation }); + await ClientFor(GetComb(deviceId)).PatchInstallationAsync(deviceId, new List { operation }); } catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) { @@ -225,53 +209,29 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService } } - private NotificationHubClient GetClient(DeviceType deviceType) + private NotificationHubClient ClientFor(Guid deviceId) { - var hubType = NotificationHubType.General; - switch (deviceType) + return _notificationHubPool.ClientFor(deviceId); + } + + private Guid GetComb(string deviceId) + { + var deviceIdString = deviceId; + InstallationDeviceEntity installationDeviceEntity; + Guid deviceIdGuid; + if (InstallationDeviceEntity.TryParse(deviceIdString, out installationDeviceEntity)) { - case DeviceType.Android: - hubType = NotificationHubType.Android; - break; - case DeviceType.iOS: - hubType = NotificationHubType.iOS; - break; - case DeviceType.ChromeExtension: - case DeviceType.FirefoxExtension: - case DeviceType.OperaExtension: - case DeviceType.EdgeExtension: - case DeviceType.VivaldiExtension: - case DeviceType.SafariExtension: - hubType = NotificationHubType.GeneralBrowserExtension; - break; - case DeviceType.WindowsDesktop: - case DeviceType.MacOsDesktop: - case DeviceType.LinuxDesktop: - hubType = NotificationHubType.GeneralDesktop; - break; - case DeviceType.ChromeBrowser: - case DeviceType.FirefoxBrowser: - case DeviceType.OperaBrowser: - case DeviceType.EdgeBrowser: - case DeviceType.IEBrowser: - case DeviceType.UnknownBrowser: - case DeviceType.SafariBrowser: - case DeviceType.VivaldiBrowser: - hubType = NotificationHubType.GeneralWeb; - break; - default: - break; + // Strip off the installation id (PartitionId). RowKey is the ID in the Installation's table. + deviceIdString = installationDeviceEntity.RowKey; } - if (!_clients.ContainsKey(hubType)) + if (Guid.TryParse(deviceIdString, out deviceIdGuid)) { - _logger.LogWarning("No hub client for '{0}'. Using general hub instead.", hubType); - hubType = NotificationHubType.General; - if (!_clients.ContainsKey(hubType)) - { - throw new Exception("No general hub client found."); - } } - return _clients[hubType]; + else + { + throw new Exception($"Invalid device id {deviceId}."); + } + return deviceIdGuid; } } diff --git a/src/Core/Services/IPushRegistrationService.cs b/src/Core/Services/IPushRegistrationService.cs index 83bbed485..985246de0 100644 --- a/src/Core/Services/IPushRegistrationService.cs +++ b/src/Core/Services/IPushRegistrationService.cs @@ -6,7 +6,7 @@ public interface IPushRegistrationService { Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, string identifier, DeviceType type); - Task DeleteRegistrationAsync(string deviceId, DeviceType type); - Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId); - Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId); + Task DeleteRegistrationAsync(string deviceId); + Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); + Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); } diff --git a/src/Core/Services/Implementations/DeviceService.cs b/src/Core/Services/Implementations/DeviceService.cs index 9d8315f69..5b1e4b0f0 100644 --- a/src/Core/Services/Implementations/DeviceService.cs +++ b/src/Core/Services/Implementations/DeviceService.cs @@ -38,13 +38,13 @@ public class DeviceService : IDeviceService public async Task ClearTokenAsync(Device device) { await _deviceRepository.ClearPushTokenAsync(device.Id); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); } public async Task DeleteAsync(Device device) { await _deviceRepository.DeleteAsync(device); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); } public async Task UpdateDevicesTrustAsync(string currentDeviceIdentifier, diff --git a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs index 92e29908f..00be72c98 100644 --- a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs +++ b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs @@ -1,61 +1,31 @@ using Bit.Core.Auth.Entities; using Bit.Core.Enums; -using Bit.Core.Repositories; using Bit.Core.Settings; using Bit.Core.Tools.Entities; -using Bit.Core.Utilities; using Bit.Core.Vault.Entities; -using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Bit.Core.Services; public class MultiServicePushNotificationService : IPushNotificationService { - private readonly List _services = new List(); + private readonly IEnumerable _services; private readonly ILogger _logger; public MultiServicePushNotificationService( - IHttpClientFactory httpFactory, - IDeviceRepository deviceRepository, - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, + [FromKeyedServices("implementation")] IEnumerable services, ILogger logger, - ILogger relayLogger, - ILogger hubLogger) + GlobalSettings globalSettings) { - if (globalSettings.SelfHosted) - { - if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) - { - _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings, - httpContextAccessor, relayLogger)); - } - if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && - CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) - { - _services.Add(new NotificationsApiPushNotificationService( - httpFactory, globalSettings, httpContextAccessor, hubLogger)); - } - } - else - { - var generalHub = globalSettings.NotificationHubs?.FirstOrDefault(h => h.HubType == NotificationHubType.General); - if (CoreHelpers.SettingHasValue(generalHub?.ConnectionString)) - { - _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository, - globalSettings, httpContextAccessor, hubLogger)); - } - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) - { - _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor)); - } - } + _services = services; _logger = logger; + _logger.LogInformation("Hub services: {Services}", _services.Count()); + globalSettings?.NotificationHubPool?.NotificationHubs?.ForEach(hub => + { + _logger.LogInformation("HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}", hub.HubName, hub.EnableSendTracing, hub.RegistrationStartDate, hub.RegistrationEndDate); + }); } public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) diff --git a/src/Core/Services/Implementations/RelayPushRegistrationService.cs b/src/Core/Services/Implementations/RelayPushRegistrationService.cs index d9df7d04d..d0f7736e9 100644 --- a/src/Core/Services/Implementations/RelayPushRegistrationService.cs +++ b/src/Core/Services/Implementations/RelayPushRegistrationService.cs @@ -38,37 +38,36 @@ public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegi await SendAsync(HttpMethod.Post, "push/register", requestModel); } - public async Task DeleteRegistrationAsync(string deviceId, DeviceType type) + public async Task DeleteRegistrationAsync(string deviceId) { var requestModel = new PushDeviceRequestModel { Id = deviceId, - Type = type, }; await SendAsync(HttpMethod.Post, "push/delete", requestModel); } public async Task AddUserRegistrationOrganizationAsync( - IEnumerable> devices, string organizationId) + IEnumerable deviceIds, string organizationId) { - if (!devices.Any()) + if (!deviceIds.Any()) { return; } - var requestModel = new PushUpdateRequestModel(devices, organizationId); + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); await SendAsync(HttpMethod.Put, "push/add-organization", requestModel); } public async Task DeleteUserRegistrationOrganizationAsync( - IEnumerable> devices, string organizationId) + IEnumerable deviceIds, string organizationId) { - if (!devices.Any()) + if (!deviceIds.Any()) { return; } - var requestModel = new PushUpdateRequestModel(devices, organizationId); + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel); } } diff --git a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs index fcd088924..f6279c946 100644 --- a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs +++ b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs @@ -4,7 +4,7 @@ namespace Bit.Core.Services; public class NoopPushRegistrationService : IPushRegistrationService { - public Task AddUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { return Task.FromResult(0); } @@ -15,12 +15,12 @@ public class NoopPushRegistrationService : IPushRegistrationService return Task.FromResult(0); } - public Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType) + public Task DeleteRegistrationAsync(string deviceId) { return Task.FromResult(0); } - public Task DeleteUserRegistrationOrganizationAsync(IEnumerable> devices, string organizationId) + public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { return Task.FromResult(0); } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index f99fb3b57..793b6ac1c 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -1,5 +1,4 @@ using Bit.Core.Auth.Settings; -using Bit.Core.Enums; using Bit.Core.Settings.LoggingSettings; namespace Bit.Core.Settings; @@ -65,7 +64,7 @@ public class GlobalSettings : IGlobalSettings public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); public virtual ILogLevelSettings MinLogLevel { get; set; } = new LogLevelSettings(); - public virtual List NotificationHubs { get; set; } = new(); + public virtual NotificationHubPoolSettings NotificationHubPool { get; set; } = new(); public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); public virtual DuoSettings Duo { get; set; } = new DuoSettings(); public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); @@ -424,7 +423,7 @@ public class GlobalSettings : IGlobalSettings public string ConnectionString { get => _connectionString; - set => _connectionString = value.Trim('"'); + set => _connectionString = value?.Trim('"'); } public string HubName { get; set; } /// @@ -433,10 +432,32 @@ public class GlobalSettings : IGlobalSettings /// public bool EnableSendTracing { get; set; } = false; /// - /// At least one hub configuration should have registration enabled, preferably the General hub as a safety net. + /// The date and time at which registration will be enabled. + /// + /// **This value should not be updated once set, as it is used to determine installation location of devices.** + /// + /// If null, registration is disabled. + /// /// - public bool EnableRegistration { get; set; } - public NotificationHubType HubType { get; set; } + public DateTime? RegistrationStartDate { get; set; } + /// + /// The date and time at which registration will be disabled. + /// + /// **This value should not be updated once set, as it is used to determine installation location of devices.** + /// + /// If null, hub registration has no yet known expiry. + /// + public DateTime? RegistrationEndDate { get; set; } + } + + public class NotificationHubPoolSettings + { + /// + /// List of Notification Hub settings to use for sending push notifications. + /// + /// Note that hubs on the same namespace share active device limits, so multiple namespaces should be used to increase capacity. + /// + public List NotificationHubs { get; set; } = new(); } public class YubicoSettings diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index d900c82e2..af985914c 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -76,6 +76,39 @@ public static class CoreHelpers return new Guid(guidArray); } + internal static DateTime DateFromComb(Guid combGuid) + { + var guidArray = combGuid.ToByteArray(); + var daysArray = new byte[4]; + var msecsArray = new byte[4]; + + Array.Copy(guidArray, guidArray.Length - 6, daysArray, 2, 2); + Array.Copy(guidArray, guidArray.Length - 4, msecsArray, 0, 4); + + Array.Reverse(daysArray); + Array.Reverse(msecsArray); + + var days = BitConverter.ToInt32(daysArray, 0); + var msecs = BitConverter.ToInt32(msecsArray, 0); + + var time = TimeSpan.FromDays(days) + TimeSpan.FromMilliseconds(msecs * 3.333333); + return new DateTime(_baseDateTicks + time.Ticks, DateTimeKind.Utc); + } + + internal static long BinForComb(Guid combGuid, int binCount) + { + // From System.Web.Util.HashCodeCombiner + uint CombineHashCodes(uint h1, byte h2) + { + return (uint)(((h1 << 5) + h1) ^ h2); + } + var guidArray = combGuid.ToByteArray(); + var randomArray = new byte[10]; + Array.Copy(guidArray, 0, randomArray, 0, 10); + var hash = randomArray.Aggregate((uint)randomArray.Length, CombineHashCodes); + return hash % binCount; + } + public static string CleanCertificateThumbprint(string thumbprint) { // Clean possible garbage characters from thumbprint copy/paste diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 040a49baa..b0a2c42ea 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -25,6 +25,7 @@ using Bit.Core.Enums; using Bit.Core.HostedServices; using Bit.Core.Identity; using Bit.Core.IdentityServer; +using Bit.Core.NotificationHub; using Bit.Core.OrganizationFeatures; using Bit.Core.Repositories; using Bit.Core.Resources; @@ -264,16 +265,30 @@ public static class ServiceCollectionExtensions } services.AddSingleton(); - if (globalSettings.SelfHosted && - CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + if (globalSettings.SelfHosted) { - services.AddSingleton(); + if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && + globalSettings.Installation?.Id != null && + CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + { + services.AddKeyedSingleton("implementation"); + services.AddSingleton(); + } + if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && + CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) + { + services.AddKeyedSingleton("implementation"); + } } else if (!globalSettings.SelfHosted) { + services.AddSingleton(); services.AddSingleton(); + services.AddKeyedSingleton("implementation"); + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) + { + services.AddKeyedSingleton("implementation"); + } } else { diff --git a/test/Common/AutoFixture/ControllerCustomization.cs b/test/Common/AutoFixture/ControllerCustomization.cs index f695f86b5..91fffbf09 100644 --- a/test/Common/AutoFixture/ControllerCustomization.cs +++ b/test/Common/AutoFixture/ControllerCustomization.cs @@ -1,6 +1,5 @@ using AutoFixture; using Microsoft.AspNetCore.Mvc; -using Org.BouncyCastle.Security; namespace Bit.Test.Common.AutoFixture; @@ -15,7 +14,7 @@ public class ControllerCustomization : ICustomization { if (!controllerType.IsAssignableTo(typeof(Controller))) { - throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); + throw new Exception($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); } _controllerType = controllerType; diff --git a/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs b/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs new file mode 100644 index 000000000..0d7382b3c --- /dev/null +++ b/test/Core.Test/NotificationHub/NotificationHubConnectionTests.cs @@ -0,0 +1,205 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Xunit; + +namespace Bit.Core.Test.NotificationHub; + +public class NotificationHubConnectionTests +{ + [Fact] + public void IsValid_ConnectionStringIsNull_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = null, + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + + // Act + var connection = NotificationHubConnection.From(hub); + + // Assert + Assert.False(connection.IsValid); + } + + [Fact] + public void IsValid_HubNameIsNull_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "Endpoint=sb://example.servicebus.windows.net/;", + HubName = null, + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + + // Act + var connection = NotificationHubConnection.From(hub); + + // Assert + Assert.False(connection.IsValid); + } + + [Fact] + public void IsValid_ConnectionStringAndHubNameAreNotNull_ReturnsTrue() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + + // Act + var connection = NotificationHubConnection.From(hub); + + // Assert + Assert.True(connection.IsValid); + } + + [Fact] + public void RegistrationEnabled_QueryTimeIsBeforeStartDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow.AddDays(1), + RegistrationEndDate = DateTime.UtcNow.AddDays(2) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_QueryTimeIsAfterEndDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow.AddDays(2)); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_NullStartDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = null, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_QueryTimeIsBetweenStartDateAndEndDate_ReturnsTrue() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(DateTime.UtcNow.AddHours(1)); + + // Assert + Assert.True(result); + } + + [Fact] + public void RegistrationEnabled_CombTimeIsBeforeStartDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow.AddDays(1), + RegistrationEndDate = DateTime.UtcNow.AddDays(2) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow)); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_CombTimeIsAfterEndDate_ReturnsFalse() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow.AddDays(2))); + + // Assert + Assert.False(result); + } + + [Fact] + public void RegistrationEnabled_CombTimeIsBetweenStartDateAndEndDate_ReturnsTrue() + { + // Arrange + var hub = new GlobalSettings.NotificationHubSettings() + { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + }; + var connection = NotificationHubConnection.From(hub); + + // Act + var result = connection.RegistrationEnabled(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow.AddHours(1))); + + // Assert + Assert.True(result); + } +} diff --git a/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs b/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs new file mode 100644 index 000000000..dd9afb867 --- /dev/null +++ b/test/Core.Test/NotificationHub/NotificationHubPoolTests.cs @@ -0,0 +1,156 @@ +using Bit.Core.NotificationHub; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; +using static Bit.Core.Settings.GlobalSettings; + +namespace Bit.Core.Test.NotificationHub; + +public class NotificationHubPoolTests +{ + [Fact] + public void NotificationHubPool_WarnsOnMissingConnectionString() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = null, + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + } + } + } + }; + var logger = Substitute.For>(); + + // Act + var sut = new NotificationHubPool(logger, globalSettings); + + // Assert + logger.Received().Log(LogLevel.Warning, Arg.Any(), + Arg.Is(o => o.ToString() == "Invalid notification hub settings: hub"), + null, + Arg.Any>()); + } + + [Fact] + public void NotificationHubPool_WarnsOnMissingHubName() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "connection", + HubName = null, + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1) + } + } + } + }; + var logger = Substitute.For>(); + + // Act + var sut = new NotificationHubPool(logger, globalSettings); + + // Assert + logger.Received().Log(LogLevel.Warning, Arg.Any(), + Arg.Is(o => o.ToString() == "Invalid notification hub settings: hub name missing"), + null, + Arg.Any>()); + } + + [Fact] + public void NotificationHubPool_ClientFor_ThrowsOnNoValidHubs() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = null, + RegistrationEndDate = null, + } + } + } + }; + var logger = Substitute.For>(); + var sut = new NotificationHubPool(logger, globalSettings); + + // Act + Action act = () => sut.ClientFor(Guid.NewGuid()); + + // Assert + Assert.Throws(act); + } + + [Fact] + public void NotificationHubPool_ClientFor_ReturnsClient() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "Endpoint=sb://example.servicebus.windows.net/;SharedAccessKey=example///example=", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow.AddMinutes(-1), + RegistrationEndDate = DateTime.UtcNow.AddDays(1), + } + } + } + }; + var logger = Substitute.For>(); + var sut = new NotificationHubPool(logger, globalSettings); + + // Act + var client = sut.ClientFor(CoreHelpers.GenerateComb(Guid.NewGuid(), DateTime.UtcNow)); + + // Assert + Assert.NotNull(client); + } + + [Fact] + public void NotificationHubPool_AllClients_ReturnsProxy() + { + // Arrange + var globalSettings = new GlobalSettings() + { + NotificationHubPool = new NotificationHubPoolSettings() + { + NotificationHubs = new() { + new() { + ConnectionString = "connection", + HubName = "hub", + RegistrationStartDate = DateTime.UtcNow, + RegistrationEndDate = DateTime.UtcNow.AddDays(1), + } + } + } + }; + var logger = Substitute.For>(); + var sut = new NotificationHubPool(logger, globalSettings); + + // Act + var proxy = sut.AllClients; + + // Assert + Assert.NotNull(proxy); + } +} diff --git a/test/Core.Test/NotificationHub/NotificationHubProxyTests.cs b/test/Core.Test/NotificationHub/NotificationHubProxyTests.cs new file mode 100644 index 000000000..b2e9c4f9f --- /dev/null +++ b/test/Core.Test/NotificationHub/NotificationHubProxyTests.cs @@ -0,0 +1,40 @@ +using AutoFixture; +using Bit.Core.NotificationHub; +using Bit.Test.Common.AutoFixture; +using Microsoft.Azure.NotificationHubs; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.NotificationHub; + +public class NotificationHubProxyTests +{ + private readonly IEnumerable _clients; + public NotificationHubProxyTests() + { + _clients = new Fixture().WithAutoNSubstitutions().CreateMany(); + } + + public static IEnumerable ClientMethods = + [ + [ + (NotificationHubClientProxy c) => c.SendTemplateNotificationAsync(new Dictionary() { { "key", "value" } }, "tag"), + (INotificationHubClient c) => c.SendTemplateNotificationAsync(Arg.Is>((a) => a.Keys.Count == 1 && a.ContainsKey("key") && a["key"] == "value"), "tag"), + ], + ]; + + [Theory] + [MemberData(nameof(ClientMethods))] + public async void CallsAllClients(Func proxyMethod, Func clientMethod) + { + var clients = _clients.ToArray(); + var proxy = new NotificationHubClientProxy(clients); + + await proxyMethod(proxy); + + foreach (var client in clients) + { + await clientMethod(client.Received()); + } + } +} diff --git a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs b/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs similarity index 81% rename from test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs rename to test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs index 82594445a..ea9ce5413 100644 --- a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs +++ b/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs @@ -1,32 +1,32 @@ -using Bit.Core.Repositories; +using Bit.Core.NotificationHub; +using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.NotificationHub; public class NotificationHubPushNotificationServiceTests { private readonly NotificationHubPushNotificationService _sut; private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; + private readonly INotificationHubPool _notificationHubPool; private readonly IHttpContextAccessor _httpContextAccessor; private readonly ILogger _logger; public NotificationHubPushNotificationServiceTests() { _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); _httpContextAccessor = Substitute.For(); + _notificationHubPool = Substitute.For(); _logger = Substitute.For>(); _sut = new NotificationHubPushNotificationService( _installationDeviceRepository, - _globalSettings, + _notificationHubPool, _httpContextAccessor, _logger ); diff --git a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs b/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs similarity index 82% rename from test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs rename to test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs index a8dd536b8..c5851f279 100644 --- a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs +++ b/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs @@ -1,11 +1,11 @@ -using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.NotificationHub; +using Bit.Core.Repositories; using Bit.Core.Settings; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.NotificationHub; public class NotificationHubPushRegistrationServiceTests { @@ -15,6 +15,7 @@ public class NotificationHubPushRegistrationServiceTests private readonly IServiceProvider _serviceProvider; private readonly ILogger _logger; private readonly GlobalSettings _globalSettings; + private readonly INotificationHubPool _notificationHubPool; public NotificationHubPushRegistrationServiceTests() { @@ -22,10 +23,12 @@ public class NotificationHubPushRegistrationServiceTests _serviceProvider = Substitute.For(); _logger = Substitute.For>(); _globalSettings = new GlobalSettings(); + _notificationHubPool = Substitute.For(); _sut = new NotificationHubPushRegistrationService( _installationDeviceRepository, _globalSettings, + _notificationHubPool, _serviceProvider, _logger ); diff --git a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs index b1876f1dd..68d6c50a7 100644 --- a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs +++ b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs @@ -1,10 +1,10 @@ -using Bit.Core.Repositories; +using AutoFixture; using Bit.Core.Services; -using Bit.Core.Settings; -using Microsoft.AspNetCore.Http; +using Bit.Test.Common.AutoFixture; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; +using GlobalSettingsCustomization = Bit.Test.Common.AutoFixture.GlobalSettings; namespace Bit.Core.Test.Services; @@ -12,35 +12,26 @@ public class MultiServicePushNotificationServiceTests { private readonly MultiServicePushNotificationService _sut; - private readonly IHttpClientFactory _httpFactory; - private readonly IDeviceRepository _deviceRepository; - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; private readonly ILogger _logger; private readonly ILogger _relayLogger; private readonly ILogger _hubLogger; + private readonly IEnumerable _services; + private readonly Settings.GlobalSettings _globalSettings; public MultiServicePushNotificationServiceTests() { - _httpFactory = Substitute.For(); - _deviceRepository = Substitute.For(); - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); _logger = Substitute.For>(); _relayLogger = Substitute.For>(); _hubLogger = Substitute.For>(); + var fixture = new Fixture().WithAutoNSubstitutions().Customize(new GlobalSettingsCustomization()); + _services = fixture.CreateMany(); + _globalSettings = fixture.Create(); + _sut = new MultiServicePushNotificationService( - _httpFactory, - _deviceRepository, - _installationDeviceRepository, - _globalSettings, - _httpContextAccessor, + _services, _logger, - _relayLogger, - _hubLogger + _globalSettings ); } diff --git a/test/Core.Test/Utilities/CoreHelpersTests.cs b/test/Core.Test/Utilities/CoreHelpersTests.cs index af1156798..2cce276fc 100644 --- a/test/Core.Test/Utilities/CoreHelpersTests.cs +++ b/test/Core.Test/Utilities/CoreHelpersTests.cs @@ -34,33 +34,30 @@ public class CoreHelpersTests // the comb are working properly } - public static IEnumerable GenerateCombCases = new[] - { - new object[] - { + public static IEnumerable GuidSeedCases = [ + [ Guid.Parse("a58db474-43d8-42f1-b4ee-0c17647cd0c0"), // Input Guid new DateTime(2022, 3, 12, 12, 12, 0, DateTimeKind.Utc), // Input Time - Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb - }, - new object[] - { + ], + [ Guid.Parse("f776e6ee-511f-4352-bb28-88513002bdeb"), new DateTime(2021, 5, 10, 10, 52, 0, DateTimeKind.Utc), - Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), - }, - new object[] - { + ], + [ Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011648a1"), new DateTime(1999, 2, 26, 16, 53, 13, DateTimeKind.Utc), - Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), - }, - new object[] - { + ], + [ Guid.Parse("bfb8f353-3b32-4a9e-bef6-24fe0b54bfb0"), new DateTime(2024, 10, 20, 1, 32, 16, DateTimeKind.Utc), - Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), - } - }; + ] + ]; + public static IEnumerable GenerateCombCases = GuidSeedCases.Zip([ + Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb for each Guid Seed case + Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), + Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), + Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), + ]).Select((zip) => new object[] { zip.Item1[0], zip.Item1[1], zip.Item2 }); [Theory] [MemberData(nameof(GenerateCombCases))] @@ -71,6 +68,31 @@ public class CoreHelpersTests Assert.Equal(expectedComb, comb); } + [Theory] + [MemberData(nameof(GuidSeedCases))] + public void DateFromComb_WithComb_Success(Guid inputGuid, DateTime inputTime) + { + var comb = CoreHelpers.GenerateComb(inputGuid, inputTime); + var inverseComb = CoreHelpers.DateFromComb(comb); + + Assert.Equal(inputTime, inverseComb, TimeSpan.FromMilliseconds(4)); + } + + [Theory] + [InlineData("00000000-0000-0000-0000-000000000000", 1, 0)] + [InlineData("00000000-0000-0000-0000-000000000001", 1, 0)] + [InlineData("00000000-0000-0000-0000-000000000000", 500, 430)] + [InlineData("00000000-0000-0000-0000-000000000001", 500, 430)] + [InlineData("10000000-0000-0000-0000-000000000001", 500, 454)] + [InlineData("00000000-0000-0100-0000-000000000001", 500, 19)] + public void BinForComb_Success(string guidString, int nbins, int expectedBin) + { + var guid = Guid.Parse(guidString); + var bin = CoreHelpers.BinForComb(guid, nbins); + + Assert.Equal(expectedBin, bin); + } + /* [Fact] public void ToGuidIdArrayTVP_Success()