1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-22 12:15:36 +01:00

[send.key] Update send.key when account encryption key is rotated (#1417)

* Rotate send.key with account encryption key

* Update tests

* Improve and refactor style, fix typo

* Use null instead of empty lists

* Revert "Use null instead of empty lists"

This reverts commit 775a52ca56.

* Fix style (use AddRange instead of reassignment)
This commit is contained in:
Thomas Rittson 2021-07-02 06:27:03 +10:00 committed by GitHub
parent 30ea8b728d
commit 86a12efa76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 162 additions and 27 deletions

View File

@ -35,6 +35,8 @@ namespace Bit.Api.Controllers
private readonly IPaymentService _paymentService; private readonly IPaymentService _paymentService;
private readonly IUserRepository _userRepository; private readonly IUserRepository _userRepository;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly ISendRepository _sendRepository;
private readonly ISendService _sendService;
public AccountsController( public AccountsController(
GlobalSettings globalSettings, GlobalSettings globalSettings,
@ -46,7 +48,9 @@ namespace Bit.Api.Controllers
IPaymentService paymentService, IPaymentService paymentService,
ISsoUserRepository ssoUserRepository, ISsoUserRepository ssoUserRepository,
IUserRepository userRepository, IUserRepository userRepository,
IUserService userService) IUserService userService,
ISendRepository sendRepository,
ISendService sendService)
{ {
_cipherRepository = cipherRepository; _cipherRepository = cipherRepository;
_folderRepository = folderRepository; _folderRepository = folderRepository;
@ -57,6 +61,8 @@ namespace Bit.Api.Controllers
_paymentService = paymentService; _paymentService = paymentService;
_userRepository = userRepository; _userRepository = userRepository;
_userService = userService; _userService = userService;
_sendRepository = sendRepository;
_sendService = sendService;
} }
[HttpPost("prelogin")] [HttpPost("prelogin")]
@ -283,26 +289,28 @@ namespace Bit.Api.Controllers
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
} }
var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id);
var ciphersDict = model.Ciphers?.ToDictionary(c => c.Id.Value);
var ciphers = new List<Cipher>(); var ciphers = new List<Cipher>();
if (existingCiphers.Any() && ciphersDict != null) if (model.Ciphers.Any())
{ {
foreach (var cipher in existingCiphers.Where(c => ciphersDict.ContainsKey(c.Id))) var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id);
{ ciphers.AddRange(existingCiphers
ciphers.Add(ciphersDict[cipher.Id].ToCipher(cipher)); .Join(model.Ciphers, c => c.Id, c => c.Id, (existing, c) => c.ToCipher(existing)));
}
} }
var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id);
var foldersDict = model.Folders?.ToDictionary(f => f.Id);
var folders = new List<Folder>(); var folders = new List<Folder>();
if (existingFolders.Any() && foldersDict != null) if (model.Folders.Any())
{ {
foreach (var folder in existingFolders.Where(f => foldersDict.ContainsKey(f.Id))) var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id);
{ folders.AddRange(existingFolders
folders.Add(foldersDict[folder.Id].ToFolder(folder)); .Join(model.Folders, f => f.Id, f => f.Id, (existing, f) => f.ToFolder(existing)));
} }
var sends = new List<Send>();
if (model.Sends?.Any() == true)
{
var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id);
sends.AddRange(existingSends
.Join(model.Sends, s => s.Id, s => s.Id, (existing, s) => s.ToSend(existing, _sendService)));
} }
var result = await _userService.UpdateKeyAsync( var result = await _userService.UpdateKeyAsync(
@ -311,7 +319,8 @@ namespace Bit.Api.Controllers
model.Key, model.Key,
model.PrivateKey, model.PrivateKey,
ciphers, ciphers,
folders); folders,
sends);
if (result.Succeeded) if (result.Succeeded)
{ {

View File

@ -12,6 +12,7 @@ namespace Bit.Core.Models.Api
public IEnumerable<CipherWithIdRequestModel> Ciphers { get; set; } public IEnumerable<CipherWithIdRequestModel> Ciphers { get; set; }
[Required] [Required]
public IEnumerable<FolderWithIdRequestModel> Folders { get; set; } public IEnumerable<FolderWithIdRequestModel> Folders { get; set; }
public IEnumerable<SendWithIdRequestModel> Sends { get; set; }
[Required] [Required]
public string PrivateKey { get; set; } public string PrivateKey { get; set; }
[Required] [Required]

View File

@ -130,4 +130,10 @@ namespace Bit.Core.Models.Api
return existingSend; return existingSend;
} }
} }
public class SendWithIdRequestModel : SendRequestModel
{
[Required]
public Guid? Id { get; set; }
}
} }

View File

