1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

[PM-6339] Shard notification hub clients across multiple accounts (#3812)

* WIP registration updates

* fix deviceHubs

* addHub inline in ctor

* adjust setttings for hub reg

* send to all clients

* fix multiservice push

* use notification hub type

* feedback

---------

Co-authored-by: Matt Bishop <mbishop@bitwarden.com>
This commit is contained in:
Kyle Spearrin 2024-04-08 15:39:44 -04:00 committed by GitHub
parent de8b7b14b8
commit 40221f578f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 208 additions and 70 deletions

View File

@ -42,11 +42,11 @@ public class PushController : Controller
Prefix(model.UserId), Prefix(model.Identifier), model.Type);
}
[HttpDelete("{id}")]
public async Task Delete(string id)
[HttpPost("delete")]
public async Task PostDelete([FromBody] PushDeviceRequestModel model)
{
CheckUsage();
await _pushRegistrationService.DeleteRegistrationAsync(Prefix(id));
await _pushRegistrationService.DeleteRegistrationAsync(Prefix(model.Id), model.Type);
}
[HttpPut("add-organization")]
@ -54,7 +54,8 @@ public class PushController : Controller
{
CheckUsage();
await _pushRegistrationService.AddUserRegistrationOrganizationAsync(
model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId));
model.Devices.Select(d => new KeyValuePair<string, Core.Enums.DeviceType>(Prefix(d.Id), d.Type)),
Prefix(model.OrganizationId));
}
[HttpPut("delete-organization")]
@ -62,7 +63,8 @@ public class PushController : Controller
{
CheckUsage();
await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(
model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId));
model.Devices.Select(d => new KeyValuePair<string, Core.Enums.DeviceType>(Prefix(d.Id), d.Type)),
Prefix(model.OrganizationId));
}
[HttpPost("send")]

View File

@ -681,8 +681,8 @@ public class OrganizationService : IOrganizationService
await _organizationUserRepository.CreateAsync(orgUser);
var deviceIds = await GetUserDeviceIdsAsync(orgUser.UserId.Value);
await _pushRegistrationService.AddUserRegistrationOrganizationAsync(deviceIds,
var devices = await GetUserDeviceIdsAsync(orgUser.UserId.Value);
await _pushRegistrationService.AddUserRegistrationOrganizationAsync(devices,
organization.Id.ToString());
await _pushNotificationService.PushSyncOrgKeysAsync(ownerId);
}
@ -1932,17 +1932,19 @@ public class OrganizationService : IOrganizationService
private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId)
{
var deviceIds = await GetUserDeviceIdsAsync(userId);
await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(deviceIds,
var devices = await GetUserDeviceIdsAsync(userId);
await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(devices,
organizationId.ToString());
await _pushNotificationService.PushSyncOrgKeysAsync(userId);
}
private async Task<IEnumerable<string>> GetUserDeviceIdsAsync(Guid userId)
private async Task<IEnumerable<KeyValuePair<string, DeviceType>>> GetUserDeviceIdsAsync(Guid userId)
{
var devices = await _deviceRepository.GetManyByUserIdAsync(userId);
return devices.Where(d => !string.IsNullOrWhiteSpace(d.PushToken)).Select(d => d.Id.ToString());
return devices
.Where(d => !string.IsNullOrWhiteSpace(d.PushToken))
.Select(d => new KeyValuePair<string, DeviceType>(d.Id.ToString(), d.Type));
}
public async Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null)

View File

@ -0,0 +1,11 @@
namespace Bit.Core.Enums;
public enum NotificationHubType
{
General = 0,
Android = 1,
iOS = 2,
GeneralWeb = 3,
GeneralBrowserExtension = 4,
GeneralDesktop = 5
}

View File

@ -0,0 +1,12 @@
using System.ComponentModel.DataAnnotations;
using Bit.Core.Enums;
namespace Bit.Core.Models.Api;
public class PushDeviceRequestModel
{
[Required]
public string Id { get; set; }
[Required]
public DeviceType Type { get; set; }
}

View File

