1
0
mirror of https://github.com/bitwarden/server.git synced 2025-02-19 02:21:21 +01:00

Merge remote-tracking branch 'origin/main' into web-push-poc

This commit is contained in:
Matt Gibson 2025-02-12 15:40:03 -08:00
commit 99528841e6
No known key found for this signature in database
GPG Key ID: 7CBCA182C13B0912
44 changed files with 1356 additions and 372 deletions

View File

@ -44,7 +44,7 @@ public class PushController : Controller
{ {
CheckUsage(); CheckUsage();
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(new PushRegistrationData(model.PushToken), Prefix(model.DeviceId), await _pushRegistrationService.CreateOrUpdateRegistrationAsync(new PushRegistrationData(model.PushToken), Prefix(model.DeviceId),
Prefix(model.UserId), Prefix(model.Identifier), model.Type); Prefix(model.UserId), Prefix(model.Identifier), model.Type, model.OrganizationIds.Select(Prefix));
} }
[HttpPost("delete")] [HttpPost("delete")]
@ -80,12 +80,12 @@ public class PushController : Controller
if (!string.IsNullOrWhiteSpace(model.UserId)) if (!string.IsNullOrWhiteSpace(model.UserId))
{ {
await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId), await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId),
model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); model.Type, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId), model.ClientType);
} }
else if (!string.IsNullOrWhiteSpace(model.OrganizationId)) else if (!string.IsNullOrWhiteSpace(model.OrganizationId))
{ {
await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId), await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId),
model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); model.Type, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId), model.ClientType);
} }
} }

View File

@ -1,6 +1,7 @@
using System.Globalization; using System.Globalization;
using Bit.Billing.Models; using Bit.Billing.Models;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -25,6 +26,7 @@ public class BitPayController : Controller
private readonly IMailService _mailService; private readonly IMailService _mailService;
private readonly IPaymentService _paymentService; private readonly IPaymentService _paymentService;
private readonly ILogger<BitPayController> _logger; private readonly ILogger<BitPayController> _logger;
private readonly IPremiumUserBillingService _premiumUserBillingService;
public BitPayController( public BitPayController(
IOptions<BillingSettings> billingSettings, IOptions<BillingSettings> billingSettings,
@ -35,7 +37,8 @@ public class BitPayController : Controller
IProviderRepository providerRepository, IProviderRepository providerRepository,
IMailService mailService, IMailService mailService,
IPaymentService paymentService, IPaymentService paymentService,
ILogger<BitPayController> logger) ILogger<BitPayController> logger,
IPremiumUserBillingService premiumUserBillingService)
{ {
_billingSettings = billingSettings?.Value; _billingSettings = billingSettings?.Value;
_bitPayClient = bitPayClient; _bitPayClient = bitPayClient;
@ -46,6 +49,7 @@ public class BitPayController : Controller
_mailService = mailService; _mailService = mailService;
_paymentService = paymentService; _paymentService = paymentService;
_logger = logger; _logger = logger;
_premiumUserBillingService = premiumUserBillingService;
} }
[HttpPost("ipn")] [HttpPost("ipn")]
@ -145,10 +149,7 @@ public class BitPayController : Controller
if (user != null) if (user != null)
{ {
billingEmail = user.BillingEmailAddress(); billingEmail = user.BillingEmailAddress();
if (await _paymentService.CreditAccountAsync(user, tx.Amount)) await _premiumUserBillingService.Credit(user, tx.Amount);
{
await _userRepository.ReplaceAsync(user);
}
} }
} }
else if (tx.ProviderId.HasValue) else if (tx.ProviderId.HasValue)

View File

@ -1,6 +1,7 @@
using System.Text; using System.Text;
using Bit.Billing.Models; using Bit.Billing.Models;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -23,6 +24,7 @@ public class PayPalController : Controller
private readonly ITransactionRepository _transactionRepository; private readonly ITransactionRepository _transactionRepository;
private readonly IUserRepository _userRepository; private readonly IUserRepository _userRepository;
private readonly IProviderRepository _providerRepository; private readonly IProviderRepository _providerRepository;
private readonly IPremiumUserBillingService _premiumUserBillingService;
public PayPalController( public PayPalController(
IOptions<BillingSettings> billingSettings, IOptions<BillingSettings> billingSettings,
@ -32,7 +34,8 @@ public class PayPalController : Controller
IPaymentService paymentService, IPaymentService paymentService,
ITransactionRepository transactionRepository, ITransactionRepository transactionRepository,
IUserRepository userRepository, IUserRepository userRepository,
IProviderRepository providerRepository) IProviderRepository providerRepository,
IPremiumUserBillingService premiumUserBillingService)
{ {
_billingSettings = billingSettings?.Value; _billingSettings = billingSettings?.Value;
_logger = logger; _logger = logger;
@ -42,6 +45,7 @@ public class PayPalController : Controller
_transactionRepository = transactionRepository; _transactionRepository = transactionRepository;
_userRepository = userRepository; _userRepository = userRepository;
_providerRepository = providerRepository; _providerRepository = providerRepository;
_premiumUserBillingService = premiumUserBillingService;
} }
[HttpPost("ipn")] [HttpPost("ipn")]
@ -257,10 +261,9 @@ public class PayPalController : Controller
{ {
var user = await _userRepository.GetByIdAsync(transaction.UserId.Value); var user = await _userRepository.GetByIdAsync(transaction.UserId.Value);
if (await _paymentService.CreditAccountAsync(user, transaction.Amount)) if (user != null)
{ {
await _userRepository.ReplaceAsync(user); await _premiumUserBillingService.Credit(user, transaction.Amount);
billingEmail = user.BillingEmailAddress(); billingEmail = user.BillingEmailAddress();
} }
} }

View File

@ -6,6 +6,8 @@ namespace Bit.Core.Billing.Services;
public interface IPremiumUserBillingService public interface IPremiumUserBillingService
{ {
Task Credit(User user, decimal amount);
/// <summary> /// <summary>
/// <para>Establishes the Stripe entities necessary for a Bitwarden <see cref="User"/> using the provided <paramref name="sale"/>.</para> /// <para>Establishes the Stripe entities necessary for a Bitwarden <see cref="User"/> using the provided <paramref name="sale"/>.</para>
/// <para> /// <para>

View File

@ -27,6 +27,57 @@ public class PremiumUserBillingService(
ISubscriberService subscriberService, ISubscriberService subscriberService,
IUserRepository userRepository) : IPremiumUserBillingService IUserRepository userRepository) : IPremiumUserBillingService
{ {
public async Task Credit(User user, decimal amount)
{
var customer = await subscriberService.GetCustomer(user);
// Negative credit represents a balance and all Stripe denomination is in cents.
var credit = (long)amount * -100;
if (customer == null)
{
var options = new CustomerCreateOptions
{
Balance = credit,
Description = user.Name,
Email = user.Email,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = user.SubscriberType(),
Value = user.SubscriberName().Length <= 30
? user.SubscriberName()
: user.SubscriberName()[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
["region"] = globalSettings.BaseServiceUri.CloudRegion,
["userId"] = user.Id.ToString()
}
};
customer = await stripeAdapter.CustomerCreateAsync(options);
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
await userRepository.ReplaceAsync(user);
}
else
{
var options = new CustomerUpdateOptions
{
Balance = customer.Balance + credit
};
await stripeAdapter.CustomerUpdateAsync(customer.Id, options);
}
}
public async Task Finalize(PremiumUserSale sale) public async Task Finalize(PremiumUserSale sale)
{ {
var (user, customerSetup, storage) = sale; var (user, customerSetup, storage) = sale;
@ -37,6 +88,12 @@ public class PremiumUserBillingService(
? await CreateCustomerAsync(user, customerSetup) ? await CreateCustomerAsync(user, customerSetup)
: await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = expand }); : await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = expand });
/*
* If the customer was previously set up with credit, which does not require a billing location,
* we need to update the customer on the fly before we start the subscription.
*/
customer = await ReconcileBillingLocationAsync(customer, customerSetup.TaxInformation);
var subscription = await CreateSubscriptionAsync(user.Id, customer, storage); var subscription = await CreateSubscriptionAsync(user.Id, customer, storage);
switch (customerSetup.TokenizedPaymentSource) switch (customerSetup.TokenizedPaymentSource)
@ -85,6 +142,11 @@ public class PremiumUserBillingService(
User user, User user,
CustomerSetup customerSetup) CustomerSetup customerSetup)
{ {
/*
* Creating a Customer via the adding of a payment method or the purchasing of a subscription requires
* an actual payment source. The only time this is not the case is when the Customer is created when the
* User purchases credit.
*/
if (customerSetup.TokenizedPaymentSource is not if (customerSetup.TokenizedPaymentSource is not
{ {
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
@ -285,4 +347,34 @@ public class PremiumUserBillingService(
return subscription; return subscription;
} }
private async Task<Customer> ReconcileBillingLocationAsync(
Customer customer,
TaxInformation taxInformation)
{
if (customer is { Address: { Country: not null and not "", PostalCode: not null and not "" } })
{
return customer;
}
var options = new CustomerUpdateOptions
{
Address = new AddressOptions
{
Line1 = taxInformation.Line1,
Line2 = taxInformation.Line2,
City = taxInformation.City,
PostalCode = taxInformation.PostalCode,
State = taxInformation.State,
Country = taxInformation.Country,
},
Expand = ["tax"],
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
return await stripeAdapter.CustomerUpdateAsync(customer.Id, options);
}
} }

View File

@ -159,21 +159,18 @@ public static class FeatureFlagKeys
public const string InlineMenuTotp = "inline-menu-totp"; public const string InlineMenuTotp = "inline-menu-totp";
public const string SelfHostLicenseRefactor = "pm-11516-self-host-license-refactor"; public const string SelfHostLicenseRefactor = "pm-11516-self-host-license-refactor";
public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration"; public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration";
public const string AuthenticatorSynciOS = "enable-authenticator-sync-ios";
public const string AuthenticatorSyncAndroid = "enable-authenticator-sync-android";
public const string AppReviewPrompt = "app-review-prompt"; public const string AppReviewPrompt = "app-review-prompt";
public const string ResellerManagedOrgAlert = "PM-15814-alert-owners-of-reseller-managed-orgs"; public const string ResellerManagedOrgAlert = "PM-15814-alert-owners-of-reseller-managed-orgs";
public const string Argon2Default = "argon2-default"; public const string Argon2Default = "argon2-default";
public const string UsePricingService = "use-pricing-service"; public const string UsePricingService = "use-pricing-service";
public const string RecordInstallationLastActivityDate = "installation-last-activity-date"; public const string RecordInstallationLastActivityDate = "installation-last-activity-date";
public const string EnablePasswordManagerSyncAndroid = "enable-password-manager-sync-android";
public const string EnablePasswordManagerSynciOS = "enable-password-manager-sync-ios";
public const string AccountDeprovisioningBanner = "pm-17120-account-deprovisioning-admin-console-banner"; public const string AccountDeprovisioningBanner = "pm-17120-account-deprovisioning-admin-console-banner";
public const string SingleTapPasskeyCreation = "single-tap-passkey-creation"; public const string SingleTapPasskeyCreation = "single-tap-passkey-creation";
public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication";
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal"; public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal";
public const string AndroidMutualTls = "mutual-tls"; public const string AndroidMutualTls = "mutual-tls";
public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias";
public const string WebPush = "web-push"; public const string WebPush = "web-push";
public static List<string> GetAllKeys() public static List<string> GetAllKeys()

View File

@ -169,6 +169,11 @@ public class CurrentContext : ICurrentContext
DeviceIdentifier = GetClaimValue(claimsDict, Claims.Device); DeviceIdentifier = GetClaimValue(claimsDict, Claims.Device);
if (Enum.TryParse(GetClaimValue(claimsDict, Claims.DeviceType), out DeviceType deviceType))
{
DeviceType = deviceType;
}
Organizations = GetOrganizations(claimsDict, orgApi); Organizations = GetOrganizations(claimsDict, orgApi);
Providers = GetProviders(claimsDict); Providers = GetProviders(claimsDict);

View File

@ -27,4 +27,6 @@ public enum PushType : byte
SyncOrganizations = 17, SyncOrganizations = 17,
SyncOrganizationStatusChanged = 18, SyncOrganizationStatusChanged = 18,
SyncOrganizationCollectionSettingChanged = 19, SyncOrganizationCollectionSettingChanged = 19,
SyncNotification = 20,
} }

View File

@ -6,6 +6,7 @@ public static class Claims
public const string SecurityStamp = "sstamp"; public const string SecurityStamp = "sstamp";
public const string Premium = "premium"; public const string Premium = "premium";
public const string Device = "device"; public const string Device = "device";
public const string DeviceType = "devicetype";
public const string OrganizationOwner = "orgowner"; public const string OrganizationOwner = "orgowner";
public const string OrganizationAdmin = "orgadmin"; public const string OrganizationAdmin = "orgadmin";

View File

@ -15,4 +15,5 @@ public class PushRegistrationRequestModel
public DeviceType Type { get; set; } public DeviceType Type { get; set; }
[Required] [Required]
public string Identifier { get; set; } public string Identifier { get; set; }
public IEnumerable<string> OrganizationIds { get; set; }
} }

View File

@ -1,18 +1,18 @@
using System.ComponentModel.DataAnnotations; #nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.Enums; using Bit.Core.Enums;
namespace Bit.Core.Models.Api; namespace Bit.Core.Models.Api;
public class PushSendRequestModel : IValidatableObject public class PushSendRequestModel : IValidatableObject
{ {
public string UserId { get; set; } public string? UserId { get; set; }
public string OrganizationId { get; set; } public string? OrganizationId { get; set; }
public string DeviceId { get; set; } public string? DeviceId { get; set; }
public string Identifier { get; set; } public string? Identifier { get; set; }
[Required] public required PushType Type { get; set; }
public PushType? Type { get; set; } public required object Payload { get; set; }
[Required] public ClientType? ClientType { get; set; }
public object Payload { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext) public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {

View File

@ -1,4 +1,5 @@
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Enums;
namespace Bit.Core.Models; namespace Bit.Core.Models;
@ -45,6 +46,22 @@ public class SyncSendPushNotification
public DateTime RevisionDate { get; set; } public DateTime RevisionDate { get; set; }
} }
#nullable enable
public class NotificationPushNotification
{
public Guid Id { get; set; }
public Priority Priority { get; set; }
public bool Global { get; set; }
public ClientType ClientType { get; set; }
public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; }
public string? Title { get; set; }
public string? Body { get; set; }
public DateTime CreationDate { get; set; }
public DateTime RevisionDate { get; set; }
}
#nullable disable
public class AuthRequestPushNotification public class AuthRequestPushNotification
{ {
public Guid UserId { get; set; } public Guid UserId { get; set; }

View File

@ -4,6 +4,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands.Interfaces; using Bit.Core.NotificationCenter.Commands.Interfaces;
using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories; using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
@ -14,14 +15,17 @@ public class CreateNotificationCommand : ICreateNotificationCommand
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IAuthorizationService _authorizationService; private readonly IAuthorizationService _authorizationService;
private readonly INotificationRepository _notificationRepository; private readonly INotificationRepository _notificationRepository;
private readonly IPushNotificationService _pushNotificationService;
public CreateNotificationCommand(ICurrentContext currentContext, public CreateNotificationCommand(ICurrentContext currentContext,
IAuthorizationService authorizationService, IAuthorizationService authorizationService,
INotificationRepository notificationRepository) INotificationRepository notificationRepository,
IPushNotificationService pushNotificationService)
{ {
_currentContext = currentContext; _currentContext = currentContext;
_authorizationService = authorizationService; _authorizationService = authorizationService;
_notificationRepository = notificationRepository; _notificationRepository = notificationRepository;
_pushNotificationService = pushNotificationService;
} }
public async Task<Notification> CreateAsync(Notification notification) public async Task<Notification> CreateAsync(Notification notification)
@ -31,6 +35,10 @@ public class CreateNotificationCommand : ICreateNotificationCommand
await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notification, await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notification,
NotificationOperations.Create); NotificationOperations.Create);
return await _notificationRepository.CreateAsync(notification); var newNotification = await _notificationRepository.CreateAsync(notification);
await _pushNotificationService.PushNotificationAsync(newNotification);
return newNotification;
} }
} }

