diff --git a/src/Notifications/AzureQueueHostedService.cs b/src/Notifications/AzureQueueHostedService.cs index 977d9a9d1..dcb6e2470 100644 --- a/src/Notifications/AzureQueueHostedService.cs +++ b/src/Notifications/AzureQueueHostedService.cs @@ -1,31 +1,28 @@ -using Azure.Storage.Queues; +#nullable enable +using Azure.Storage.Queues; using Bit.Core.Settings; using Bit.Core.Utilities; -using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; + private readonly HubHelpers _hubHelpers; private readonly GlobalSettings _globalSettings; - private Task _executingTask; - private CancellationTokenSource _cts; + private Task? _executingTask; + private CancellationTokenSource? _cts; private QueueClient _queueClient; public AzureQueueHostedService( ILogger logger, - IHubContext hubContext, - IHubContext anonymousHubContext, + HubHelpers hubHelpers, GlobalSettings globalSettings) { _logger = logger; - _hubContext = hubContext; + _hubHelpers = hubHelpers; _globalSettings = globalSettings; - _anonymousHubContext = anonymousHubContext; } public Task StartAsync(CancellationToken cancellationToken) @@ -41,14 +38,16 @@ public class AzureQueueHostedService : IHostedService, IDisposable { return; } + _logger.LogWarning("Stopping service."); - _cts.Cancel(); + _cts?.Cancel(); await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); cancellationToken.ThrowIfCancellationRequested(); } public void Dispose() - { } + { + } private async Task ExecuteAsync(CancellationToken cancellationToken) { @@ -57,16 +56,17 @@ public class AzureQueueHostedService : IHostedService, IDisposable { try { - var messages = await _queueClient.ReceiveMessagesAsync(32); + var messages = await _queueClient.ReceiveMessagesAsync(32, cancellationToken: cancellationToken); if (messages.Value?.Any() ?? false) { foreach (var message in messages.Value) { try { - await HubHelpers.SendNotificationToHubAsync( - message.DecodeMessageText(), _hubContext, _anonymousHubContext, _logger, cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await _hubHelpers.SendNotificationToHubAsync(message.DecodeMessageText(), + cancellationToken); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } catch (Exception e) { @@ -74,7 +74,8 @@ public class AzureQueueHostedService : IHostedService, IDisposable message.MessageId, message.DequeueCount); if (message.DequeueCount > 2) { - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } } } diff --git a/src/Notifications/Controllers/SendController.cs b/src/Notifications/Controllers/SendController.cs index 7debd51df..c663102b5 100644 --- a/src/Notifications/Controllers/SendController.cs +++ b/src/Notifications/Controllers/SendController.cs @@ -1,36 +1,30 @@ -using System.Text; +#nullable enable +using System.Text; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; +namespace Bit.Notifications.Controllers; [Authorize("Internal")] public class SendController : Controller { - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; - private readonly ILogger _logger; + private readonly HubHelpers _hubHelpers; - public SendController(IHubContext hubContext, IHubContext anonymousHubContext, ILogger logger) + public SendController(HubHelpers hubHelpers) { - _hubContext = hubContext; - _anonymousHubContext = anonymousHubContext; - _logger = logger; + _hubHelpers = hubHelpers; } [HttpPost("~/send")] [SelfHosted(SelfHostedOnly = true)] - public async Task PostSend() + public async Task PostSendAsync() { - using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) + using var reader = new StreamReader(Request.Body, Encoding.UTF8); + var notificationJson = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(notificationJson)) { - var notificationJson = await reader.ReadToEndAsync(); - if (!string.IsNullOrWhiteSpace(notificationJson)) - { - await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext, _anonymousHubContext, _logger); - } + await _hubHelpers.SendNotificationToHubAsync(notificationJson); } } } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 25f43138e..df8fb077f 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -5,24 +5,30 @@ using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; -public static class HubHelpers +public class HubHelpers { - private static JsonSerializerOptions _deserializerOptions = - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; + private static readonly JsonSerializerOptions _deserializerOptions = new() { PropertyNameCaseInsensitive = true }; private static readonly string _receiveMessageMethod = "ReceiveMessage"; - public static async Task SendNotificationToHubAsync( - string notificationJson, - IHubContext hubContext, + private readonly IHubContext _hubContext; + private readonly IHubContext _anonymousHubContext; + private readonly ILogger _logger; + + public HubHelpers(IHubContext hubContext, IHubContext anonymousHubContext, - ILogger logger, - CancellationToken cancellationToken = default(CancellationToken) - ) + ILogger logger) + { + _hubContext = hubContext; + _anonymousHubContext = anonymousHubContext; + _logger = logger; + } + + public async Task SendNotificationToHubAsync(string notificationJson, CancellationToken cancellationToken = default) { var notification = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - logger.LogInformation("Sending notification: {NotificationType}", notification.Type); + _logger.LogInformation("Sending notification: {NotificationType}", notification.Type); switch (notification.Type) { case PushType.SyncCipherUpdate: @@ -34,12 +40,12 @@ public static class HubHelpers notificationJson, _deserializerOptions); if (cipherNotification.Payload.UserId.HasValue) { - await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } else if (cipherNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients + await _hubContext.Clients .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value)) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } @@ -51,7 +57,7 @@ public static class HubHelpers var folderNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken); break; case PushType.SyncCiphers: @@ -63,7 +69,7 @@ public static class HubHelpers var userNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(userNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, userNotification, cancellationToken); break; case PushType.SyncSendCreate: @@ -72,21 +78,21 @@ public static class HubHelpers var sendNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken); break; case PushType.AuthRequestResponse: var authRequestResponseNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) + await _anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) .SendAsync("AuthRequestResponseRecieved", authRequestResponseNotification, cancellationToken); break; case PushType.AuthRequest: var authRequestNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken); break; case PushType.SyncNotification: @@ -97,19 +103,19 @@ public static class HubHelpers { if (syncNotification.Payload.ClientType == ClientType.All) { - await hubContext.Clients.User(syncNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(syncNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); } else { - await hubContext.Clients.Group(NotificationsHub.GetUserGroup( + await _hubContext.Clients.Group(NotificationsHub.GetUserGroup( syncNotification.Payload.UserId.Value, syncNotification.Payload.ClientType)) .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); } } else if (syncNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( syncNotification.Payload.OrganizationId.Value, syncNotification.Payload.ClientType)) .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); } diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index 440808b78..961c47199 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -61,6 +61,7 @@ public class Startup } services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); // Mvc services.AddMvc(); diff --git a/test/Notifications.Test/HubHelpersTest.cs b/test/Notifications.Test/HubHelpersTest.cs new file mode 100644 index 000000000..2a30a55d4 --- /dev/null +++ b/test/Notifications.Test/HubHelpersTest.cs @@ -0,0 +1,201 @@ +#nullable enable +using System.Text.Json; +using Bit.Core.Enums; +using Bit.Core.Models; +using Bit.Core.Utilities; +using Bit.Notifications; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.SignalR; +using NSubstitute; + +namespace Notifications.Test; + +[SutProviderCustomize] +public class HubHelpersTest +{ + [Theory] + [BitAutoData] + public async Task SendNotificationToHubAsync_SyncNotificationGlobal_NothingSent(SutProvider sutProvider, + ClientType clientType, string contextId, CancellationToken cancellationToke) + { + var message = new SyncNotificationPushNotification + { + Id = Guid.NewGuid(), + UserId = null, + OrganizationId = null, + ClientType = clientType, + RevisionDate = DateTime.UtcNow + }; + + var json = ToNotificationJson(message, PushType.SyncNotification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToke); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).Group(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(false)] + [BitAutoData(true)] + public async Task SendNotificationToHubAsync_SyncNotificationUserIdProvidedClientTypeAll_SentToUser( + bool organizationIdProvided, SutProvider sutProvider, Guid userId, Guid organizationId, + string contextId, CancellationToken cancellationToken) + { + var syncNotification = new SyncNotificationPushNotification + { + Id = Guid.NewGuid(), + UserId = userId, + OrganizationId = organizationIdProvided ? organizationId : null, + ClientType = ClientType.All, + RevisionDate = DateTime.UtcNow + }; + + var json = ToNotificationJson(syncNotification, PushType.SyncNotification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + await sutProvider.GetDependency>().Clients.Received(1) + .User(userId.ToString()) + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsSyncNotificationEqual(syncNotification, objects[0], + PushType.SyncNotification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).Group(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(false, ClientType.Browser)] + [BitAutoData(false, ClientType.Desktop)] + [BitAutoData(false, ClientType.Mobile)] + [BitAutoData(false, ClientType.Web)] + [BitAutoData(true, ClientType.Browser)] + [BitAutoData(true, ClientType.Desktop)] + [BitAutoData(true, ClientType.Mobile)] + [BitAutoData(true, ClientType.Web)] + public async Task + SendNotificationToHubAsync_SyncNotificationUserIdProvidedClientTypeNotAll_SentToGroupUserClientType( + bool organizationIdProvided, ClientType clientType, SutProvider sutProvider, Guid userId, + string contextId, CancellationToken cancellationToken) + { + var syncNotification = new SyncNotificationPushNotification + { + Id = Guid.NewGuid(), + UserId = userId, + OrganizationId = organizationIdProvided ? Guid.NewGuid() : null, + ClientType = clientType, + RevisionDate = DateTime.UtcNow + }; + + var json = ToNotificationJson(syncNotification, PushType.SyncNotification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"UserClientType_{userId}_{clientType}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsSyncNotificationEqual(syncNotification, objects[0], + PushType.SyncNotification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task + SendNotificationToHubAsync_SyncNotificationUserIdNullOrganizationIdProvidedClientTypeAll_SentToGroupOrganization( + SutProvider sutProvider, string contextId, Guid organizationId, + CancellationToken cancellationToken) + { + var syncNotification = new SyncNotificationPushNotification + { + Id = Guid.NewGuid(), + UserId = null, + OrganizationId = organizationId, + ClientType = ClientType.All, + RevisionDate = DateTime.UtcNow + }; + + var json = ToNotificationJson(syncNotification, PushType.SyncNotification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"Organization_{organizationId}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsSyncNotificationEqual(syncNotification, objects[0], + PushType.SyncNotification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(ClientType.Browser)] + [BitAutoData(ClientType.Desktop)] + [BitAutoData(ClientType.Mobile)] + [BitAutoData(ClientType.Web)] + public async Task + SendNotificationToHubAsync_SyncNotificationUserIdNullOrganizationIdProvidedClientTypeNotAll_SentToGroupOrganizationClientType( + ClientType clientType, SutProvider sutProvider, string contextId, Guid organizationId, + CancellationToken cancellationToken) + { + var syncNotification = new SyncNotificationPushNotification + { + Id = Guid.NewGuid(), + UserId = null, + OrganizationId = organizationId, + ClientType = clientType, + RevisionDate = DateTime.UtcNow + }; + + var json = ToNotificationJson(syncNotification, PushType.SyncNotification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"OrganizationClientType_{organizationId}_{clientType}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsSyncNotificationEqual(syncNotification, objects[0], + PushType.SyncNotification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + private static string ToNotificationJson(object payload, PushType type, string contextId) + { + var notification = new PushNotificationData(type, payload, contextId); + return JsonSerializer.Serialize(notification, JsonHelpers.IgnoreWritingNull); + } + + private static bool IsSyncNotificationEqual(SyncNotificationPushNotification expected, object? actual, + PushType type, string contextId) + { + if (actual is not PushNotificationData pushNotificationData) + { + return false; + } + + return pushNotificationData.Type == type && + pushNotificationData.ContextId == contextId && + expected.Id == pushNotificationData.Payload.Id && + expected.UserId == pushNotificationData.Payload.UserId && + expected.OrganizationId == pushNotificationData.Payload.OrganizationId && + expected.ClientType == pushNotificationData.Payload.ClientType && + expected.RevisionDate == pushNotificationData.Payload.RevisionDate; + } +} diff --git a/test/Notifications.Test/Notifications.Test.csproj b/test/Notifications.Test/Notifications.Test.csproj index 4dd37605c..a4bab9df9 100644 --- a/test/Notifications.Test/Notifications.Test.csproj +++ b/test/Notifications.Test/Notifications.Test.csproj @@ -18,5 +18,7 @@ + +