@ -1,4 +1,5 @@
using System.ComponentModel.DataAnnotations;
using Bit.Core.Enums;
namespace Bit.Core.Models.Api;
@ -7,14 +8,14 @@ public class PushUpdateRequestModel
public PushUpdateRequestModel()
{ }
public PushUpdateRequestModel(IEnumerable<string> deviceIds, string organizationId)
public PushUpdateRequestModel(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
DeviceIds = deviceIds;
Devices = devices.Select(d => new PushDeviceRequestModel { Id = d.Key, Type = d.Value });
OrganizationId = organizationId;
}
[Required]
public IEnumerable<string> DeviceIds { get; set; }
public IEnumerable<PushDeviceRequestModel> Devices { get; set; }
[Required]
public string OrganizationId { get; set; }
}

View File

@ -6,7 +6,7 @@ public interface IPushRegistrationService
{
Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
string identifier, DeviceType type);
Task DeleteRegistrationAsync(string deviceId);
Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId);
Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId);
Task DeleteRegistrationAsync(string deviceId, DeviceType type);
Task AddUserRegistrationOrganizationAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId);
Task DeleteUserRegistrationOrganizationAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId);
}

View File

@ -38,13 +38,13 @@ public class DeviceService : IDeviceService
public async Task ClearTokenAsync(Device device)
{
await _deviceRepository.ClearPushTokenAsync(device.Id);
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type);
}
public async Task DeleteAsync(Device device)
{
await _deviceRepository.DeleteAsync(device);
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString(), device.Type);
}
public async Task UpdateDevicesTrustAsync(string currentDeviceIdentifier,

View File

@ -43,7 +43,8 @@ public class MultiServicePushNotificationService : IPushNotificationService
}
else
{
if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString))
var generalHub = globalSettings.NotificationHubs?.FirstOrDefault(h => h.HubType == NotificationHubType.General);
if (CoreHelpers.SettingHasValue(generalHub?.ConnectionString))
{
_services.Add(new NotificationHubPushNotificationService(installationDeviceRepository,
globalSettings, httpContextAccessor, hubLogger));

View File

@ -20,8 +20,9 @@ public class NotificationHubPushNotificationService : IPushNotificationService
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
private NotificationHubClient _client = null;
private ILogger _logger;
private readonly List<NotificationHubClient> _clients = [];
private readonly bool _enableTracing = false;
private readonly ILogger _logger;
public NotificationHubPushNotificationService(
IInstallationDeviceRepository installationDeviceRepository,
@ -32,10 +33,18 @@ public class NotificationHubPushNotificationService : IPushNotificationService
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
_client = NotificationHubClient.CreateClientFromConnectionString(
_globalSettings.NotificationHub.ConnectionString,
_globalSettings.NotificationHub.HubName,
_globalSettings.NotificationHub.EnableSendTracing);
foreach (var hub in globalSettings.NotificationHubs)
{
var client = NotificationHubClient.CreateClientFromConnectionString(
hub.ConnectionString,
hub.HubName,
hub.EnableSendTracing);
_clients.Add(client);
_enableTracing = _enableTracing || hub.EnableSendTracing;
}
_logger = logger;
}
@ -255,16 +264,31 @@ public class NotificationHubPushNotificationService : IPushNotificationService
private async Task SendPayloadAsync(string tag, PushType type, object payload)
{
var outcome = await _client.SendTemplateNotificationAsync(
new Dictionary<string, string>
{
{ "type", ((byte)type).ToString() },
{ "payload", JsonSerializer.Serialize(payload) }
}, tag);
if (_globalSettings.NotificationHub.EnableSendTracing)
var tasks = new List<Task<NotificationOutcome>>();
foreach (var client in _clients)
{
_logger.LogInformation("Azure Notification Hub Tracking ID: {id} | {type} push notification with {success} successes and {failure} failures with a payload of {@payload} and result of {@results}",
outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results);
var task = client.SendTemplateNotificationAsync(
new Dictionary<string, string>
{
{ "type", ((byte)type).ToString() },
{ "payload", JsonSerializer.Serialize(payload) }
}, tag);
tasks.Add(task);
}
await Task.WhenAll(tasks);
if (_enableTracing)
{
for (var i = 0; i < tasks.Count; i++)
{
if (_clients[i].EnableTestSend)
{
var outcome = await tasks[i];
_logger.LogInformation("Azure Notification Hub Tracking ID: {id} | {type} push notification with {success} successes and {failure} failures with a payload of {@payload} and result of {@results}",
outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results);
}
}
}
}

View File

@ -3,6 +3,7 @@ using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Microsoft.Azure.NotificationHubs;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
@ -10,18 +11,36 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
{
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
private NotificationHubClient _client = null;
private readonly ILogger<NotificationHubPushRegistrationService> _logger;
private Dictionary<NotificationHubType, NotificationHubClient> _clients = [];
public NotificationHubPushRegistrationService(
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings)
GlobalSettings globalSettings,
ILogger<NotificationHubPushRegistrationService> logger)
{
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
_client = NotificationHubClient.CreateClientFromConnectionString(
_globalSettings.NotificationHub.ConnectionString,
_globalSettings.NotificationHub.HubName);
_logger = logger;
// Is this dirty to do in the ctor?
void addHub(NotificationHubType type)
{
var hubRegistration = globalSettings.NotificationHubs.FirstOrDefault(
h => h.HubType == type && h.EnableRegistration);
if (hubRegistration != null)
{
var client = NotificationHubClient.CreateClientFromConnectionString(
hubRegistration.ConnectionString,
hubRegistration.HubName,
hubRegistration.EnableSendTracing);
_clients.Add(type, client);
}
}
addHub(NotificationHubType.General);
addHub(NotificationHubType.iOS);
addHub(NotificationHubType.Android);
}
public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
@ -84,7 +103,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate,
userId, identifier);
await _client.CreateOrUpdateInstallationAsync(installation);
await GetClient(type).CreateOrUpdateInstallationAsync(installation);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
@ -119,11 +138,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
installation.Templates.Add(fullTemplateId, template);
}
public async Task DeleteRegistrationAsync(string deviceId)
public async Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType)
{
try
{
await _client.DeleteInstallationAsync(deviceId);
await GetClient(deviceType).DeleteInstallationAsync(deviceId);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId));
@ -135,31 +154,31 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
}
}
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}");
if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Add, $"organizationId:{organizationId}");
if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key))
{
var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
var entities = devices.Select(e => new InstallationDeviceEntity(e.Key));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove,
await PatchTagsForUserDevicesAsync(devices, UpdateOperationType.Remove,
$"organizationId:{organizationId}");
if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
if (devices.Any() && InstallationDeviceEntity.IsInstallationDeviceId(devices.First().Key))
{
var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
var entities = devices.Select(e => new InstallationDeviceEntity(e.Key));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
private async Task PatchTagsForUserDevicesAsync(IEnumerable<string> deviceIds, UpdateOperationType op,
private async Task PatchTagsForUserDevicesAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, UpdateOperationType op,
string tag)
{
if (!deviceIds.Any())
if (!devices.Any())
{
return;
}
@ -179,11 +198,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
operation.Path += $"/{tag}";
}
foreach (var id in deviceIds)
foreach (var device in devices)
{
try
{
await _client.PatchInstallationAsync(id, new List<PartialUpdateOperation> { operation });
await GetClient(device.Value).PatchInstallationAsync(device.Key, new List<PartialUpdateOperation> { operation });
}
catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found"))
{
@ -191,4 +210,54 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
}
}
}
private NotificationHubClient GetClient(DeviceType deviceType)
{
var hubType = NotificationHubType.General;
switch (deviceType)
{
case DeviceType.Android:
hubType = NotificationHubType.Android;
break;
case DeviceType.iOS:
hubType = NotificationHubType.iOS;
break;
case DeviceType.ChromeExtension:
case DeviceType.FirefoxExtension:
case DeviceType.OperaExtension:
case DeviceType.EdgeExtension:
case DeviceType.VivaldiExtension:
case DeviceType.SafariExtension:
hubType = NotificationHubType.GeneralBrowserExtension;
break;
case DeviceType.WindowsDesktop:
case DeviceType.MacOsDesktop:
case DeviceType.LinuxDesktop:
hubType = NotificationHubType.GeneralDesktop;
break;
case DeviceType.ChromeBrowser:
case DeviceType.FirefoxBrowser:
case DeviceType.OperaBrowser:
case DeviceType.EdgeBrowser:
case DeviceType.IEBrowser:
case DeviceType.UnknownBrowser:
case DeviceType.SafariBrowser:
case DeviceType.VivaldiBrowser:
hubType = NotificationHubType.GeneralWeb;
break;
default:
break;
}
if (!_clients.ContainsKey(hubType))
{
_logger.LogWarning("No hub client for '{0}'. Using general hub instead.", hubType);
hubType = NotificationHubType.General;
if (!_clients.ContainsKey(hubType))
{
throw new Exception("No general hub client found.");
}
}
return _clients[hubType];
}
}

