1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-25 12:45:18 +01:00

PM-10600: Sending to specific client types for other clients

This commit is contained in:
Maciej Zieniuk 2024-10-22 11:11:55 +01:00
parent 7020565770
commit 6296c1fb1f
No known key found for this signature in database
GPG Key ID: 9CACE59F1272ACD9
7 changed files with 96 additions and 16 deletions

View File

@ -51,6 +51,7 @@ public class SyncNotificationPushNotification
public bool Global { get; set; }
public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; }
public ClientType ClientType { get; set; }
public DateTime RevisionDate { get; set; }
}

View File

@ -174,6 +174,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
Global = notification.Global,
UserId = notification.Id,
OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate
};

View File

@ -202,6 +202,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
Global = notification.Global,
UserId = notification.Id,
OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate
};

View File

@ -181,6 +181,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
Global = notification.Global,
UserId = notification.Id,
OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate
};

View File

@ -197,6 +197,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
Global = notification.Global,
UserId = notification.Id,
OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate
};

View File

@ -10,6 +10,8 @@ public static class HubHelpers
private static JsonSerializerOptions _deserializerOptions =
new JsonSerializerOptions { PropertyNameCaseInsensitive = true };
private static readonly string _receiveMessageMethod = "ReceiveMessage";
public static async Task SendNotificationToHubAsync(
string notificationJson,
IHubContext<NotificationsHub> hubContext,
@ -33,13 +35,13 @@ public static class HubHelpers
if (cipherNotification.Payload.UserId.HasValue)
{
await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", cipherNotification, cancellationToken);
.SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken);
}
else if (cipherNotification.Payload.OrganizationId.HasValue)
{
await hubContext.Clients.Group(
$"Organization_{cipherNotification.Payload.OrganizationId}")
.SendAsync("ReceiveMessage", cipherNotification, cancellationToken);
await hubContext.Clients
.Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value))
.SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken);
}
break;
@ -50,7 +52,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<SyncFolderPushNotification>>(
notificationJson, _deserializerOptions);
await hubContext.Clients.User(folderNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", folderNotification, cancellationToken);
.SendAsync(_receiveMessageMethod, folderNotification, cancellationToken);
break;
case PushType.SyncCiphers:
case PushType.SyncVault:
@ -62,7 +64,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<UserPushNotification>>(
notificationJson, _deserializerOptions);
await hubContext.Clients.User(userNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", userNotification, cancellationToken);
.SendAsync(_receiveMessageMethod, userNotification, cancellationToken);
break;
case PushType.SyncSendCreate:
case PushType.SyncSendUpdate:
@ -71,7 +73,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<SyncSendPushNotification>>(
notificationJson, _deserializerOptions);
await hubContext.Clients.User(sendNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", sendNotification, cancellationToken);
.SendAsync(_receiveMessageMethod, sendNotification, cancellationToken);
break;
case PushType.AuthRequestResponse:
var authRequestResponseNotification =
@ -85,7 +87,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<AuthRequestPushNotification>>(
notificationJson, _deserializerOptions);
await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", authRequestNotification, cancellationToken);
.SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken);
break;
case PushType.SyncNotification:
var syncNotification =
@ -93,19 +95,39 @@ public static class HubHelpers
notificationJson, _deserializerOptions);
if (syncNotification.Payload.Global)
{
await hubContext.Clients.All.SendAsync("ReceiveMessage", syncNotification, cancellationToken);
if (syncNotification.Payload.ClientType == ClientType.All)
{
await hubContext.Clients.All.SendAsync(_receiveMessageMethod, syncNotification,
cancellationToken);
}
else
{
await hubContext.Clients
.Group(NotificationsHub.GetGlobalGroup(syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
}
else if (syncNotification.Payload.UserId.HasValue)
{
await hubContext.Clients.User(syncNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", syncNotification, cancellationToken);
if (syncNotification.Payload.ClientType == ClientType.All)
{
await hubContext.Clients.User(syncNotification.Payload.UserId.ToString())
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
else
{
await hubContext.Clients.Group(NotificationsHub.GetUserGroup(
syncNotification.Payload.UserId.Value, syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
}
else if (syncNotification.Payload.OrganizationId.HasValue)
{
await hubContext.Clients.Group(
$"Organization_{syncNotification.Payload.OrganizationId}")
.SendAsync("ReceiveMessage", syncNotification, cancellationToken);
await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(
syncNotification.Payload.OrganizationId.Value, syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
break;
default:
break;

View File

@ -1,5 +1,7 @@
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
namespace Bit.Notifications;
@ -20,13 +22,30 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
{
var currentContext = new CurrentContext(null, null);
await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
if (clientType != ClientType.All)
{
await Groups.AddToGroupAsync(Context.ConnectionId, GetGlobalGroup(clientType));
if (currentContext.UserId.HasValue)
{
await Groups.AddToGroupAsync(Context.ConnectionId,
GetUserGroup(currentContext.UserId.Value, clientType));
}
}
if (currentContext.Organizations != null)
{
foreach (var org in currentContext.Organizations)
{
await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}");
await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id));
if (clientType != ClientType.All)
{
await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType));
}
}
}
_connectionCounter.Increment();
await base.OnConnectedAsync();
}
@ -35,14 +54,48 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
{
var currentContext = new CurrentContext(null, null);
await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
if (clientType != ClientType.All)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetGlobalGroup(clientType));
if (currentContext.UserId.HasValue)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId,
GetUserGroup(currentContext.UserId.Value, clientType));
}
}
if (currentContext.Organizations != null)
{
foreach (var org in currentContext.Organizations)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}");
await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id));
if (clientType != ClientType.All)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType));
}
}
}
_connectionCounter.Decrement();
await base.OnDisconnectedAsync(exception);
}
public static string GetGlobalGroup(ClientType clientType)
{
return $"ClientType_{clientType}";
}
public static string GetUserGroup(Guid userId, ClientType clientType)
{
return $"{userId}_{clientType}";
}
public static string GetOrganizationGroup(Guid organizationId, ClientType? clientType = null)
{
return clientType is not ClientType.All
? $"Organization_{organizationId}"
: $"Organization_{organizationId}_{clientType}";
}
}