@ -28,7 +28,7 @@ namespace Bit.Core.Repositories
Task MoveAsync(IEnumerable<Guid> ids, Guid? folderId, Guid userId); Task MoveAsync(IEnumerable<Guid> ids, Guid? folderId, Guid userId);
Task DeleteByUserIdAsync(Guid userId); Task DeleteByUserIdAsync(Guid userId);
Task DeleteByOrganizationIdAsync(Guid organizationId); Task DeleteByOrganizationIdAsync(Guid organizationId);
Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders); Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends);
Task UpdateCiphersAsync(Guid userId, IEnumerable<Cipher> ciphers); Task UpdateCiphersAsync(Guid userId, IEnumerable<Cipher> ciphers);
Task CreateAsync(IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders); Task CreateAsync(IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders);
Task CreateAsync(IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections, Task CreateAsync(IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections,

View File

@ -282,7 +282,7 @@ namespace Bit.Core.Repositories.SqlServer
} }
} }
public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders) public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends)
{ {
using (var connection = new SqlConnection(ConnectionString)) using (var connection = new SqlConnection(ConnectionString))
{ {
@ -323,7 +323,11 @@ namespace Bit.Core.Repositories.SqlServer
SELECT TOP 0 * SELECT TOP 0 *
INTO #TempFolder INTO #TempFolder
FROM [dbo].[Folder]"; FROM [dbo].[Folder]
SELECT TOP 0 *
INTO #TempSend
FROM [dbo].[Send]";
using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction))
{ {
@ -352,6 +356,16 @@ namespace Bit.Core.Repositories.SqlServer
} }
} }
if (sends.Any())
{
using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction))
{
bulkCopy.DestinationTableName = "#TempSend";
var dataTable = BuildSendsTable(bulkCopy, sends);
bulkCopy.WriteToServer(dataTable);
}
}
// 4. Insert into real tables from temp tables and clean up. // 4. Insert into real tables from temp tables and clean up.
var sql = string.Empty; var sql = string.Empty;
@ -389,9 +403,26 @@ namespace Bit.Core.Repositories.SqlServer
F.[UserId] = @UserId"; F.[UserId] = @UserId";
} }
if (sends.Any())
{
sql += @"
UPDATE
[dbo].[Send]
SET
[Key] = TS.[Key],
[RevisionDate] = TS.[RevisionDate]
FROM
[dbo].[Send] S
INNER JOIN
#TempSend TS ON S.Id = TS.Id
WHERE
S.[UserId] = @UserId";
}
sql += @" sql += @"
DROP TABLE #TempCipher DROP TABLE #TempCipher
DROP TABLE #TempFolder"; DROP TABLE #TempFolder
DROP TABLE #TempSend";
using (var cmd = new SqlCommand(sql, connection, transaction)) using (var cmd = new SqlCommand(sql, connection, transaction))
{ {
@ -833,6 +864,82 @@ namespace Bit.Core.Repositories.SqlServer
return collectionCiphersTable; return collectionCiphersTable;
} }
private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable<Send> sends)
{
var s = sends.FirstOrDefault();
if (s == null)
{
throw new ApplicationException("Must have some Sends to bulk import.");
}
var sendsTable = new DataTable("SendsDataTable");
var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType());
sendsTable.Columns.Add(idColumn);
var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid));
sendsTable.Columns.Add(userIdColumn);
var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid));
sendsTable.Columns.Add(organizationIdColumn);
var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType());
sendsTable.Columns.Add(typeColumn);
var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType());
sendsTable.Columns.Add(dataColumn);
var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType());
sendsTable.Columns.Add(keyColumn);
var passwordColumn = new DataColumn(nameof(s.Password), typeof(string));
sendsTable.Columns.Add(passwordColumn);
var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int));
sendsTable.Columns.Add(maxAccessCountColumn);
var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType());
sendsTable.Columns.Add(accessCountColumn);
var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType());
sendsTable.Columns.Add(creationDateColumn);
var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType());
sendsTable.Columns.Add(revisionDateColumn);
var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime));
sendsTable.Columns.Add(expirationDateColumn);
var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType());
sendsTable.Columns.Add(deletionDateColumn);
var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType());
sendsTable.Columns.Add(disabledColumn);
var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool));
sendsTable.Columns.Add(hideEmailColumn);
foreach (DataColumn col in sendsTable.Columns)
{
bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName);
}
var keys = new DataColumn[1];
keys[0] = idColumn;
sendsTable.PrimaryKey = keys;
foreach (var send in sends)
{
var row = sendsTable.NewRow();
row[idColumn] = send.Id;
row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value;
row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value;
row[typeColumn] = (short)send.Type;
row[dataColumn] = send.Data;
row[keyColumn] = send.Key;
row[passwordColumn] = send.Password;
row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value;
row[accessCountColumn] = send.AccessCount;
row[creationDateColumn] = send.CreationDate;
row[revisionDateColumn] = send.RevisionDate;
row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value;
row[deletionDateColumn] = send.DeletionDate;
row[disabledColumn] = send.Disabled;
row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value;
sendsTable.Rows.Add(row);
}
return sendsTable;
}
public class CipherDetailsWithCollections : CipherDetails public class CipherDetailsWithCollections : CipherDetails
{ {
public DataTable CollectionIds { get; set; } public DataTable CollectionIds { get; set; }

View File

@ -38,7 +38,7 @@ namespace Bit.Core.Services
Task<IdentityResult> ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key, Task<IdentityResult> ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key,
KdfType kdf, int kdfIterations); KdfType kdf, int kdfIterations);
Task<IdentityResult> UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, Task<IdentityResult> UpdateKeyAsync(User user, string masterPassword, string key, string privateKey,
IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders); IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends);
Task<IdentityResult> RefreshSecurityStampAsync(User user, string masterPasswordHash); Task<IdentityResult> RefreshSecurityStampAsync(User user, string masterPasswordHash);
Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true); Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true);
Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type,