View File

@ -15,9 +15,8 @@ public class Notification : ITableObject<Guid>
public ClientType ClientType { get; set; } public ClientType ClientType { get; set; }
public Guid? UserId { get; set; } public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; } public Guid? OrganizationId { get; set; }
[MaxLength(256)] [MaxLength(256)] public string? Title { get; set; }
public string? Title { get; set; } [MaxLength(3000)] public string? Body { get; set; }
public string? Body { get; set; }
public DateTime CreationDate { get; set; } public DateTime CreationDate { get; set; }
public DateTime RevisionDate { get; set; } public DateTime RevisionDate { get; set; }
public Guid? TaskId { get; set; } public Guid? TaskId { get; set; }

View File

@ -5,6 +5,6 @@ namespace Bit.Core.NotificationHub;
public interface INotificationHubPool public interface INotificationHubPool
{ {
NotificationHubConnection ConnectionFor(Guid comb); NotificationHubConnection ConnectionFor(Guid comb);
NotificationHubClient ClientFor(Guid comb); INotificationHubClient ClientFor(Guid comb);
INotificationHubProxy AllClients { get; } INotificationHubProxy AllClients { get; }
} }

View File

@ -43,7 +43,7 @@ public class NotificationHubPool : INotificationHubPool
/// <param name="comb"></param> /// <param name="comb"></param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="InvalidOperationException">Thrown when no notification hub is found for a given comb.</exception> /// <exception cref="InvalidOperationException">Thrown when no notification hub is found for a given comb.</exception>
public NotificationHubClient ClientFor(Guid comb) public INotificationHubClient ClientFor(Guid comb)
{ {
var resolvedConnection = ConnectionFor(comb); var resolvedConnection = ConnectionFor(comb);
return resolvedConnection.HubClient; return resolvedConnection.HubClient;

View File

@ -12,6 +12,7 @@ using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Notification = Bit.Core.NotificationCenter.Entities.Notification;
namespace Bit.Core.NotificationHub; namespace Bit.Core.NotificationHub;
@ -135,11 +136,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false)
{ {
var message = new UserPushNotification var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow };
{
UserId = userId,
Date = DateTime.UtcNow
};
await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext); await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext);
} }
@ -184,31 +181,59 @@ public class NotificationHubPushNotificationService : IPushNotificationService
await PushAuthRequestAsync(authRequest, PushType.AuthRequestResponse); await PushAuthRequestAsync(authRequest, PushType.AuthRequestResponse);
} }
public async Task PushNotificationAsync(Notification notification)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate
};
if (notification.UserId.HasValue)
{
await SendPayloadToUserAsync(notification.UserId.Value, PushType.SyncNotification, message, true,
notification.ClientType);
}
else if (notification.OrganizationId.HasValue)
{
await SendPayloadToOrganizationAsync(notification.OrganizationId.Value, PushType.SyncNotification, message,
true, notification.ClientType);
}
}
private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type) private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type)
{ {
var message = new AuthRequestPushNotification var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId };
{
Id = authRequest.Id,
UserId = authRequest.UserId
};
await SendPayloadToUserAsync(authRequest.UserId, type, message, true); await SendPayloadToUserAsync(authRequest.UserId, type, message, true);
} }
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext,
ClientType? clientType = null)
{ {
await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext),
clientType: clientType);
} }
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload,
bool excludeCurrentContext, ClientType? clientType = null)
{ {
await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); await SendPayloadToOrganizationAsync(orgId.ToString(), type, payload,
GetContextIdentifier(excludeCurrentContext), clientType: clientType);
} }
public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier); var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier, clientType);
await SendPayloadAsync(tag, type, payload); await SendPayloadAsync(tag, type, payload);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{ {
@ -217,9 +242,9 @@ public class NotificationHubPushNotificationService : IPushNotificationService
} }
public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier); var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier, clientType);
await SendPayloadAsync(tag, type, payload); await SendPayloadAsync(tag, type, payload);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{ {
@ -259,18 +284,23 @@ public class NotificationHubPushNotificationService : IPushNotificationService
return null; return null;
} }
var currentContext = _httpContextAccessor?.HttpContext?. var currentContext =
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier; return currentContext?.DeviceIdentifier;
} }
private string BuildTag(string tag, string identifier) private string BuildTag(string tag, string identifier, ClientType? clientType)
{ {
if (!string.IsNullOrWhiteSpace(identifier)) if (!string.IsNullOrWhiteSpace(identifier))
{ {
tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}"; tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}";
} }
if (clientType.HasValue && clientType.Value != ClientType.All)
{
tag += $" && clientType:{clientType}";
}
return $"({tag})"; return $"({tag})";
} }
@ -279,8 +309,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
var results = await _notificationHubPool.AllClients.SendTemplateNotificationAsync( var results = await _notificationHubPool.AllClients.SendTemplateNotificationAsync(
new Dictionary<string, string> new Dictionary<string, string>
{ {
{ "type", ((byte)type).ToString() }, { "type", ((byte)type).ToString() }, { "payload", JsonSerializer.Serialize(payload) }
{ "payload", JsonSerializer.Serialize(payload) }
}, tag); }, tag);
if (_enableTracing) if (_enableTracing)
@ -291,7 +320,9 @@ public class NotificationHubPushNotificationService : IPushNotificationService
{ {
continue; continue;
} }
_logger.LogInformation("Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}",
_logger.LogInformation(
"Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}",
outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results); outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results);
} }
} }

View File

