1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-22 12:15:36 +01:00

Added the ability to create a JWT on a user license that contains all license properties as claims

This commit is contained in:
Conner Turnbull 2024-10-31 11:09:04 -04:00
parent 02fd8b0f3e
commit a473487e19
No known key found for this signature in database
GPG Key ID: D42CA06D8EB866CC
7 changed files with 78 additions and 7 deletions

View File

@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Licenses.Services;
using Bit.Core.Billing.Licenses.Services.Implementations;
using Bit.Core.Entities;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Billing.Licenses.Extensions;
@ -10,5 +11,6 @@ public static class LicenseServiceCollectionExtensions
public static void AddLicenseServices(this IServiceCollection services)
{
services.AddTransient<ILicenseClaimsFactory<Organization>, OrganizationLicenseClaimsFactory>();
services.AddTransient<ILicenseClaimsFactory<User>, UserLicenseClaimsFactory>();
}
}

View File

@ -0,0 +1,38 @@
using System.Globalization;
using System.Security.Claims;
using Bit.Core.Billing.Licenses.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Licenses.Services.Implementations;
public class UserLicenseClaimsFactory : ILicenseClaimsFactory<User>
{
public Task<List<Claim>> GenerateClaims(User entity, LicenseContext licenseContext)
{
var subscriptionInfo = licenseContext.SubscriptionInfo;
var expires = subscriptionInfo.UpcomingInvoice?.Date?.AddDays(7) ?? entity.PremiumExpirationDate?.AddDays(7);
var refresh = subscriptionInfo.UpcomingInvoice?.Date ?? entity.PremiumExpirationDate;
var trial = (subscriptionInfo.Subscription?.TrialEndDate.HasValue ?? false) &&
subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow;
var claims = new List<Claim>
{
new(nameof(UserLicense.LicenseType), LicenseType.User.ToString()),
new(nameof(UserLicense.LicenseKey), entity.LicenseKey),
new(nameof(UserLicense.Id), entity.Id.ToString()),
new(nameof(UserLicense.Name), entity.Name),
new(nameof(UserLicense.Email), entity.Email),
new(nameof(UserLicense.Premium), entity.Premium.ToString()),
new(nameof(UserLicense.MaxStorageGb), entity.MaxStorageGb.ToString()),
new(nameof(UserLicense.Issued), DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)),
new(nameof(UserLicense.Expires), expires.ToString()),
new(nameof(UserLicense.Refresh), refresh.ToString()),
new(nameof(UserLicense.Trial), trial.ToString()),
};
return Task.FromResult(claims);
}
}

View File

@ -70,6 +70,7 @@ public class UserLicense : ILicense
public LicenseType? LicenseType { get; set; }
public string Hash { get; set; }
public string Signature { get; set; }
public string Token { get; set; }
[JsonIgnore]
public byte[] SignatureBytes => Convert.FromBase64String(Signature);
@ -84,6 +85,7 @@ public class UserLicense : ILicense
!p.Name.Equals(nameof(Signature)) &&
!p.Name.Equals(nameof(SignatureBytes)) &&
!p.Name.Equals(nameof(LicenseType)) &&
!p.Name.Equals(nameof(Token)) &&
(
!forHash ||
(

View File

@ -18,4 +18,6 @@ public interface ILicensingService
Organization organization,
Guid installationId,
SubscriptionInfo subscriptionInfo);
Task<string> CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo);
}

View File

@ -26,10 +26,10 @@ public class LicensingService : ILicensingService
private readonly IGlobalSettings _globalSettings;
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IMailService _mailService;
private readonly ILogger<LicensingService> _logger;
private readonly ILicenseClaimsFactory<Organization> _organizationLicenseClaimsFactory;
private readonly ILicenseClaimsFactory<User> _userLicenseClaimsFactory;
private readonly IFeatureService _featureService;
private IDictionary<Guid, DateTime> _userCheckCache = new Dictionary<Guid, DateTime>();
@ -37,22 +37,22 @@ public class LicensingService : ILicensingService
public LicensingService(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IMailService mailService,
IWebHostEnvironment environment,
ILogger<LicensingService> logger,
IGlobalSettings globalSettings,
ILicenseClaimsFactory<Organization> organizationLicenseClaimsFactory,
IFeatureService featureService)
IFeatureService featureService,
ILicenseClaimsFactory<User> userLicenseClaimsFactory)
{
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_mailService = mailService;
_logger = logger;
_globalSettings = globalSettings;
_organizationLicenseClaimsFactory = organizationLicenseClaimsFactory;
_featureService = featureService;
_userLicenseClaimsFactory = userLicenseClaimsFactory;
var certThumbprint = environment.IsDevelopment() ?
"207E64A231E8AA32AAF68A61037C075EBEBD553F" :
@ -305,6 +305,21 @@ public class LicensingService : ILicensingService
return GenerateToken(claims, audience, expires);
}
public async Task<string> CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.SelfHostLicenseRefactor))
{
return null;
}
var licenseContext = new LicenseContext { SubscriptionInfo = subscriptionInfo };
var claims = await _userLicenseClaimsFactory.GenerateClaims(user, licenseContext);
var audience = user.Id.ToString();
var expires = user.PremiumExpirationDate ?? DateTime.UtcNow.AddDays(7);
return GenerateToken(claims, audience, expires);
}
private string GenerateToken(List<Claim> claims, string audience, DateTime expires)
{
if (claims.All(claim => claim.Type != JwtClaimTypes.JwtId))

View File

@ -1111,7 +1111,9 @@ public class UserService : UserManager<User>, IUserService, IDisposable
}
}
public async Task<UserLicense> GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null,
public async Task<UserLicense> GenerateLicenseAsync(
User user,
SubscriptionInfo subscriptionInfo = null,
int? version = null)
{
if (user == null)
@ -1124,8 +1126,13 @@ public class UserService : UserManager<User>, IUserService, IDisposable
subscriptionInfo = await _paymentService.GetSubscriptionAsync(user);
}
return subscriptionInfo == null ? new UserLicense(user, _licenseService) :
new UserLicense(user, subscriptionInfo, _licenseService);
var userLicense = subscriptionInfo == null
? new UserLicense(user, _licenseService)
: new UserLicense(user, subscriptionInfo, _licenseService);
userLicense.Token = await _licenseService.CreateUserTokenAsync(user, subscriptionInfo);
return userLicense;
}
public override async Task<bool> CheckPasswordAsync(User user, string password)

View File

@ -58,4 +58,9 @@ public class NoopLicensingService : ILicensingService
{
return Task.FromResult<string>(null);
}
public Task<string> CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo)
{
return Task.FromResult<string>(null);
}
}