View File

@ -52,6 +52,7 @@ namespace Bit.Core.Services
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IOrganizationService _organizationService; private readonly IOrganizationService _organizationService;
private readonly ISendRepository _sendRepository;
public UserService( public UserService(
IUserRepository userRepository, IUserRepository userRepository,
@ -79,7 +80,8 @@ namespace Bit.Core.Services
IFido2 fido2, IFido2 fido2,
ICurrentContext currentContext, ICurrentContext currentContext,
GlobalSettings globalSettings, GlobalSettings globalSettings,
IOrganizationService organizationService) IOrganizationService organizationService,
ISendRepository sendRepository)
: base( : base(
store, store,
optionsAccessor, optionsAccessor,
@ -113,6 +115,7 @@ namespace Bit.Core.Services
_currentContext = currentContext; _currentContext = currentContext;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_organizationService = organizationService; _organizationService = organizationService;
_sendRepository = sendRepository;
} }
public Guid? GetProperUserId(ClaimsPrincipal principal) public Guid? GetProperUserId(ClaimsPrincipal principal)
@ -726,7 +729,7 @@ namespace Bit.Core.Services
} }
public async Task<IdentityResult> UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, public async Task<IdentityResult> UpdateKeyAsync(User user, string masterPassword, string key, string privateKey,
IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders) IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends)
{ {
if (user == null) if (user == null)
{ {
@ -739,9 +742,9 @@ namespace Bit.Core.Services
user.SecurityStamp = Guid.NewGuid().ToString(); user.SecurityStamp = Guid.NewGuid().ToString();
user.Key = key; user.Key = key;
user.PrivateKey = privateKey; user.PrivateKey = privateKey;
if (ciphers.Any() || folders.Any()) if (ciphers.Any() || folders.Any() || sends.Any())
{ {
await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders); await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends);
} }
else else
{ {

View File

@ -30,6 +30,8 @@ namespace Bit.Api.Test.Controllers
private readonly ISsoUserRepository _ssoUserRepository; private readonly ISsoUserRepository _ssoUserRepository;
private readonly IUserRepository _userRepository; private readonly IUserRepository _userRepository;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly ISendRepository _sendRepository;
private readonly ISendService _sendService;
private readonly IProviderUserRepository _providerUserRepository; private readonly IProviderUserRepository _providerUserRepository;
public AccountsControllerTests() public AccountsControllerTests()
@ -43,6 +45,8 @@ namespace Bit.Api.Test.Controllers
_providerUserRepository = Substitute.For<IProviderUserRepository>(); _providerUserRepository = Substitute.For<IProviderUserRepository>();
_paymentService = Substitute.For<IPaymentService>(); _paymentService = Substitute.For<IPaymentService>();
_globalSettings = new GlobalSettings(); _globalSettings = new GlobalSettings();
_sendRepository = Substitute.For<ISendRepository>();
_sendService = Substitute.For<ISendService>();
_sut = new AccountsController( _sut = new AccountsController(
_globalSettings, _globalSettings,
_cipherRepository, _cipherRepository,
@ -53,7 +57,9 @@ namespace Bit.Api.Test.Controllers
_paymentService, _paymentService,
_ssoUserRepository, _ssoUserRepository,
_userRepository, _userRepository,
_userService _userService,
_sendRepository,
_sendService
); );
} }

View File

@ -45,6 +45,7 @@ namespace Bit.Core.Test.Services
private readonly CurrentContext _currentContext; private readonly CurrentContext _currentContext;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IOrganizationService _organizationService; private readonly IOrganizationService _organizationService;
private readonly ISendRepository _sendRepository;
public UserServiceTests() public UserServiceTests()
{ {
@ -74,6 +75,7 @@ namespace Bit.Core.Test.Services
_currentContext = new CurrentContext(); _currentContext = new CurrentContext();
_globalSettings = new GlobalSettings(); _globalSettings = new GlobalSettings();
_organizationService = Substitute.For<IOrganizationService>(); _organizationService = Substitute.For<IOrganizationService>();
_sendRepository = Substitute.For<ISendRepository>();
_sut = new UserService( _sut = new UserService(
_userRepository, _userRepository,
@ -101,7 +103,8 @@ namespace Bit.Core.Test.Services
_fido2, _fido2,
_currentContext, _currentContext,
_globalSettings, _globalSettings,
_organizationService _organizationService,
_sendRepository
); );
} }