@ -7,6 +7,7 @@ using Bit.Core.Enums;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
using Bit.Core.Platform.Push; using Bit.Core.Platform.Push;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Microsoft.Azure.NotificationHubs; using Microsoft.Azure.NotificationHubs;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
@ -36,16 +37,19 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
} }
public async Task CreateOrUpdateRegistrationAsync(PushRegistrationData data, string deviceId, string userId, public async Task CreateOrUpdateRegistrationAsync(PushRegistrationData data, string deviceId, string userId,
string identifier, DeviceType type) string identifier, DeviceType type, IEnumerable<string> organizationIds)
{ {
var orgIds = organizationIds.ToList();
var clientType = DeviceTypes.ToClientType(type);
var installation = new Installation var installation = new Installation
{ {
InstallationId = deviceId, InstallationId = deviceId,
PushChannel = data.Token, PushChannel = data.Token,
Tags = new List<string> Tags = new List<string>
{ {
$"userId:{userId}" $"userId:{userId}",
}, $"clientType:{clientType}"
}.Concat(orgIds.Select(organizationId => $"organizationId:{organizationId}")).ToList(),
Templates = new Dictionary<string, InstallationTemplate>() Templates = new Dictionary<string, InstallationTemplate>()
}; };
@ -56,11 +60,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
if (data.Token != null) if (data.Token != null)
{ {
await CreateOrUpdateMobileRegistrationAsync(installation, userId, identifier, type); await CreateOrUpdateMobileRegistrationAsync(installation, userId, identifier, clientType, orgIds, type);
} }
else if (data.WebPush != null) else if (data.WebPush != null)
{ {
await CreateOrUpdateWebRegistrationAsync(data.WebPush.Value.Endpoint, data.WebPush.Value.P256dh, data.WebPush.Value.Auth, installation, userId, identifier, type); await CreateOrUpdateWebRegistrationAsync(data.WebPush.Value.Endpoint, data.WebPush.Value.P256dh, data.WebPush.Value.Auth, installation, userId, identifier, clientType, orgIds);
} }
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
@ -70,7 +74,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
} }
private async Task CreateOrUpdateMobileRegistrationAsync(Installation installation, string userId, private async Task CreateOrUpdateMobileRegistrationAsync(Installation installation, string userId,
string identifier, DeviceType type) string identifier, ClientType clientType, List<string> organizationIds, DeviceType type)
{ {
if (string.IsNullOrWhiteSpace(installation.PushChannel)) if (string.IsNullOrWhiteSpace(installation.PushChannel))
{ {
@ -82,41 +86,41 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
case DeviceType.Android: case DeviceType.Android:
installation.Templates.Add(BuildInstallationTemplate("payload", installation.Templates.Add(BuildInstallationTemplate("payload",
"{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}", "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("message", installation.Templates.Add(BuildInstallationTemplate("message",
"{\"data\":{\"data\":{\"type\":\"#(type)\"}," + "{\"data\":{\"data\":{\"type\":\"#(type)\"}," +
"\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}", "\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("badgeMessage", installation.Templates.Add(BuildInstallationTemplate("badgeMessage",
"{\"data\":{\"data\":{\"type\":\"#(type)\"}," + "{\"data\":{\"data\":{\"type\":\"#(type)\"}," +
"\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}", "\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Platform = NotificationPlatform.FcmV1; installation.Platform = NotificationPlatform.FcmV1;
break; break;
case DeviceType.iOS: case DeviceType.iOS:
installation.Templates.Add(BuildInstallationTemplate("payload", installation.Templates.Add(BuildInstallationTemplate("payload",
"{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," + "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," +
"\"aps\":{\"content-available\":1}}", "\"aps\":{\"content-available\":1}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("message", installation.Templates.Add(BuildInstallationTemplate("message",
"{\"data\":{\"type\":\"#(type)\"}," + "{\"data\":{\"type\":\"#(type)\"}," +
"\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}", userId, identifier)); "\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}", userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("badgeMessage", installation.Templates.Add(BuildInstallationTemplate("badgeMessage",
"{\"data\":{\"type\":\"#(type)\"}," + "{\"data\":{\"type\":\"#(type)\"}," +
"\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}", "\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Platform = NotificationPlatform.Apns; installation.Platform = NotificationPlatform.Apns;
break; break;
case DeviceType.AndroidAmazon: case DeviceType.AndroidAmazon:
installation.Templates.Add(BuildInstallationTemplate("payload", installation.Templates.Add(BuildInstallationTemplate("payload",
"{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}", "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("message", installation.Templates.Add(BuildInstallationTemplate("message",
"{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}", "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("badgeMessage", installation.Templates.Add(BuildInstallationTemplate("badgeMessage",
"{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}", "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Platform = NotificationPlatform.Adm; installation.Platform = NotificationPlatform.Adm;
break; break;
@ -128,7 +132,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
} }
private async Task CreateOrUpdateWebRegistrationAsync(string endpoint, string p256dh, string auth, Installation installation, string userId, private async Task CreateOrUpdateWebRegistrationAsync(string endpoint, string p256dh, string auth, Installation installation, string userId,
string identifier, DeviceType type) string identifier, ClientType clientType, List<string> organizationIds)
{ {
// The Azure SDK is currently lacking support for web push registrations. // The Azure SDK is currently lacking support for web push registrations.
// We need to use the REST API directly. // We need to use the REST API directly.
@ -140,13 +144,13 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
installation.Templates.Add(BuildInstallationTemplate("payload", installation.Templates.Add(BuildInstallationTemplate("payload",
"{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}", "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("message", installation.Templates.Add(BuildInstallationTemplate("message",
"{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}", "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
installation.Templates.Add(BuildInstallationTemplate("badgeMessage", installation.Templates.Add(BuildInstallationTemplate("badgeMessage",
"{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}", "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}",
userId, identifier)); userId, identifier, clientType, organizationIds));
var content = new var content = new
{ {
@ -178,7 +182,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
} }
private KeyValuePair<string, InstallationTemplate> BuildInstallationTemplate(string templateId, [StringSyntax(StringSyntaxAttribute.Json)] string templateBody, private KeyValuePair<string, InstallationTemplate> BuildInstallationTemplate(string templateId, [StringSyntax(StringSyntaxAttribute.Json)] string templateBody,
string userId, string identifier) string userId, string identifier, ClientType clientType, List<string> organizationIds)
{ {
var fullTemplateId = $"template:{templateId}"; var fullTemplateId = $"template:{templateId}";
@ -187,8 +191,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
Body = templateBody, Body = templateBody,
Tags = new List<string> Tags = new List<string>
{ {
fullTemplateId, fullTemplateId, $"{fullTemplateId}_userId:{userId}", $"clientType:{clientType}"
$"{fullTemplateId}_userId:{userId}"
} }
}; };
@ -197,6 +200,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}"); template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}");
} }
foreach (var organizationId in organizationIds)
{
template.Tags.Add($"organizationId:{organizationId}");
}
return new KeyValuePair<string, InstallationTemplate>(fullTemplateId, template); return new KeyValuePair<string, InstallationTemplate>(fullTemplateId, template);
} }
@ -273,7 +281,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService
} }
} }
private NotificationHubClient ClientFor(Guid deviceId) private INotificationHubClient ClientFor(Guid deviceId)
{ {
return _notificationHubPool.ClientFor(deviceId); return _notificationHubPool.ClientFor(deviceId);
} }

View File

@ -5,26 +5,25 @@ using Bit.Core.Auth.Entities;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models; using Bit.Core.Models;
using Bit.Core.Settings; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Platform.Push.Internal; namespace Bit.Core.Platform.Push.Internal;
public class AzureQueuePushNotificationService : IPushNotificationService public class AzureQueuePushNotificationService : IPushNotificationService
{ {
private readonly QueueClient _queueClient; private readonly QueueClient _queueClient;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor; private readonly IHttpContextAccessor _httpContextAccessor;
public AzureQueuePushNotificationService( public AzureQueuePushNotificationService(
GlobalSettings globalSettings, [FromKeyedServices("notifications")] QueueClient queueClient,
IHttpContextAccessor httpContextAccessor) IHttpContextAccessor httpContextAccessor)
{ {
_queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"); _queueClient = queueClient;
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor; _httpContextAccessor = httpContextAccessor;
} }
@ -129,11 +128,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false)
{ {
var message = new UserPushNotification var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow };
{
UserId = userId,
Date = DateTime.UtcNow
};
await SendMessageAsync(type, message, excludeCurrentContext); await SendMessageAsync(type, message, excludeCurrentContext);
} }
@ -150,11 +145,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type) private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type)
{ {
var message = new AuthRequestPushNotification var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId };
{
Id = authRequest.Id,
UserId = authRequest.UserId
};
await SendMessageAsync(type, message, true); await SendMessageAsync(type, message, true);
} }
@ -174,6 +165,25 @@ public class AzureQueuePushNotificationService : IPushNotificationService
await PushSendAsync(send, PushType.SyncSendDelete); await PushSendAsync(send, PushType.SyncSendDelete);
} }
public async Task PushNotificationAsync(Notification notification)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate
};
await SendMessageAsync(PushType.SyncNotification, message, true);
}
private async Task PushSendAsync(Send send, PushType type) private async Task PushSendAsync(Send send, PushType type)
{ {
if (send.UserId.HasValue) if (send.UserId.HasValue)
@ -204,20 +214,20 @@ public class AzureQueuePushNotificationService : IPushNotificationService
return null; return null;
} }
var currentContext = _httpContextAccessor?.HttpContext?. var currentContext =
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier; return currentContext?.DeviceIdentifier;
} }
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
// Noop // Noop
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
// Noop // Noop
return Task.FromResult(0); return Task.FromResult(0);

View File

@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
@ -23,11 +24,13 @@ public interface IPushNotificationService
Task PushSyncSendCreateAsync(Send send); Task PushSyncSendCreateAsync(Send send);
Task PushSyncSendUpdateAsync(Send send); Task PushSyncSendUpdateAsync(Send send);
Task PushSyncSendDeleteAsync(Send send); Task PushSyncSendDeleteAsync(Send send);
Task PushNotificationAsync(Notification notification);
Task PushAuthRequestAsync(AuthRequest authRequest); Task PushAuthRequestAsync(AuthRequest authRequest);
Task PushAuthRequestResponseAsync(AuthRequest authRequest); Task PushAuthRequestResponseAsync(AuthRequest authRequest);
Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, string deviceId = null);
Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null);
Task PushSyncOrganizationStatusAsync(Organization organization); Task PushSyncOrganizationStatusAsync(Organization organization);
Task PushSyncOrganizationCollectionManagementSettingsAsync(Organization organization); Task PushSyncOrganizationCollectionManagementSettingsAsync(Organization organization);
Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null);
Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null);
} }

View File

@ -5,7 +5,7 @@ namespace Bit.Core.Platform.Push;
public interface IPushRegistrationService public interface IPushRegistrationService
{ {
Task CreateOrUpdateRegistrationAsync(PushRegistrationData data, string deviceId, string userId, string identifier, DeviceType type); Task CreateOrUpdateRegistrationAsync(PushRegistrationData data, string deviceId, string userId, string identifier, DeviceType type, IEnumerable<string> organizationIds);
Task DeleteRegistrationAsync(string deviceId); Task DeleteRegistrationAsync(string deviceId);
Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId); Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId);
Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId); Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId);

View File

@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
@ -131,20 +132,6 @@ public class MultiServicePushNotificationService : IPushNotificationService
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId));
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId));
return Task.FromResult(0);
}
public Task PushSyncOrganizationStatusAsync(Organization organization) public Task PushSyncOrganizationStatusAsync(Organization organization)
{ {
PushToServices((s) => s.PushSyncOrganizationStatusAsync(organization)); PushToServices((s) => s.PushSyncOrganizationStatusAsync(organization));
@ -157,6 +144,26 @@ public class MultiServicePushNotificationService : IPushNotificationService
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task PushNotificationAsync(Notification notification)
{
PushToServices((s) => s.PushNotificationAsync(notification));
return Task.CompletedTask;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
{
PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType));
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
{
PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId, clientType));
return Task.FromResult(0);
}
private void PushToServices(Func<IPushNotificationService, Task> pushFunc) private void PushToServices(Func<IPushNotificationService, Task> pushFunc)
{ {
if (_services != null) if (_services != null)

View File

@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
@ -84,7 +85,7 @@ public class NoopPushNotificationService : IPushNotificationService
} }
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
return Task.FromResult(0); return Task.FromResult(0);
} }
@ -107,8 +108,10 @@ public class NoopPushNotificationService : IPushNotificationService
} }
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task PushNotificationAsync(Notification notification) => Task.CompletedTask;
} }