View File

@ -38,30 +38,37 @@ public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegi
await SendAsync(HttpMethod.Post, "push/register", requestModel);
}
public async Task DeleteRegistrationAsync(string deviceId)
public async Task DeleteRegistrationAsync(string deviceId, DeviceType type)
{
await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId));
var requestModel = new PushDeviceRequestModel
{
Id = deviceId,
Type = type,
};
await SendAsync(HttpMethod.Post, "push/delete", requestModel);
}
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
public async Task AddUserRegistrationOrganizationAsync(
IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
if (!deviceIds.Any())
if (!devices.Any())
{
return;
}
var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
var requestModel = new PushUpdateRequestModel(devices, organizationId);
await SendAsync(HttpMethod.Put, "push/add-organization", requestModel);
}
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
public async Task DeleteUserRegistrationOrganizationAsync(
IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
if (!deviceIds.Any())
if (!devices.Any())
{
return;
}
var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
var requestModel = new PushUpdateRequestModel(devices, organizationId);
await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel);
}
}

View File

@ -4,7 +4,7 @@ namespace Bit.Core.Services;
public class NoopPushRegistrationService : IPushRegistrationService
{
public Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
public Task AddUserRegistrationOrganizationAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
return Task.FromResult(0);
}
@ -15,12 +15,12 @@ public class NoopPushRegistrationService : IPushRegistrationService
return Task.FromResult(0);
}
public Task DeleteRegistrationAsync(string deviceId)
public Task DeleteRegistrationAsync(string deviceId, DeviceType deviceType)
{
return Task.FromResult(0);
}
public Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
public Task DeleteUserRegistrationOrganizationAsync(IEnumerable<KeyValuePair<string, DeviceType>> devices, string organizationId)
{
return Task.FromResult(0);
}

