diff --git a/perf/MicroBenchmarks/Identity/IdentityServer/RedisPersistedGrantStoreTests.cs b/perf/MicroBenchmarks/Identity/IdentityServer/RedisPersistedGrantStoreTests.cs index 0ca09992f..cbb11acbb 100644 --- a/perf/MicroBenchmarks/Identity/IdentityServer/RedisPersistedGrantStoreTests.cs +++ b/perf/MicroBenchmarks/Identity/IdentityServer/RedisPersistedGrantStoreTests.cs @@ -13,8 +13,11 @@ public class RedisPersistedGrantStoreTests { const string SQL = nameof(SQL); const string Redis = nameof(Redis); + const string Cosmos = nameof(Cosmos); + private readonly IPersistedGrantStore _redisGrantStore; private readonly IPersistedGrantStore _sqlGrantStore; + private readonly IPersistedGrantStore _cosmosGrantStore; private readonly PersistedGrant _updateGrant; private IPersistedGrantStore _grantStore = null!; @@ -45,12 +48,18 @@ public class RedisPersistedGrantStoreTests ); var sqlConnectionString = "YOUR CONNECTION STRING HERE"; - _sqlGrantStore = new PersistedGrantStore( new GrantRepository( sqlConnectionString, sqlConnectionString - ) + ), + g => new Bit.Core.Auth.Entities.Grant(g) + ); + + var cosmosConnectionString = "YOUR CONNECTION STRING HERE"; + _cosmosGrantStore = new PersistedGrantStore( + new Bit.Core.Auth.Repositories.Cosmos.GrantRepository(cosmosConnectionString), + g => new Bit.Core.Auth.Models.Data.GrantItem(g) ); var creationTime = new DateTime(638350407400000000, DateTimeKind.Utc); @@ -69,7 +78,7 @@ public class RedisPersistedGrantStoreTests }; } - [Params(Redis, SQL)] + [Params(Redis, SQL, Cosmos)] public string StoreType { get; set; } = null!; [GlobalSetup] @@ -83,6 +92,10 @@ public class RedisPersistedGrantStoreTests { _grantStore = _sqlGrantStore; } + else if (StoreType == Cosmos) + { + _grantStore = _cosmosGrantStore; + } else { throw new InvalidProgramException(); diff --git a/src/Core/Auth/Entities/Grant.cs b/src/Core/Auth/Entities/Grant.cs index 7b2522fc7..a0e56e35d 100644 --- a/src/Core/Auth/Entities/Grant.cs +++ b/src/Core/Auth/Entities/Grant.cs @@ -1,10 +1,28 @@ #nullable enable using System.ComponentModel.DataAnnotations; +using Bit.Core.Auth.Models.Data; +using Duende.IdentityServer.Models; namespace Bit.Core.Auth.Entities; -public class Grant +public class Grant : IGrant { + public Grant() { } + + public Grant(PersistedGrant pGrant) + { + Key = pGrant.Key; + Type = pGrant.Type; + SubjectId = pGrant.SubjectId; + SessionId = pGrant.SessionId; + ClientId = pGrant.ClientId; + Description = pGrant.Description; + CreationDate = pGrant.CreationTime; + ExpirationDate = pGrant.Expiration; + ConsumedDate = pGrant.ConsumedTime; + Data = pGrant.Data; + } + public int Id { get; set; } [MaxLength(200)] public string Key { get; set; } = null!; diff --git a/src/Core/Auth/Models/Data/GrantItem.cs b/src/Core/Auth/Models/Data/GrantItem.cs new file mode 100644 index 000000000..de856904d --- /dev/null +++ b/src/Core/Auth/Models/Data/GrantItem.cs @@ -0,0 +1,77 @@ +using System.Text.Json.Serialization; +using Bit.Core.Auth.Repositories.Cosmos; +using Duende.IdentityServer.Models; + +namespace Bit.Core.Auth.Models.Data; + +public class GrantItem : IGrant +{ + public GrantItem() { } + + public GrantItem(PersistedGrant pGrant) + { + Key = pGrant.Key; + Type = pGrant.Type; + SubjectId = pGrant.SubjectId; + SessionId = pGrant.SessionId; + ClientId = pGrant.ClientId; + Description = pGrant.Description; + CreationDate = pGrant.CreationTime; + ExpirationDate = pGrant.Expiration; + ConsumedDate = pGrant.ConsumedTime; + Data = pGrant.Data; + SetTtl(); + } + + public GrantItem(IGrant g) + { + Key = g.Key; + Type = g.Type; + SubjectId = g.SubjectId; + SessionId = g.SessionId; + ClientId = g.ClientId; + Description = g.Description; + CreationDate = g.CreationDate; + ExpirationDate = g.ExpirationDate; + ConsumedDate = g.ConsumedDate; + Data = g.Data; + SetTtl(); + } + + [JsonPropertyName("id")] + [JsonConverter(typeof(Base64IdStringConverter))] + public string Key { get; set; } + [JsonPropertyName("typ")] + public string Type { get; set; } + [JsonPropertyName("sub")] + public string SubjectId { get; set; } + [JsonPropertyName("sid")] + public string SessionId { get; set; } + [JsonPropertyName("cid")] + public string ClientId { get; set; } + [JsonPropertyName("des")] + public string Description { get; set; } + [JsonPropertyName("cre")] + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + [JsonPropertyName("exp")] + public DateTime? ExpirationDate { get; set; } + [JsonPropertyName("con")] + public DateTime? ConsumedDate { get; set; } + [JsonPropertyName("data")] + public string Data { get; set; } + // https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/how-to-time-to-live?tabs=dotnet-sdk-v3#set-time-to-live-on-an-item-using-an-sdk + [JsonPropertyName("ttl")] + public int Ttl { get; set; } = -1; + + public void SetTtl() + { + if (ExpirationDate != null) + { + var sec = (ExpirationDate.Value - DateTime.UtcNow).TotalSeconds; + if (sec > 0) + { + Ttl = (int)sec; + } + } + } +} diff --git a/src/Core/Auth/Models/Data/IGrant.cs b/src/Core/Auth/Models/Data/IGrant.cs new file mode 100644 index 000000000..5f1463153 --- /dev/null +++ b/src/Core/Auth/Models/Data/IGrant.cs @@ -0,0 +1,15 @@ +namespace Bit.Core.Auth.Models.Data; + +public interface IGrant +{ + string Key { get; set; } + string Type { get; set; } + string SubjectId { get; set; } + string SessionId { get; set; } + string ClientId { get; set; } + string Description { get; set; } + DateTime CreationDate { get; set; } + DateTime? ExpirationDate { get; set; } + DateTime? ConsumedDate { get; set; } + string Data { get; set; } +} diff --git a/src/Core/Auth/Repositories/Cosmos/Base64IdStringConverter.cs b/src/Core/Auth/Repositories/Cosmos/Base64IdStringConverter.cs new file mode 100644 index 000000000..5ec53100f --- /dev/null +++ b/src/Core/Auth/Repositories/Cosmos/Base64IdStringConverter.cs @@ -0,0 +1,32 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using Bit.Core.Utilities; + +namespace Bit.Core.Auth.Repositories.Cosmos; + +public class Base64IdStringConverter : JsonConverter +{ + public override string Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => + ToKey(reader.GetString()); + + public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOptions options) => + writer.WriteStringValue(ToId(value)); + + public static string ToId(string key) + { + if (key == null) + { + return null; + } + return CoreHelpers.TransformToBase64Url(key); + } + + public static string ToKey(string id) + { + if (id == null) + { + return null; + } + return CoreHelpers.TransformFromBase64Url(id); + } +} diff --git a/src/Core/Auth/Repositories/Cosmos/GrantRepository.cs b/src/Core/Auth/Repositories/Cosmos/GrantRepository.cs new file mode 100644 index 000000000..66fc3fb79 --- /dev/null +++ b/src/Core/Auth/Repositories/Cosmos/GrantRepository.cs @@ -0,0 +1,81 @@ +using System.Net; +using System.Text.Json; +using System.Text.Json.Serialization; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Azure.Cosmos; + +namespace Bit.Core.Auth.Repositories.Cosmos; + +public class GrantRepository : IGrantRepository +{ + private readonly CosmosClient _client; + private readonly Database _database; + private readonly Container _container; + + public GrantRepository(GlobalSettings globalSettings) + : this(globalSettings.IdentityServer.CosmosConnectionString) + { } + + public GrantRepository(string cosmosConnectionString) + { + var options = new CosmosClientOptions + { + Serializer = new SystemTextJsonCosmosSerializer(new JsonSerializerOptions + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = false + }) + }; + // TODO: Perhaps we want to evaluate moving this to DI as a keyed service singleton in .NET 8 + _client = new CosmosClient(cosmosConnectionString, options); + _database = _client.GetDatabase("identity"); + _container = _database.GetContainer("grant"); + } + + public async Task GetByKeyAsync(string key) + { + var id = Base64IdStringConverter.ToId(key); + try + { + var response = await _container.ReadItemAsync(id, new PartitionKey(id)); + return response.Resource; + } + catch (CosmosException e) + { + if (e.StatusCode == HttpStatusCode.NotFound) + { + return null; + } + throw; + } + } + + public Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) + => throw new NotImplementedException(); + + public async Task SaveAsync(IGrant obj) + { + if (obj is not GrantItem item) + { + item = new GrantItem(obj); + } + item.SetTtl(); + var id = Base64IdStringConverter.ToId(item.Key); + await _container.UpsertItemAsync(item, new PartitionKey(id), new ItemRequestOptions + { + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/best-practice-dotnet#best-practices-for-write-heavy-workloads + EnableContentResponseOnWrite = false + }); + } + + public async Task DeleteByKeyAsync(string key) + { + var id = Base64IdStringConverter.ToId(key); + await _container.DeleteItemAsync(id, new PartitionKey(id)); + } + + public Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) + => throw new NotImplementedException(); +} diff --git a/src/Core/Auth/Repositories/IGrantRepository.cs b/src/Core/Auth/Repositories/IGrantRepository.cs index feb29681f..2304385be 100644 --- a/src/Core/Auth/Repositories/IGrantRepository.cs +++ b/src/Core/Auth/Repositories/IGrantRepository.cs @@ -1,12 +1,12 @@ -using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Data; namespace Bit.Core.Auth.Repositories; public interface IGrantRepository { - Task GetByKeyAsync(string key); - Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type); - Task SaveAsync(Grant obj); + Task GetByKeyAsync(string key); + Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type); + Task SaveAsync(IGrant obj); Task DeleteByKeyAsync(string key); Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type); } diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 570378e47..9b664a788 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -37,7 +37,7 @@ - + diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 6d128fb61..533a8bfb5 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -327,6 +327,7 @@ public class GlobalSettings : IGlobalSettings public string CertificateThumbprint { get; set; } public string CertificatePassword { get; set; } public string RedisConnectionString { get; set; } + public string CosmosConnectionString { get; set; } public string LicenseKey { get; set; } = "eyJhbGciOiJQUzI1NiIsImtpZCI6IklkZW50aXR5U2VydmVyTGljZW5zZWtleS83Y2VhZGJiNzgxMzA0NjllODgwNjg5MTAyNTQxNGYxNiIsInR5cCI6ImxpY2Vuc2Urand0In0.eyJpc3MiOiJodHRwczovL2R1ZW5kZXNvZnR3YXJlLmNvbSIsImF1ZCI6IklkZW50aXR5U2VydmVyIiwiaWF0IjoxNzAxODIwODAwLCJleHAiOjE3MzM0NDMyMDAsImNvbXBhbnlfbmFtZSI6IkJpdHdhcmRlbiBJbmMuIiwiY29udGFjdF9pbmZvIjoiY29udGFjdEBkdWVuZGVzb2Z0d2FyZS5jb20iLCJlZGl0aW9uIjoiU3RhcnRlciIsImlkIjoiNDMxOSIsImZlYXR1cmUiOlsiaXN2IiwidW5saW1pdGVkX2NsaWVudHMiXSwicHJvZHVjdCI6IkJpdHdhcmRlbiJ9.iLA771PffgIh0ClRS8OWHbg2cAgjhgOkUjRRkLNr9dpQXhYZkVKdpUn-Gw9T7grsGcAx0f4p-TQmtcCpbN9EJCF5jlF0-NfsRTp_gmCgQ5eXyiE4DzJp2OCrz_3STf07N1dILwhD3nk9rzcA6SRQ4_kja8wAMHKnD5LisW98r5DfRDBecRs16KS5HUhg99DRMR5fd9ntfydVMTC_E23eEOHVLsR4YhiSXaEINPjFDG1czyOBClJItDW8g9X8qlClZegr630UjnKKg06A4usoL25VFHHn8Ew3v-_-XdlWoWsIpMMVvacwZT8rwkxjIesFNsXG6yzuROIhaxAvB1297A"; } diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index b44283f6e..ea3082e84 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -338,16 +338,50 @@ public static class CoreHelpers return Encoding.UTF8.GetString(Base64UrlDecode(input)); } + /// + /// Encodes a Base64 URL formatted string. + /// + /// Byte data + /// Base64 URL formatted string public static string Base64UrlEncode(byte[] input) { - var output = Convert.ToBase64String(input) + // Standard base64 encoder + var standardB64 = Convert.ToBase64String(input); + return TransformToBase64Url(standardB64); + } + + /// + /// Transforms a Base64 standard formatted string to a Base64 URL formatted string. + /// + /// Base64 standard formatted string + /// Base64 URL formatted string + public static string TransformToBase64Url(string input) + { + var output = input .Replace('+', '-') .Replace('/', '_') .Replace("=", string.Empty); return output; } + /// + /// Decodes a Base64 URL formatted string. + /// + /// Base64 URL formatted string + /// Data as bytes public static byte[] Base64UrlDecode(string input) + { + var standardB64 = TransformFromBase64Url(input); + // Standard base64 decoder + return Convert.FromBase64String(standardB64); + } + + /// + /// Transforms a Base64 URL formatted string to a Base64 standard formatted string. + /// + /// Base64 URL formatted string + /// Base64 standard formatted string + public static string TransformFromBase64Url(string input) { var output = input; // 62nd char of encoding @@ -370,8 +404,8 @@ public static class CoreHelpers throw new InvalidOperationException("Illegal base64url string!"); } - // Standard base64 decoder - return Convert.FromBase64String(output); + // Standard base64 string output + return output; } public static string PunyEncode(string text) diff --git a/src/Core/Utilities/SystemTextJsonCosmosSerializer.cs b/src/Core/Utilities/SystemTextJsonCosmosSerializer.cs new file mode 100644 index 000000000..8b2b8684e --- /dev/null +++ b/src/Core/Utilities/SystemTextJsonCosmosSerializer.cs @@ -0,0 +1,40 @@ +using System.Text.Json; +using Azure.Core.Serialization; +using Microsoft.Azure.Cosmos; + +namespace Bit.Core.Utilities; + +// ref: https://github.com/Azure/azure-cosmos-dotnet-v3/blob/master/Microsoft.Azure.Cosmos.Samples/Usage/SystemTextJson/CosmosSystemTextJsonSerializer.cs +public class SystemTextJsonCosmosSerializer : CosmosSerializer +{ + private readonly JsonObjectSerializer _systemTextJsonSerializer; + + public SystemTextJsonCosmosSerializer(JsonSerializerOptions jsonSerializerOptions) + { + _systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); + } + + public override T FromStream(Stream stream) + { + using (stream) + { + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + return (T)_systemTextJsonSerializer.Deserialize(stream, typeof(T), default); + } + } + + public override Stream ToStream(T input) + { + var streamPayload = new MemoryStream(); + _systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); + streamPayload.Position = 0; + return streamPayload; + } +} diff --git a/src/Identity/IdentityServer/PersistedGrantStore.cs b/src/Identity/IdentityServer/PersistedGrantStore.cs index 9d8ebffd0..91503a47a 100644 --- a/src/Identity/IdentityServer/PersistedGrantStore.cs +++ b/src/Identity/IdentityServer/PersistedGrantStore.cs @@ -1,18 +1,24 @@ -using Bit.Core.Auth.Repositories; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; using Duende.IdentityServer.Models; using Duende.IdentityServer.Stores; -using Grant = Bit.Core.Auth.Entities.Grant; namespace Bit.Identity.IdentityServer; public class PersistedGrantStore : IPersistedGrantStore { private readonly IGrantRepository _grantRepository; + private readonly Func _toGrant; + private readonly IPersistedGrantStore _fallbackGrantStore; public PersistedGrantStore( - IGrantRepository grantRepository) + IGrantRepository grantRepository, + Func toGrant, + IPersistedGrantStore fallbackGrantStore = null) { _grantRepository = grantRepository; + _toGrant = toGrant; + _fallbackGrantStore = fallbackGrantStore; } public async Task GetAsync(string key) @@ -20,6 +26,11 @@ public class PersistedGrantStore : IPersistedGrantStore var grant = await _grantRepository.GetByKeyAsync(key); if (grant == null) { + if (_fallbackGrantStore != null) + { + // It wasn't found, there is a chance is was instead stored in the fallback store + return await _fallbackGrantStore.GetAsync(key); + } return null; } @@ -47,28 +58,11 @@ public class PersistedGrantStore : IPersistedGrantStore public async Task StoreAsync(PersistedGrant pGrant) { - var grant = ToGrant(pGrant); + var grant = _toGrant(pGrant); await _grantRepository.SaveAsync(grant); } - private Grant ToGrant(PersistedGrant pGrant) - { - return new Grant - { - Key = pGrant.Key, - Type = pGrant.Type, - SubjectId = pGrant.SubjectId, - SessionId = pGrant.SessionId, - ClientId = pGrant.ClientId, - Description = pGrant.Description, - CreationDate = pGrant.CreationTime, - ExpirationDate = pGrant.Expiration, - ConsumedDate = pGrant.ConsumedTime, - Data = pGrant.Data - }; - } - - private PersistedGrant ToPersistedGrant(Grant grant) + private PersistedGrant ToPersistedGrant(IGrant grant) { return new PersistedGrant { diff --git a/src/Identity/Utilities/ServiceCollectionExtensions.cs b/src/Identity/Utilities/ServiceCollectionExtensions.cs index 07d9ef32d..4f3f37f8a 100644 --- a/src/Identity/Utilities/ServiceCollectionExtensions.cs +++ b/src/Identity/Utilities/ServiceCollectionExtensions.cs @@ -1,4 +1,5 @@ -using Bit.Core.IdentityServer; +using Bit.Core.Auth.Repositories; +using Bit.Core.IdentityServer; using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Identity.IdentityServer; @@ -51,31 +52,58 @@ public static class ServiceCollectionExtensions .AddIdentityServerCertificate(env, globalSettings) .AddExtensionGrantValidator(); - if (CoreHelpers.SettingHasValue(globalSettings.IdentityServer.RedisConnectionString)) + if (CoreHelpers.SettingHasValue(globalSettings.IdentityServer.CosmosConnectionString)) { - // If we have redis, prefer it - - // Add the original persisted grant store via it's implementation type - // so we can inject it right after. - services.AddSingleton(); - - services.AddSingleton(sp => - { - return new RedisPersistedGrantStore( - // TODO: .NET 8 create a keyed service for this connection multiplexer and even PersistedGrantStore - ConnectionMultiplexer.Connect(globalSettings.IdentityServer.RedisConnectionString), - sp.GetRequiredService>(), - sp.GetRequiredService() // Fallback grant store - ); - }); + services.AddSingleton(sp => BuildCosmosGrantStore(sp, globalSettings)); + } + else if (CoreHelpers.SettingHasValue(globalSettings.IdentityServer.RedisConnectionString)) + { + services.AddSingleton(sp => BuildRedisGrantStore(sp, globalSettings)); } else { - // Use the original grant store - identityServerBuilder.AddPersistedGrantStore(); + services.AddTransient(sp => BuildSqlGrantStore(sp)); } services.AddTransient(); return identityServerBuilder; } + + private static PersistedGrantStore BuildCosmosGrantStore(IServiceProvider sp, GlobalSettings globalSettings) + { + if (!CoreHelpers.SettingHasValue(globalSettings.IdentityServer.CosmosConnectionString)) + { + throw new ArgumentException("No cosmos config string available."); + } + return new PersistedGrantStore( + // TODO: Perhaps we want to evaluate moving this repo to DI as a keyed service singleton in .NET 8 + new Core.Auth.Repositories.Cosmos.GrantRepository(globalSettings), + g => new Core.Auth.Models.Data.GrantItem(g), + fallbackGrantStore: BuildRedisGrantStore(sp, globalSettings, true)); + } + + private static RedisPersistedGrantStore BuildRedisGrantStore(IServiceProvider sp, + GlobalSettings globalSettings, bool allowNull = false) + { + if (!CoreHelpers.SettingHasValue(globalSettings.IdentityServer.RedisConnectionString)) + { + if (allowNull) + { + return null; + } + throw new ArgumentException("No redis config string available."); + } + + return new RedisPersistedGrantStore( + // TODO: .NET 8 create a keyed service for this connection multiplexer and even PersistedGrantStore + ConnectionMultiplexer.Connect(globalSettings.IdentityServer.RedisConnectionString), + sp.GetRequiredService>(), + fallbackGrantStore: BuildSqlGrantStore(sp)); + } + + private static PersistedGrantStore BuildSqlGrantStore(IServiceProvider sp) + { + return new PersistedGrantStore(sp.GetRequiredService(), + g => new Core.Auth.Entities.Grant(g)); + } } diff --git a/src/Infrastructure.Dapper/Auth/Repositories/GrantRepository.cs b/src/Infrastructure.Dapper/Auth/Repositories/GrantRepository.cs index 250aec8e1..6d004b534 100644 --- a/src/Infrastructure.Dapper/Auth/Repositories/GrantRepository.cs +++ b/src/Infrastructure.Dapper/Auth/Repositories/GrantRepository.cs @@ -1,5 +1,6 @@ using System.Data; using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.Dapper.Repositories; @@ -18,11 +19,11 @@ public class GrantRepository : BaseRepository, IGrantRepository : base(connectionString, readOnlyConnectionString) { } - public async Task GetByKeyAsync(string key) + public async Task GetByKeyAsync(string key) { using (var connection = new SqlConnection(ConnectionString)) { - var results = await connection.QueryAsync( + var results = await connection.QueryAsync( "[dbo].[Grant_ReadByKey]", new { Key = key }, commandType: CommandType.StoredProcedure); @@ -31,12 +32,12 @@ public class GrantRepository : BaseRepository, IGrantRepository } } - public async Task> GetManyAsync(string subjectId, string sessionId, + public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) { using (var connection = new SqlConnection(ConnectionString)) { - var results = await connection.QueryAsync( + var results = await connection.QueryAsync( "[dbo].[Grant_Read]", new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, commandType: CommandType.StoredProcedure); @@ -45,8 +46,13 @@ public class GrantRepository : BaseRepository, IGrantRepository } } - public async Task SaveAsync(Grant obj) + public async Task SaveAsync(IGrant obj) { + if (obj is not Grant gObj) + { + throw new ArgumentException(null, nameof(obj)); + } + using (var connection = new SqlConnection(ConnectionString)) { var results = await connection.ExecuteAsync( diff --git a/src/Infrastructure.EntityFramework/Auth/Repositories/GrantRepository.cs b/src/Infrastructure.EntityFramework/Auth/Repositories/GrantRepository.cs index 671e36c40..f22384afb 100644 --- a/src/Infrastructure.EntityFramework/Auth/Repositories/GrantRepository.cs +++ b/src/Infrastructure.EntityFramework/Auth/Repositories/GrantRepository.cs @@ -1,4 +1,5 @@ using AutoMapper; +using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Infrastructure.EntityFramework.Auth.Models; using Bit.Infrastructure.EntityFramework.Repositories; @@ -42,7 +43,7 @@ public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository } } - public async Task GetByKeyAsync(string key) + public async Task GetByKeyAsync(string key) { using (var scope = ServiceScopeFactory.CreateScope()) { @@ -55,7 +56,7 @@ public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository } } - public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) + public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) { using (var scope = ServiceScopeFactory.CreateScope()) { @@ -67,26 +68,31 @@ public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository g.Type == type select g; var grants = await query.ToListAsync(); - return (ICollection)grants; + return (ICollection)grants; } } - public async Task SaveAsync(Core.Auth.Entities.Grant obj) + public async Task SaveAsync(IGrant obj) { + if (obj is not Core.Auth.Entities.Grant gObj) + { + throw new ArgumentException(null, nameof(obj)); + } + using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); var existingGrant = await (from g in dbContext.Grants - where g.Key == obj.Key + where g.Key == gObj.Key select g).FirstOrDefaultAsync(); if (existingGrant != null) { - obj.Id = existingGrant.Id; - dbContext.Entry(existingGrant).CurrentValues.SetValues(obj); + gObj.Id = existingGrant.Id; + dbContext.Entry(existingGrant).CurrentValues.SetValues(gObj); } else { - var entity = Mapper.Map(obj); + var entity = Mapper.Map(gObj); await dbContext.AddAsync(entity); await dbContext.SaveChangesAsync(); }