View File

@ -11,7 +11,7 @@ public class NoopPushRegistrationService : IPushRegistrationService
} }
public Task CreateOrUpdateRegistrationAsync(PushRegistrationData pushRegistrationData, string deviceId, string userId, public Task CreateOrUpdateRegistrationAsync(PushRegistrationData pushRegistrationData, string deviceId, string userId,
string identifier, DeviceType type) string identifier, DeviceType type, IEnumerable<string> organizationIds)
{ {
return Task.FromResult(0); return Task.FromResult(0);
} }

View File

@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models; using Bit.Core.Models;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
@ -183,6 +184,25 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
await PushSendAsync(send, PushType.SyncSendDelete); await PushSendAsync(send, PushType.SyncSendDelete);
} }
public async Task PushNotificationAsync(Notification notification)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate
};
await SendMessageAsync(PushType.SyncNotification, message, true);
}
private async Task PushSendAsync(Send send, PushType type) private async Task PushSendAsync(Send send, PushType type)
{ {
if (send.UserId.HasValue) if (send.UserId.HasValue)
@ -212,20 +232,20 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
return null; return null;
} }
var currentContext = _httpContextAccessor?.HttpContext?. var currentContext =
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier; return currentContext?.DeviceIdentifier;
} }
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
// Noop // Noop
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null) string deviceId = null, ClientType? clientType = null)
{ {
// Noop // Noop
return Task.FromResult(0); return Task.FromResult(0);

View File

@ -5,6 +5,7 @@ using Bit.Core.Enums;
using Bit.Core.IdentityServer; using Bit.Core.IdentityServer;
using Bit.Core.Models; using Bit.Core.Models;
using Bit.Core.Models.Api; using Bit.Core.Models.Api;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
@ -138,11 +139,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false)
{ {
var message = new UserPushNotification var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow };
{
UserId = userId,
Date = DateTime.UtcNow
};
await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext); await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext);
} }
@ -189,69 +186,37 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type) private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type)
{ {
var message = new AuthRequestPushNotification var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId };
{
Id = authRequest.Id,
UserId = authRequest.UserId
};
await SendPayloadToUserAsync(authRequest.UserId, type, message, true); await SendPayloadToUserAsync(authRequest.UserId, type, message, true);
} }
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) public async Task PushNotificationAsync(Notification notification)
{ {
var request = new PushSendRequestModel var message = new NotificationPushNotification
{ {
UserId = userId.ToString(), Id = notification.Id,
Type = type, Priority = notification.Priority,
Payload = payload Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate
}; };
await AddCurrentContextAsync(request, excludeCurrentContext); if (notification.UserId.HasValue)
await SendAsync(HttpMethod.Post, "push/send", request); {
await SendPayloadToUserAsync(notification.UserId.Value, PushType.SyncNotification, message, true,
notification.ClientType);
} }
else if (notification.OrganizationId.HasValue)
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext)
{ {
var request = new PushSendRequestModel await SendPayloadToOrganizationAsync(notification.OrganizationId.Value, PushType.SyncNotification, message,
{ true, notification.ClientType);
OrganizationId = orgId.ToString(),
Type = type,
Payload = payload
};
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
} }
private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier)
{
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier))
{
var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier);
if (device != null)
{
request.DeviceId = device.Id.ToString();
}
if (addIdentifier)
{
request.Identifier = currentContext.DeviceIdentifier;
}
}
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
throw new NotImplementedException();
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
throw new NotImplementedException();
} }
public async Task PushSyncOrganizationStatusAsync(Organization organization) public async Task PushSyncOrganizationStatusAsync(Organization organization)
@ -278,4 +243,65 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
}, },
false false
); );
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext,
ClientType? clientType = null)
{
var request = new PushSendRequestModel
{
UserId = userId.ToString(),
Type = type,
Payload = payload,
ClientType = clientType
};
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
}
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload,
bool excludeCurrentContext, ClientType? clientType = null)
{
var request = new PushSendRequestModel
{
OrganizationId = orgId.ToString(),
Type = type,
Payload = payload,
ClientType = clientType
};
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
}
private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier)
{
var currentContext =
_httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier))
{
var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier);
if (device != null)
{
request.DeviceId = device.Id.ToString();
}
if (addIdentifier)
{
request.Identifier = currentContext.DeviceIdentifier;
}
}
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
{
throw new NotImplementedException();
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
{
throw new NotImplementedException();
}
} }

View File

@ -26,7 +26,7 @@ public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegi
} }
public async Task CreateOrUpdateRegistrationAsync(PushRegistrationData pushData, string deviceId, string userId, public async Task CreateOrUpdateRegistrationAsync(PushRegistrationData pushData, string deviceId, string userId,
string identifier, DeviceType type) string identifier, DeviceType type, IEnumerable<string> organizationIds)
{ {
var requestModel = new PushRegistrationRequestModel var requestModel = new PushRegistrationRequestModel
{ {
@ -34,7 +34,8 @@ public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegi
Identifier = identifier, Identifier = identifier,
PushToken = pushData.Token, PushToken = pushData.Token,
Type = type, Type = type,
UserId = userId UserId = userId,
OrganizationIds = organizationIds
}; };
await SendAsync(HttpMethod.Post, "push/register", requestModel); await SendAsync(HttpMethod.Post, "push/register", requestModel);
} }

View File

@ -1,6 +1,7 @@
using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Utilities; using Bit.Core.Auth.Utilities;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.NotificationHub; using Bit.Core.NotificationHub;
using Bit.Core.Platform.Push; using Bit.Core.Platform.Push;
@ -12,34 +13,31 @@ public class DeviceService : IDeviceService
{ {
private readonly IDeviceRepository _deviceRepository; private readonly IDeviceRepository _deviceRepository;
private readonly IPushRegistrationService _pushRegistrationService; private readonly IPushRegistrationService _pushRegistrationService;
private readonly IOrganizationUserRepository _organizationUserRepository;
public DeviceService( public DeviceService(
IDeviceRepository deviceRepository, IDeviceRepository deviceRepository,
IPushRegistrationService pushRegistrationService) IPushRegistrationService pushRegistrationService,
IOrganizationUserRepository organizationUserRepository)
{ {
_deviceRepository = deviceRepository; _deviceRepository = deviceRepository;
_pushRegistrationService = pushRegistrationService; _pushRegistrationService = pushRegistrationService;
_organizationUserRepository = organizationUserRepository;
} }
public async Task SaveAsync(WebPushRegistrationData webPush, Device device) public async Task SaveAsync(WebPushRegistrationData webPush, Device device)
{ {
if (device.Id == default(Guid)) await SaveAsync(new PushRegistrationData(webPush.Endpoint, webPush.P256dh, webPush.Auth), device);
{
await _deviceRepository.CreateAsync(device);
}
else
{
device.RevisionDate = DateTime.UtcNow;
await _deviceRepository.ReplaceAsync(device);
}
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(new NotificationHub.PushRegistrationData(webPush.Endpoint, webPush.P256dh, webPush.Auth), device.Id.ToString(),
device.UserId.ToString(), device.Identifier, device.Type);
} }
public async Task SaveAsync(Device device) public async Task SaveAsync(Device device)
{ {
if (device.Id == default(Guid)) await SaveAsync(new PushRegistrationData(device.PushToken), device);
}
private async Task SaveAsync(PushRegistrationData data, Device device)
{
if (device.Id == default)
{ {
await _deviceRepository.CreateAsync(device); await _deviceRepository.CreateAsync(device);
} }
@ -49,8 +47,14 @@ public class DeviceService : IDeviceService
await _deviceRepository.ReplaceAsync(device); await _deviceRepository.ReplaceAsync(device);
} }
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(new NotificationHub.PushRegistrationData(device.PushToken), device.Id.ToString(), var organizationIdsString =
device.UserId.ToString(), device.Identifier, device.Type); (await _organizationUserRepository.GetManyDetailsByUserAsync(device.UserId,
OrganizationUserStatusType.Confirmed))
.Select(ou => ou.OrganizationId.ToString());
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(data, device.Id.ToString(),
device.UserId.ToString(), device.Identifier, device.Type, organizationIdsString);
} }
public async Task ClearTokenAsync(Device device) public async Task ClearTokenAsync(Device device)

View File

@ -18,6 +18,7 @@ public class ApiResources
Claims.SecurityStamp, Claims.SecurityStamp,
Claims.Premium, Claims.Premium,
Claims.Device, Claims.Device,
Claims.DeviceType,
Claims.OrganizationOwner, Claims.OrganizationOwner,
Claims.OrganizationAdmin, Claims.OrganizationAdmin,
Claims.OrganizationUser, Claims.OrganizationUser,

View File

@ -210,6 +210,7 @@ public abstract class BaseRequestValidator<T> where T : class
if (device != null) if (device != null)
{ {
claims.Add(new Claim(Claims.Device, device.Identifier)); claims.Add(new Claim(Claims.Device, device.Identifier));
claims.Add(new Claim(Claims.DeviceType, device.Type.ToString()));
} }
var customResponse = new Dictionary<string, object>(); var customResponse = new Dictionary<string, object>();

View File