View File

@ -1,4 +1,5 @@
using Bit.Core.Auth.Settings;
using Bit.Core.Enums;
using Bit.Core.Settings.LoggingSettings;
namespace Bit.Core.Settings;
@ -64,7 +65,7 @@ public class GlobalSettings : IGlobalSettings
public virtual SentrySettings Sentry { get; set; } = new SentrySettings();
public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings();
public virtual ILogLevelSettings MinLogLevel { get; set; } = new LogLevelSettings();
public virtual NotificationHubSettings NotificationHub { get; set; } = new NotificationHubSettings();
public virtual List<NotificationHubSettings> NotificationHubs { get; set; } = new();
public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings();
public virtual DuoSettings Duo { get; set; } = new DuoSettings();
public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings();
@ -416,12 +417,16 @@ public class GlobalSettings : IGlobalSettings
set => _connectionString = value.Trim('"');
}
public string HubName { get; set; }
/// <summary>
/// Enables TestSend on the Azure Notification Hub, which allows tracing of the request through the hub and to the platform-specific push notification service (PNS).
/// Enabling this will result in delayed responses because the Hub must wait on delivery to the PNS. This should ONLY be enabled in a non-production environment, as results are throttled.
/// </summary>
public bool EnableSendTracing { get; set; } = false;
/// <summary>
/// At least one hub configuration should have registration enabled, preferably the General hub as a safety net.
/// </summary>
public bool EnableRegistration { get; set; }
public NotificationHubType HubType { get; set; }
}
public class YubicoSettings

View File

@ -1,6 +1,7 @@
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
@ -11,16 +12,19 @@ public class NotificationHubPushRegistrationServiceTests
private readonly NotificationHubPushRegistrationService _sut;
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly ILogger<NotificationHubPushRegistrationService> _logger;
private readonly GlobalSettings _globalSettings;
public NotificationHubPushRegistrationServiceTests()
{
_installationDeviceRepository = Substitute.For<IInstallationDeviceRepository>();
_logger = Substitute.For<ILogger<NotificationHubPushRegistrationService>>();
_globalSettings = new GlobalSettings();
_sut = new NotificationHubPushRegistrationService(
_installationDeviceRepository,
_globalSettings
_globalSettings,
_logger
);
}