@ -10,6 +10,8 @@ public static class HubHelpers
private static JsonSerializerOptions _deserializerOptions = private static JsonSerializerOptions _deserializerOptions =
new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; new JsonSerializerOptions { PropertyNameCaseInsensitive = true };
private static readonly string _receiveMessageMethod = "ReceiveMessage";
public static async Task SendNotificationToHubAsync( public static async Task SendNotificationToHubAsync(
string notificationJson, string notificationJson,
IHubContext<NotificationsHub> hubContext, IHubContext<NotificationsHub> hubContext,
@ -18,7 +20,8 @@ public static class HubHelpers
CancellationToken cancellationToken = default(CancellationToken) CancellationToken cancellationToken = default(CancellationToken)
) )
{ {
var notification = JsonSerializer.Deserialize<PushNotificationData<object>>(notificationJson, _deserializerOptions); var notification =
JsonSerializer.Deserialize<PushNotificationData<object>>(notificationJson, _deserializerOptions);
logger.LogInformation("Sending notification: {NotificationType}", notification.Type); logger.LogInformation("Sending notification: {NotificationType}", notification.Type);
switch (notification.Type) switch (notification.Type)
{ {
@ -32,14 +35,15 @@ public static class HubHelpers
if (cipherNotification.Payload.UserId.HasValue) if (cipherNotification.Payload.UserId.HasValue)
{ {
await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", cipherNotification, cancellationToken); .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken);
} }
else if (cipherNotification.Payload.OrganizationId.HasValue) else if (cipherNotification.Payload.OrganizationId.HasValue)
{ {
await hubContext.Clients.Group( await hubContext.Clients
$"Organization_{cipherNotification.Payload.OrganizationId}") .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value))
.SendAsync("ReceiveMessage", cipherNotification, cancellationToken); .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken);
} }
break; break;
case PushType.SyncFolderUpdate: case PushType.SyncFolderUpdate:
case PushType.SyncFolderCreate: case PushType.SyncFolderCreate:
@ -48,7 +52,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<SyncFolderPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<SyncFolderPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) await hubContext.Clients.User(folderNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", folderNotification, cancellationToken); .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken);
break; break;
case PushType.SyncCiphers: case PushType.SyncCiphers:
case PushType.SyncVault: case PushType.SyncVault:
@ -60,7 +64,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<UserPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<UserPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) await hubContext.Clients.User(userNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", userNotification, cancellationToken); .SendAsync(_receiveMessageMethod, userNotification, cancellationToken);
break; break;
case PushType.SyncSendCreate: case PushType.SyncSendCreate:
case PushType.SyncSendUpdate: case PushType.SyncSendUpdate:
@ -69,7 +73,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<SyncSendPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<SyncSendPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) await hubContext.Clients.User(sendNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", sendNotification, cancellationToken); .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken);
break; break;
case PushType.AuthRequestResponse: case PushType.AuthRequestResponse:
var authRequestResponseNotification = var authRequestResponseNotification =
@ -83,7 +87,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<AuthRequestPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<AuthRequestPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", authRequestNotification, cancellationToken); .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken);
break; break;
case PushType.SyncOrganizationStatusChanged: case PushType.SyncOrganizationStatusChanged:
var orgStatusNotification = var orgStatusNotification =
@ -99,6 +103,32 @@ public static class HubHelpers
await hubContext.Clients.Group($"Organization_{organizationCollectionSettingsChangedNotification.Payload.OrganizationId}") await hubContext.Clients.Group($"Organization_{organizationCollectionSettingsChangedNotification.Payload.OrganizationId}")
.SendAsync("ReceiveMessage", organizationCollectionSettingsChangedNotification, cancellationToken); .SendAsync("ReceiveMessage", organizationCollectionSettingsChangedNotification, cancellationToken);
break; break;
case PushType.SyncNotification:
var syncNotification =
JsonSerializer.Deserialize<PushNotificationData<NotificationPushNotification>>(
notificationJson, _deserializerOptions);
if (syncNotification.Payload.UserId.HasValue)
{
if (syncNotification.Payload.ClientType == ClientType.All)
{
await hubContext.Clients.User(syncNotification.Payload.UserId.ToString())
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
else
{
await hubContext.Clients.Group(NotificationsHub.GetUserGroup(
syncNotification.Payload.UserId.Value, syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
}
else if (syncNotification.Payload.OrganizationId.HasValue)
{
await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(
syncNotification.Payload.OrganizationId.Value, syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
break;
default: default:
break; break;
} }

View File

@ -1,5 +1,7 @@
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
namespace Bit.Notifications; namespace Bit.Notifications;
@ -20,13 +22,25 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
{ {
var currentContext = new CurrentContext(null, null); var currentContext = new CurrentContext(null, null);
await currentContext.BuildAsync(Context.User, _globalSettings); await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
if (clientType != ClientType.All && currentContext.UserId.HasValue)
{
await Groups.AddToGroupAsync(Context.ConnectionId, GetUserGroup(currentContext.UserId.Value, clientType));
}
if (currentContext.Organizations != null) if (currentContext.Organizations != null)
{ {
foreach (var org in currentContext.Organizations) foreach (var org in currentContext.Organizations)
{ {
await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id));
if (clientType != ClientType.All)
{
await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType));
} }
} }
}
_connectionCounter.Increment(); _connectionCounter.Increment();
await base.OnConnectedAsync(); await base.OnConnectedAsync();
} }
@ -35,14 +49,39 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
{ {
var currentContext = new CurrentContext(null, null); var currentContext = new CurrentContext(null, null);
await currentContext.BuildAsync(Context.User, _globalSettings); await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
if (clientType != ClientType.All && currentContext.UserId.HasValue)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId,
GetUserGroup(currentContext.UserId.Value, clientType));
}
if (currentContext.Organizations != null) if (currentContext.Organizations != null)
{ {
foreach (var org in currentContext.Organizations) foreach (var org in currentContext.Organizations)
{ {
await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id));
if (clientType != ClientType.All)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType));
} }
} }
}
_connectionCounter.Decrement(); _connectionCounter.Decrement();
await base.OnDisconnectedAsync(exception); await base.OnDisconnectedAsync(exception);
} }
public static string GetUserGroup(Guid userId, ClientType clientType)
{
return $"UserClientType_{userId}_{clientType}";
}
public static string GetOrganizationGroup(Guid organizationId, ClientType? clientType = null)
{
return clientType is null or ClientType.All
? $"Organization_{organizationId}"
: $"OrganizationClientType_{organizationId}_{clientType}";
}
} }

View File

@ -3,6 +3,7 @@ using System.Reflection;
using System.Security.Claims; using System.Security.Claims;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using AspNetCoreRateLimit; using AspNetCoreRateLimit;
using Azure.Storage.Queues;
using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Business.Tokenables;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
@ -306,7 +307,10 @@ public static class ServiceCollectionExtensions
services.AddKeyedSingleton<IPushNotificationService, NotificationHubPushNotificationService>("implementation"); services.AddKeyedSingleton<IPushNotificationService, NotificationHubPushNotificationService>("implementation");
if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString))
{ {
services.AddKeyedSingleton<IPushNotificationService, AzureQueuePushNotificationService>("implementation"); services.AddKeyedSingleton("notifications",
(_, _) => new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"));
services.AddKeyedSingleton<IPushNotificationService, AzureQueuePushNotificationService>(
"implementation");
} }
} }

View File

@ -133,12 +133,10 @@ public class NotificationsControllerTests : IClassFixture<ApiApplicationFactory>
[InlineData(null, null, "2", 10)] [InlineData(null, null, "2", 10)]
[InlineData(10, null, "2", 10)] [InlineData(10, null, "2", 10)]
[InlineData(10, 2, "3", 10)] [InlineData(10, 2, "3", 10)]
[InlineData(10, 3, null, 0)] [InlineData(10, 3, null, 4)]
[InlineData(15, null, "2", 15)] [InlineData(24, null, "2", 24)]
[InlineData(15, 2, null, 5)] [InlineData(24, 2, null, 0)]
[InlineData(20, null, "2", 20)] [InlineData(1000, null, null, 24)]
[InlineData(20, 2, null, 0)]
[InlineData(1000, null, null, 20)]
public async Task ListAsync_PaginationFilter_ReturnsNextPageOfNotificationsCorrectOrder( public async Task ListAsync_PaginationFilter_ReturnsNextPageOfNotificationsCorrectOrder(
int? pageSize, int? pageNumber, string? expectedContinuationToken, int expectedCount) int? pageSize, int? pageNumber, string? expectedContinuationToken, int expectedCount)
{ {
@ -505,11 +503,12 @@ public class NotificationsControllerTests : IClassFixture<ApiApplicationFactory>
userPartOrOrganizationNotificationWithStatuses userPartOrOrganizationNotificationWithStatuses
} }
.SelectMany(n => n) .SelectMany(n => n)
.Where(n => n.Item1.ClientType is ClientType.All or ClientType.Web)
.ToList(); .ToList();
} }
private async Task<List<Notification>> CreateNotificationsAsync(Guid? userId = null, Guid? organizationId = null, private async Task<List<Notification>> CreateNotificationsAsync(Guid? userId = null, Guid? organizationId = null,
int numberToCreate = 5) int numberToCreate = 3)
{ {
var priorities = Enum.GetValues<Priority>(); var priorities = Enum.GetValues<Priority>();
var clientTypes = Enum.GetValues<ClientType>(); var clientTypes = Enum.GetValues<ClientType>();
@ -570,13 +569,9 @@ public class NotificationsControllerTests : IClassFixture<ApiApplicationFactory>
DeletedDate = DateTime.UtcNow - TimeSpan.FromMinutes(_random.Next(3600)) DeletedDate = DateTime.UtcNow - TimeSpan.FromMinutes(_random.Next(3600))
}); });
return List<NotificationStatus> statuses =
[ [readDateNotificationStatus, deletedDateNotificationStatus, readDateAndDeletedDateNotificationStatus];
(notifications[0], readDateNotificationStatus),
(notifications[1], deletedDateNotificationStatus), return notifications.Select(n => (n, statuses.Find(s => s.NotificationId == n.Id))).ToList();
(notifications[2], readDateAndDeletedDateNotificationStatus),
(notifications[3], null),
(notifications[4], null)
];
} }
} }

View File

@ -3,6 +3,7 @@ using Bit.Billing.Controllers;
using Bit.Billing.Test.Utilities; using Bit.Billing.Test.Utilities;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -33,6 +34,7 @@ public class PayPalControllerTests
private readonly ITransactionRepository _transactionRepository = Substitute.For<ITransactionRepository>(); private readonly ITransactionRepository _transactionRepository = Substitute.For<ITransactionRepository>();
private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>(); private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>();
private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>(); private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>();
private readonly IPremiumUserBillingService _premiumUserBillingService = Substitute.For<IPremiumUserBillingService>();
private const string _defaultWebhookKey = "webhook-key"; private const string _defaultWebhookKey = "webhook-key";
@ -385,8 +387,6 @@ public class PayPalControllerTests
_userRepository.GetByIdAsync(userId).Returns(user); _userRepository.GetByIdAsync(userId).Returns(user);
_paymentService.CreditAccountAsync(user, 48M).Returns(true);
var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody);
var result = await controller.PostIpn(); var result = await controller.PostIpn();
@ -398,9 +398,7 @@ public class PayPalControllerTests
transaction.UserId == userId && transaction.UserId == userId &&
transaction.Amount == 48M)); transaction.Amount == 48M));
await _paymentService.Received(1).CreditAccountAsync(user, 48M); await _premiumUserBillingService.Received(1).Credit(user, 48M);
await _userRepository.Received(1).ReplaceAsync(user);
await _mailService.Received(1).SendAddedCreditAsync(billingEmail, 48M); await _mailService.Received(1).SendAddedCreditAsync(billingEmail, 48M);
} }
@ -544,7 +542,8 @@ public class PayPalControllerTests
_paymentService, _paymentService,
_transactionRepository, _transactionRepository,
_userRepository, _userRepository,
_providerRepository); _providerRepository,
_premiumUserBillingService);
var httpContext = new DefaultHttpContext(); var httpContext = new DefaultHttpContext();

View File

@ -0,0 +1,35 @@
#nullable enable
using AutoFixture;
using AutoFixture.Kernel;
using Azure.Storage.Queues;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
namespace Bit.Core.Test.AutoFixture;
public class QueueClientBuilder : ISpecimenBuilder
{
public object Create(object request, ISpecimenContext context)
{
var type = request as Type;
if (type == typeof(QueueClient))
{
return Substitute.For<QueueClient>();
}
return new NoSpecimen();
}
}
public class QueueClientCustomizeAttribute : BitCustomizeAttribute
{
public override ICustomization GetCustomization() => new QueueClientFixtures();
}
public class QueueClientFixtures : ICustomization
{
public void Customize(IFixture fixture)
{
fixture.Customizations.Add(new QueueClientBuilder());
}
}

View File

@ -0,0 +1,94 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using System.Text.Json;
using Bit.Core.Enums;
using Bit.Core.Models.Api;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Core.Test.Models.Api.Request;
public class PushSendRequestModelTests
{
[Theory]
[InlineData(null, null)]
[InlineData(null, "")]
[InlineData(null, " ")]
[InlineData("", null)]
[InlineData(" ", null)]
[InlineData("", "")]
[InlineData(" ", " ")]
public void Validate_UserIdOrganizationIdNullOrEmpty_Invalid(string? userId, string? organizationId)
{
var model = new PushSendRequestModel
{
UserId = userId,
OrganizationId = organizationId,
Type = PushType.SyncCiphers,
Payload = "test"
};
var results = Validate(model);
Assert.Single(results);
Assert.Contains(results, result => result.ErrorMessage == "UserId or OrganizationId is required.");
}
[Theory]
[BitAutoData("Payload")]
[BitAutoData("Type")]
public void Validate_RequiredFieldNotProvided_Invalid(string requiredField)
{
var model = new PushSendRequestModel
{
UserId = Guid.NewGuid().ToString(),
OrganizationId = Guid.NewGuid().ToString(),
Type = PushType.SyncCiphers,
Payload = "test"
};
var dictionary = new Dictionary<string, object?>();
foreach (var property in model.GetType().GetProperties())
{
if (property.Name == requiredField)
{
continue;
}
dictionary[property.Name] = property.GetValue(model);
}
var serialized = JsonSerializer.Serialize(dictionary, JsonHelpers.IgnoreWritingNull);
var jsonException =
Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<PushSendRequestModel>(serialized));
Assert.Contains($"missing required properties, including the following: {requiredField}",
jsonException.Message);
}
[Fact]
public void Validate_AllFieldsPresent_Valid()
{
var model = new PushSendRequestModel
{
UserId = Guid.NewGuid().ToString(),
OrganizationId = Guid.NewGuid().ToString(),
Type = PushType.SyncCiphers,
Payload = "test payload",
Identifier = Guid.NewGuid().ToString(),
ClientType = ClientType.All,
DeviceId = Guid.NewGuid().ToString()
};
var results = Validate(model);
Assert.Empty(results);
}
private static List<ValidationResult> Validate(PushSendRequestModel model)
{
var results = new List<ValidationResult>();
Validator.TryValidateObject(model, new ValidationContext(model), results, true);
return results;
}
}

View File

@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands; using Bit.Core.NotificationCenter.Commands;
using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories; using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Test.NotificationCenter.AutoFixture; using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
@ -55,5 +56,8 @@ public class CreateNotificationCommandTest
Assert.Equal(notification, newNotification); Assert.Equal(notification, newNotification);
Assert.Equal(DateTime.UtcNow, notification.CreationDate, TimeSpan.FromMinutes(1)); Assert.Equal(DateTime.UtcNow, notification.CreationDate, TimeSpan.FromMinutes(1));
Assert.Equal(notification.CreationDate, notification.RevisionDate); Assert.Equal(notification.CreationDate, notification.RevisionDate);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationAsync(newNotification);
} }
} }

View File

@ -1,42 +1,241 @@
using Bit.Core.NotificationHub; #nullable enable
using Bit.Core.Platform.Push; using System.Text.Json;
using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.Models.Data;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationHub;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Microsoft.AspNetCore.Http; using Bit.Core.Test.NotificationCenter.AutoFixture;
using Microsoft.Extensions.Logging; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
namespace Bit.Core.Test.NotificationHub; namespace Bit.Core.Test.NotificationHub;
[SutProviderCustomize]
public class NotificationHubPushNotificationServiceTests public class NotificationHubPushNotificationServiceTests
{ {
private readonly NotificationHubPushNotificationService _sut; [Theory]
[BitAutoData]
private readonly IInstallationDeviceRepository _installationDeviceRepository; [NotificationCustomize]
private readonly INotificationHubPool _notificationHubPool; public async void PushNotificationAsync_Global_NotSent(
private readonly IHttpContextAccessor _httpContextAccessor; SutProvider<NotificationHubPushNotificationService> sutProvider, Notification notification)
private readonly ILogger<NotificationsApiPushNotificationService> _logger;
public NotificationHubPushNotificationServiceTests()
{ {
_installationDeviceRepository = Substitute.For<IInstallationDeviceRepository>(); await sutProvider.Sut.PushNotificationAsync(notification);
_httpContextAccessor = Substitute.For<IHttpContextAccessor>();
_notificationHubPool = Substitute.For<INotificationHubPool>();
_logger = Substitute.For<ILogger<NotificationsApiPushNotificationService>>();
_sut = new NotificationHubPushNotificationService( await sutProvider.GetDependency<INotificationHubPool>()
_installationDeviceRepository, .Received(0)
_notificationHubPool, .AllClients
_httpContextAccessor, .Received(0)
_logger .SendTemplateNotificationAsync(Arg.Any<IDictionary<string, string>>(), Arg.Any<string>());
); await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
} }
// Remove this test when we add actual tests. It only proves that [Theory]
// we've properly constructed the system under test. [BitAutoData(false)]
[Fact(Skip = "Needs additional work")] [BitAutoData(true)]
public void ServiceExists() [NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdProvidedClientTypeAll_SentToUser(
bool organizationIdNull, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{ {
Assert.NotNull(_sut); if (organizationIdNull)
{
notification.OrganizationId = null;
}
notification.ClientType = ClientType.All;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
$"(template:payload_userId:{notification.UserId})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(false, ClientType.Browser)]
[BitAutoData(false, ClientType.Desktop)]
[BitAutoData(false, ClientType.Web)]
[BitAutoData(false, ClientType.Mobile)]
[BitAutoData(true, ClientType.Browser)]
[BitAutoData(true, ClientType.Desktop)]
[BitAutoData(true, ClientType.Web)]
[BitAutoData(true, ClientType.Mobile)]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdProvidedClientTypeNotAll_SentToUser(bool organizationIdNull,
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{
if (organizationIdNull)
{
notification.OrganizationId = null;
}
notification.ClientType = clientType;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
$"(template:payload_userId:{notification.UserId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeAll_SentToOrganization(
SutProvider<NotificationHubPushNotificationService> sutProvider, Notification notification)
{
notification.UserId = null;
notification.ClientType = ClientType.All;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
$"(template:payload && organizationId:{notification.OrganizationId})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeNotAll_SentToOrganization(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{
notification.UserId = null;
notification.ClientType = clientType;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
$"(template:payload && organizationId:{notification.OrganizationId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData([null])]
[BitAutoData(ClientType.All)]
public async void SendPayloadToUserAsync_ClientTypeNullOrAll_SentToUser(ClientType? clientType,
SutProvider<NotificationHubPushNotificationService> sutProvider, Guid userId, PushType pushType, string payload,
string identifier)
{
await sutProvider.Sut.SendPayloadToUserAsync(userId.ToString(), pushType, payload, identifier, null,
clientType);
await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{identifier})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Mobile)]
[BitAutoData(ClientType.Web)]
public async void SendPayloadToUserAsync_ClientTypeExplicit_SentToUserAndClientType(ClientType clientType,
SutProvider<NotificationHubPushNotificationService> sutProvider, Guid userId, PushType pushType, string payload,
string identifier)
{
await sutProvider.Sut.SendPayloadToUserAsync(userId.ToString(), pushType, payload, identifier, null,
clientType);
await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{identifier} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData([null])]
[BitAutoData(ClientType.All)]
public async void SendPayloadToOrganizationAsync_ClientTypeNullOrAll_SentToOrganization(ClientType? clientType,
SutProvider<NotificationHubPushNotificationService> sutProvider, Guid organizationId, PushType pushType,
string payload, string identifier)
{
await sutProvider.Sut.SendPayloadToOrganizationAsync(organizationId.ToString(), pushType, payload, identifier,
null, clientType);
await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload,
$"(template:payload && organizationId:{organizationId} && !deviceIdentifier:{identifier})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Mobile)]
[BitAutoData(ClientType.Web)]
public async void SendPayloadToOrganizationAsync_ClientTypeExplicit_SentToOrganizationAndClientType(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider, Guid organizationId,
PushType pushType, string payload, string identifier)
{
await sutProvider.Sut.SendPayloadToOrganizationAsync(organizationId.ToString(), pushType, payload, identifier,
null, clientType);
await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload,
$"(template:payload && organizationId:{organizationId} && !deviceIdentifier:{identifier} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
private static NotificationPushNotification ToSyncNotificationPushNotification(Notification notification) =>
new()
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate
};
private static async Task AssertSendTemplateNotificationAsync(
SutProvider<NotificationHubPushNotificationService> sutProvider, PushType type, object payload, string tag)
{
await sutProvider.GetDependency<INotificationHubPool>()
.Received(1)
.AllClients
.Received(1)
.SendTemplateNotificationAsync(
Arg.Is<IDictionary<string, string>>(dictionary => MatchingSendPayload(dictionary, type, payload)),
tag);
}
private static bool MatchingSendPayload(IDictionary<string, string> dictionary, PushType type, object payload)
{
return dictionary.ContainsKey("type") && dictionary["type"].Equals(((byte)type).ToString()) &&
dictionary.ContainsKey("payload") && dictionary["payload"].Equals(JsonSerializer.Serialize(payload));
} }
} }

View File

@ -1,45 +1,290 @@
using Bit.Core.NotificationHub; #nullable enable
using Bit.Core.Repositories; using Bit.Core.Enums;
using Bit.Core.Settings; using Bit.Core.NotificationHub;
using Microsoft.Extensions.Logging; using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Azure.NotificationHubs;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
namespace Bit.Core.Test.NotificationHub; namespace Bit.Core.Test.NotificationHub;
[SutProviderCustomize]
public class NotificationHubPushRegistrationServiceTests public class NotificationHubPushRegistrationServiceTests
{ {
private readonly NotificationHubPushRegistrationService _sut; [Theory]
[BitAutoData([null])]
private readonly IInstallationDeviceRepository _installationDeviceRepository; [BitAutoData("")]
private readonly IServiceProvider _serviceProvider; [BitAutoData(" ")]
private readonly ILogger<NotificationHubPushRegistrationService> _logger; public async Task CreateOrUpdateRegistrationAsync_PushTokenNullOrEmpty_InstallationNotCreated(string? pushToken,
private readonly GlobalSettings _globalSettings; SutProvider<NotificationHubPushRegistrationService> sutProvider, Guid deviceId, Guid userId, Guid identifier,
private readonly INotificationHubPool _notificationHubPool; Guid organizationId)
public NotificationHubPushRegistrationServiceTests()
{ {
_installationDeviceRepository = Substitute.For<IInstallationDeviceRepository>(); await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(),
_serviceProvider = Substitute.For<IServiceProvider>(); identifier.ToString(), DeviceType.Android, [organizationId.ToString()]);
_logger = Substitute.For<ILogger<NotificationHubPushRegistrationService>>();
_globalSettings = new GlobalSettings();
_notificationHubPool = Substitute.For<INotificationHubPool>();
_sut = new NotificationHubPushRegistrationService( sutProvider.GetDependency<INotificationHubPool>()
_installationDeviceRepository, .Received(0)
_globalSettings, .ClientFor(deviceId);
_notificationHubPool,
_serviceProvider,
Substitute.For<IHttpClientFactory>(),
_logger
);
} }
// Remove this test when we add actual tests. It only proves that [Theory]
// we've properly constructed the system under test. [BitAutoData(false, false)]
[Fact(Skip = "Needs additional work")] [BitAutoData(false, true)]
public void ServiceExists() [BitAutoData(true, false)]
[BitAutoData(true, true)]
public async Task CreateOrUpdateRegistrationAsync_DeviceTypeAndroid_InstallationCreated(bool identifierNull,
bool partOfOrganizationId, SutProvider<NotificationHubPushRegistrationService> sutProvider, Guid deviceId,
Guid userId, Guid? identifier, Guid organizationId)
{ {
Assert.NotNull(_sut); var notificationHubClient = Substitute.For<INotificationHubClient>();
sutProvider.GetDependency<INotificationHubPool>().ClientFor(Arg.Any<Guid>()).Returns(notificationHubClient);
var pushToken = "test push token";
await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(),
identifierNull ? null : identifier.ToString(), DeviceType.Android,
partOfOrganizationId ? [organizationId.ToString()] : []);
sutProvider.GetDependency<INotificationHubPool>()
.Received(1)
.ClientFor(deviceId);
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation =>
installation.InstallationId == deviceId.ToString() &&
installation.PushChannel == pushToken &&
installation.Platform == NotificationPlatform.FcmV1 &&
installation.Tags.Contains($"userId:{userId}") &&
installation.Tags.Contains("clientType:Mobile") &&
(identifierNull || installation.Tags.Contains($"deviceIdentifier:{identifier}")) &&
(!partOfOrganizationId || installation.Tags.Contains($"organizationId:{organizationId}")) &&
installation.Templates.Count == 3));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:payload",
"{\"message\":{\"data\":{\"type\":\"$(type)\",\"payload\":\"$(payload)\"}}}",
new List<string?>
{
"template:payload",
$"template:payload_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:payload_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:message",
"{\"message\":{\"data\":{\"type\":\"$(type)\"},\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}",
new List<string?>
{
"template:message",
$"template:message_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:message_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:badgeMessage",
"{\"message\":{\"data\":{\"type\":\"$(type)\"},\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}",
new List<string?>
{
"template:badgeMessage",
$"template:badgeMessage_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:badgeMessage_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
}
[Theory]
[BitAutoData(false, false)]
[BitAutoData(false, true)]
[BitAutoData(true, false)]
[BitAutoData(true, true)]
public async Task CreateOrUpdateRegistrationAsync_DeviceTypeIOS_InstallationCreated(bool identifierNull,
bool partOfOrganizationId, SutProvider<NotificationHubPushRegistrationService> sutProvider, Guid deviceId,
Guid userId, Guid identifier, Guid organizationId)
{
var notificationHubClient = Substitute.For<INotificationHubClient>();
sutProvider.GetDependency<INotificationHubPool>().ClientFor(Arg.Any<Guid>()).Returns(notificationHubClient);
var pushToken = "test push token";
await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(),
identifierNull ? null : identifier.ToString(), DeviceType.iOS,
partOfOrganizationId ? [organizationId.ToString()] : []);
sutProvider.GetDependency<INotificationHubPool>()
.Received(1)
.ClientFor(deviceId);
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation =>
installation.InstallationId == deviceId.ToString() &&
installation.PushChannel == pushToken &&
installation.Platform == NotificationPlatform.Apns &&
installation.Tags.Contains($"userId:{userId}") &&
installation.Tags.Contains("clientType:Mobile") &&
(identifierNull || installation.Tags.Contains($"deviceIdentifier:{identifier}")) &&
(!partOfOrganizationId || installation.Tags.Contains($"organizationId:{organizationId}")) &&
installation.Templates.Count == 3));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:payload",
"{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"},\"aps\":{\"content-available\":1}}",
new List<string?>
{
"template:payload",
$"template:payload_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:payload_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:message",
"{\"data\":{\"type\":\"#(type)\"},\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}",
new List<string?>
{
"template:message",
$"template:message_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:message_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:badgeMessage",
"{\"data\":{\"type\":\"#(type)\"},\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}",
new List<string?>
{
"template:badgeMessage",
$"template:badgeMessage_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:badgeMessage_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
}
[Theory]
[BitAutoData(false, false)]
[BitAutoData(false, true)]
[BitAutoData(true, false)]
[BitAutoData(true, true)]
public async Task CreateOrUpdateRegistrationAsync_DeviceTypeAndroidAmazon_InstallationCreated(bool identifierNull,
bool partOfOrganizationId, SutProvider<NotificationHubPushRegistrationService> sutProvider, Guid deviceId,
Guid userId, Guid identifier, Guid organizationId)
{
var notificationHubClient = Substitute.For<INotificationHubClient>();
sutProvider.GetDependency<INotificationHubPool>().ClientFor(Arg.Any<Guid>()).Returns(notificationHubClient);
var pushToken = "test push token";
await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(),
identifierNull ? null : identifier.ToString(), DeviceType.AndroidAmazon,
partOfOrganizationId ? [organizationId.ToString()] : []);
sutProvider.GetDependency<INotificationHubPool>()
.Received(1)
.ClientFor(deviceId);
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation =>
installation.InstallationId == deviceId.ToString() &&
installation.PushChannel == pushToken &&
installation.Platform == NotificationPlatform.Adm &&
installation.Tags.Contains($"userId:{userId}") &&
installation.Tags.Contains("clientType:Mobile") &&
(identifierNull || installation.Tags.Contains($"deviceIdentifier:{identifier}")) &&
(!partOfOrganizationId || installation.Tags.Contains($"organizationId:{organizationId}")) &&
installation.Templates.Count == 3));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:payload",
"{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}",
new List<string?>
{
"template:payload",
$"template:payload_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:payload_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:message",
"{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}",
new List<string?>
{
"template:message",
$"template:message_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:message_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation => MatchingInstallationTemplate(
installation.Templates, "template:badgeMessage",
"{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}",
new List<string?>
{
"template:badgeMessage",
$"template:badgeMessage_userId:{userId}",
"clientType:Mobile",
identifierNull ? null : $"template:badgeMessage_deviceIdentifier:{identifier}",
partOfOrganizationId ? $"organizationId:{organizationId}" : null,
})));
}
[Theory]
[BitAutoData(DeviceType.ChromeBrowser)]
[BitAutoData(DeviceType.ChromeExtension)]
[BitAutoData(DeviceType.MacOsDesktop)]
public async Task CreateOrUpdateRegistrationAsync_DeviceTypeNotMobile_InstallationCreated(DeviceType deviceType,
SutProvider<NotificationHubPushRegistrationService> sutProvider, Guid deviceId, Guid userId, Guid identifier,
Guid organizationId)
{
var notificationHubClient = Substitute.For<INotificationHubClient>();
sutProvider.GetDependency<INotificationHubPool>().ClientFor(Arg.Any<Guid>()).Returns(notificationHubClient);
var pushToken = "test push token";
await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(),
identifier.ToString(), deviceType, [organizationId.ToString()]);
sutProvider.GetDependency<INotificationHubPool>()
.Received(1)
.ClientFor(deviceId);
await notificationHubClient
.Received(1)
.CreateOrUpdateInstallationAsync(Arg.Is<Installation>(installation =>
installation.InstallationId == deviceId.ToString() &&
installation.PushChannel == pushToken &&
installation.Tags.Contains($"userId:{userId}") &&
installation.Tags.Contains($"clientType:{DeviceTypes.ToClientType(deviceType)}") &&
installation.Tags.Contains($"deviceIdentifier:{identifier}") &&
installation.Tags.Contains($"organizationId:{organizationId}") &&
installation.Templates.Count == 0));
}
private static bool MatchingInstallationTemplate(IDictionary<string, InstallationTemplate> templates, string key,
string body, List<string?> tags)
{
var tagsNoNulls = tags.FindAll(tag => tag != null);
return templates.ContainsKey(key) && templates[key].Body == body &&
templates[key].Tags.Count == tagsNoNulls.Count &&
templates[key].Tags.All(tagsNoNulls.Contains);
} }
} }

View File

@ -1,33 +1,72 @@
using Bit.Core.Settings; #nullable enable
using System.Text.Json;
using Azure.Storage.Queues;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Test.AutoFixture;
using Bit.Core.Test.AutoFixture.CurrentContextFixtures;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
namespace Bit.Core.Platform.Push.Internal.Test; namespace Bit.Core.Platform.Push.Internal.Test;
[QueueClientCustomize]
[SutProviderCustomize]
public class AzureQueuePushNotificationServiceTests public class AzureQueuePushNotificationServiceTests
{ {
private readonly AzureQueuePushNotificationService _sut; [Theory]
[BitAutoData]
private readonly GlobalSettings _globalSettings; [NotificationCustomize]
private readonly IHttpContextAccessor _httpContextAccessor; [CurrentContextCustomize]
public async void PushNotificationAsync_Notification_Sent(
public AzureQueuePushNotificationServiceTests() SutProvider<AzureQueuePushNotificationService> sutProvider, Notification notification, Guid deviceIdentifier,
ICurrentContext currentContext)
{ {
_globalSettings = new GlobalSettings(); currentContext.DeviceIdentifier.Returns(deviceIdentifier.ToString());
_httpContextAccessor = Substitute.For<IHttpContextAccessor>(); sutProvider.GetDependency<IHttpContextAccessor>().HttpContext!.RequestServices
.GetService(Arg.Any<Type>()).Returns(currentContext);
_sut = new AzureQueuePushNotificationService( await sutProvider.Sut.PushNotificationAsync(notification);
_globalSettings,
_httpContextAccessor await sutProvider.GetDependency<QueueClient>().Received(1)
); .SendMessageAsync(Arg.Is<string>(message =>
MatchMessage(PushType.SyncNotification, message, new SyncNotificationEquals(notification),
deviceIdentifier.ToString())));
} }
// Remove this test when we add actual tests. It only proves that private static bool MatchMessage<T>(PushType pushType, string message, IEquatable<T> expectedPayloadEquatable,
// we've properly constructed the system under test. string contextId)
[Fact(Skip = "Needs additional work")]
public void ServiceExists()
{ {
Assert.NotNull(_sut); var pushNotificationData = JsonSerializer.Deserialize<PushNotificationData<T>>(message);
return pushNotificationData != null &&
pushNotificationData.Type == pushType &&
expectedPayloadEquatable.Equals(pushNotificationData.Payload) &&
pushNotificationData.ContextId == contextId;
}
private class SyncNotificationEquals(Notification notification) : IEquatable<NotificationPushNotification>
{
public bool Equals(NotificationPushNotification? other)
{
return other != null &&
other.Id == notification.Id &&
other.Priority == notification.Priority &&
other.Global == notification.Global &&
other.ClientType == notification.ClientType &&
other.UserId.HasValue == notification.UserId.HasValue &&
other.UserId == notification.UserId &&
other.OrganizationId.HasValue == notification.OrganizationId.HasValue &&
other.OrganizationId == notification.OrganizationId &&
other.Title == notification.Title &&
other.Body == notification.Body &&
other.CreationDate == notification.CreationDate &&
other.RevisionDate == notification.RevisionDate;
}
} }
} }

View File

@ -1,45 +1,62 @@
using AutoFixture; #nullable enable
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Microsoft.Extensions.Logging; using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
using GlobalSettingsCustomization = Bit.Test.Common.AutoFixture.GlobalSettings;
namespace Bit.Core.Platform.Push.Internal.Test; namespace Bit.Core.Platform.Push.Internal.Test;
[SutProviderCustomize]
public class MultiServicePushNotificationServiceTests public class MultiServicePushNotificationServiceTests
{ {
private readonly MultiServicePushNotificationService _sut; [Theory]
[BitAutoData]
private readonly ILogger<MultiServicePushNotificationService> _logger; [NotificationCustomize]
private readonly ILogger<RelayPushNotificationService> _relayLogger; public async Task PushNotificationAsync_Notification_Sent(
private readonly ILogger<NotificationsApiPushNotificationService> _hubLogger; SutProvider<MultiServicePushNotificationService> sutProvider, Notification notification)
private readonly IEnumerable<IPushNotificationService> _services;
private readonly Settings.GlobalSettings _globalSettings;
public MultiServicePushNotificationServiceTests()
{ {
_logger = Substitute.For<ILogger<MultiServicePushNotificationService>>(); await sutProvider.Sut.PushNotificationAsync(notification);
_relayLogger = Substitute.For<ILogger<RelayPushNotificationService>>();
_hubLogger = Substitute.For<ILogger<NotificationsApiPushNotificationService>>();
_services = new Fixture().WithAutoNSubstitutions().CreateMany<IPushNotificationService>();
var fixture = new Fixture().WithAutoNSubstitutions().Customize(new GlobalSettingsCustomization()); await sutProvider.GetDependency<IEnumerable<IPushNotificationService>>()
_services = fixture.CreateMany<IPushNotificationService>(); .First()
_globalSettings = fixture.Create<Settings.GlobalSettings>(); .Received(1)
.PushNotificationAsync(notification);
_sut = new MultiServicePushNotificationService(
_services,
_logger,
_globalSettings
);
} }
// Remove this test when we add actual tests. It only proves that [Theory]
// we've properly constructed the system under test. [BitAutoData([null, null])]
[Fact] [BitAutoData(ClientType.All, null)]
public void ServiceExists() [BitAutoData([null, "test device id"])]
[BitAutoData(ClientType.All, "test device id")]
public async Task SendPayloadToUserAsync_Message_Sent(ClientType? clientType, string? deviceId, string userId,
PushType type, object payload, string identifier, SutProvider<MultiServicePushNotificationService> sutProvider)
{ {
Assert.NotNull(_sut); await sutProvider.Sut.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType);
await sutProvider.GetDependency<IEnumerable<IPushNotificationService>>()
.First()
.Received(1)
.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType);
}
[Theory]
[BitAutoData([null, null])]
[BitAutoData(ClientType.All, null)]
[BitAutoData([null, "test device id"])]
[BitAutoData(ClientType.All, "test device id")]
public async Task SendPayloadToOrganizationAsync_Message_Sent(ClientType? clientType, string? deviceId,
string organizationId, PushType type, object payload, string identifier,
SutProvider<MultiServicePushNotificationService> sutProvider)
{
await sutProvider.Sut.SendPayloadToOrganizationAsync(organizationId, type, payload, identifier, deviceId,
clientType);
await sutProvider.GetDependency<IEnumerable<IPushNotificationService>>()
.First()
.Received(1)
.SendPayloadToOrganizationAsync(organizationId, type, payload, identifier, deviceId, clientType);
} }
} }

View File

@ -3,6 +3,7 @@ using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.NotificationHub; using Bit.Core.NotificationHub;
using Bit.Core.Platform.Push; using Bit.Core.Platform.Push;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -17,15 +18,23 @@ namespace Bit.Core.Test.Services;
[SutProviderCustomize] [SutProviderCustomize]
public class DeviceServiceTests public class DeviceServiceTests
{ {
[Fact] [Theory]
public async Task DeviceSaveShouldUpdateRevisionDateAndPushRegistration() [BitAutoData]
public async Task SaveAsync_IdProvided_UpdatedRevisionDateAndPushRegistration(Guid id, Guid userId,
Guid organizationId1, Guid organizationId2,
OrganizationUserOrganizationDetails organizationUserOrganizationDetails1,
OrganizationUserOrganizationDetails organizationUserOrganizationDetails2)
{ {
organizationUserOrganizationDetails1.OrganizationId = organizationId1;
organizationUserOrganizationDetails2.OrganizationId = organizationId2;
var deviceRepo = Substitute.For<IDeviceRepository>(); var deviceRepo = Substitute.For<IDeviceRepository>();
var pushRepo = Substitute.For<IPushRegistrationService>(); var pushRepo = Substitute.For<IPushRegistrationService>();
var deviceService = new DeviceService(deviceRepo, pushRepo); var organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
organizationUserRepository.GetManyDetailsByUserAsync(Arg.Any<Guid>(), Arg.Any<OrganizationUserStatusType?>())
.Returns([organizationUserOrganizationDetails1, organizationUserOrganizationDetails2]);
var deviceService = new DeviceService(deviceRepo, pushRepo, organizationUserRepository);
var id = Guid.NewGuid();
var userId = Guid.NewGuid();
var device = new Device var device = new Device
{ {
Id = id, Id = id,
@ -38,8 +47,53 @@ public class DeviceServiceTests
await deviceService.SaveAsync(device); await deviceService.SaveAsync(device);
Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1));
await pushRepo.Received().CreateOrUpdateRegistrationAsync(Arg.Is<PushRegistrationData>(v => v.Token == "testToken"), id.ToString(), await pushRepo.Received(1).CreateOrUpdateRegistrationAsync(Arg.Is<PushRegistrationData>(v => v.Token == "testToken"), id.ToString(),
userId.ToString(), "testid", DeviceType.Android); userId.ToString(), "testid", DeviceType.Android,
Arg.Do<IEnumerable<string>>(organizationIds =>
{
var organizationIdsList = organizationIds.ToList();
Assert.Equal(2, organizationIdsList.Count);
Assert.Contains(organizationId1.ToString(), organizationIdsList);
Assert.Contains(organizationId2.ToString(), organizationIdsList);
}));
}
[Theory]
[BitAutoData]
public async Task SaveAsync_IdNotProvided_CreatedAndPushRegistration(Guid userId, Guid organizationId1,
Guid organizationId2,
OrganizationUserOrganizationDetails organizationUserOrganizationDetails1,
OrganizationUserOrganizationDetails organizationUserOrganizationDetails2)
{
organizationUserOrganizationDetails1.OrganizationId = organizationId1;
organizationUserOrganizationDetails2.OrganizationId = organizationId2;
var deviceRepo = Substitute.For<IDeviceRepository>();
var pushRepo = Substitute.For<IPushRegistrationService>();
var organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
organizationUserRepository.GetManyDetailsByUserAsync(Arg.Any<Guid>(), Arg.Any<OrganizationUserStatusType?>())
.Returns([organizationUserOrganizationDetails1, organizationUserOrganizationDetails2]);
var deviceService = new DeviceService(deviceRepo, pushRepo, organizationUserRepository);
var device = new Device
{
Name = "test device",
Type = DeviceType.Android,
UserId = userId,
PushToken = "testtoken",
Identifier = "testid"
};
await deviceService.SaveAsync(device);
await pushRepo.Received(1).CreateOrUpdateRegistrationAsync(Arg.Is<PushRegistrationData>(v => v.Token == "testToken"),
Arg.Do<string>(id => Guid.TryParse(id, out var _)), userId.ToString(), "testid", DeviceType.Android,
Arg.Do<IEnumerable<string>>(organizationIds =>
{
var organizationIdsList = organizationIds.ToList();
Assert.Equal(2, organizationIdsList.Count);
Assert.Contains(organizationId1.ToString(), organizationIdsList);
Assert.Contains(organizationId2.ToString(), organizationIdsList);
}));
} }
/// <summary> /// <summary>
@ -63,12 +117,7 @@ public class DeviceServiceTests
sutProvider.GetDependency<IDeviceRepository>() sutProvider.GetDependency<IDeviceRepository>()
.GetManyByUserIdAsync(currentUserId) .GetManyByUserIdAsync(currentUserId)
.Returns(new List<Device> .Returns(new List<Device> { deviceOne, deviceTwo, deviceThree, });
{
deviceOne,
deviceTwo,
deviceThree,
});
var currentDeviceModel = new DeviceKeysUpdateRequestModel var currentDeviceModel = new DeviceKeysUpdateRequestModel
{ {
@ -86,7 +135,8 @@ public class DeviceServiceTests
}, },
}; };
await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, alteredDeviceModels); await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel,
alteredDeviceModels);
// Updating trust, "current" or "other" only needs to change the EncryptedPublicKey & EncryptedUserKey // Updating trust, "current" or "other" only needs to change the EncryptedPublicKey & EncryptedUserKey
await sutProvider.GetDependency<IDeviceRepository>() await sutProvider.GetDependency<IDeviceRepository>()
@ -150,11 +200,7 @@ public class DeviceServiceTests
sutProvider.GetDependency<IDeviceRepository>() sutProvider.GetDependency<IDeviceRepository>()
.GetManyByUserIdAsync(currentUserId) .GetManyByUserIdAsync(currentUserId)
.Returns(new List<Device> .Returns(new List<Device> { deviceOne, deviceTwo, });
{
deviceOne,
deviceTwo,
});
var currentDeviceModel = new DeviceKeysUpdateRequestModel var currentDeviceModel = new DeviceKeysUpdateRequestModel
{ {
@ -172,7 +218,8 @@ public class DeviceServiceTests
}, },
}; };
await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, alteredDeviceModels); await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel,
alteredDeviceModels);
// Check that UpsertAsync was called for the trusted device // Check that UpsertAsync was called for the trusted device
await sutProvider.GetDependency<IDeviceRepository>() await sutProvider.GetDependency<IDeviceRepository>()
@ -204,11 +251,7 @@ public class DeviceServiceTests
sutProvider.GetDependency<IDeviceRepository>() sutProvider.GetDependency<IDeviceRepository>()
.GetManyByUserIdAsync(currentUserId) .GetManyByUserIdAsync(currentUserId)
.Returns(new List<Device> .Returns(new List<Device> { deviceOne, deviceTwo, });
{
deviceOne,
deviceTwo,
});
var currentDeviceModel = new DeviceKeysUpdateRequestModel var currentDeviceModel = new DeviceKeysUpdateRequestModel
{ {
@ -238,11 +281,7 @@ public class DeviceServiceTests
sutProvider.GetDependency<IDeviceRepository>() sutProvider.GetDependency<IDeviceRepository>()
.GetManyByUserIdAsync(currentUserId) .GetManyByUserIdAsync(currentUserId)
.Returns(new List<Device> .Returns(new List<Device> { deviceOne, deviceTwo, });
{
deviceOne,
deviceTwo,
});
var currentDeviceModel = new DeviceKeysUpdateRequestModel var currentDeviceModel = new DeviceKeysUpdateRequestModel
{ {
@ -261,6 +300,7 @@ public class DeviceServiceTests
}; };
await Assert.ThrowsAsync<BadRequestException>(() => await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, alteredDeviceModels)); sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel,
alteredDeviceModels));
} }
} }

View File

@ -24,6 +24,7 @@
"sstamp", "sstamp",
"premium", "premium",
"device", "device",
"devicetype",
"orgowner", "orgowner",
"orgadmin", "orgadmin",
"orguser", "orguser",