1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

[send.key] Update send.key when account encryption key is rotated (#1417)

* Rotate send.key with account encryption key

* Update tests

* Improve and refactor style, fix typo

* Use null instead of empty lists

* Revert "Use null instead of empty lists"

This reverts commit 775a52ca56.

* Fix style (use AddRange instead of reassignment)
This commit is contained in:
Thomas Rittson 2021-07-02 06:27:03 +10:00 committed by GitHub
parent 30ea8b728d
commit 86a12efa76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 162 additions and 27 deletions

View File

@ -35,6 +35,8 @@ namespace Bit.Api.Controllers
private readonly IPaymentService _paymentService;
private readonly IUserRepository _userRepository;
private readonly IUserService _userService;
private readonly ISendRepository _sendRepository;
private readonly ISendService _sendService;
public AccountsController(
GlobalSettings globalSettings,
@ -46,7 +48,9 @@ namespace Bit.Api.Controllers
IPaymentService paymentService,
ISsoUserRepository ssoUserRepository,
IUserRepository userRepository,
IUserService userService)
IUserService userService,
ISendRepository sendRepository,
ISendService sendService)
{
_cipherRepository = cipherRepository;
_folderRepository = folderRepository;
@ -57,6 +61,8 @@ namespace Bit.Api.Controllers
_paymentService = paymentService;
_userRepository = userRepository;
_userService = userService;
_sendRepository = sendRepository;
_sendService = sendService;
}
[HttpPost("prelogin")]
@ -283,26 +289,28 @@ namespace Bit.Api.Controllers
throw new UnauthorizedAccessException();
}
var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id);
var ciphersDict = model.Ciphers?.ToDictionary(c => c.Id.Value);
var ciphers = new List<Cipher>();
if (existingCiphers.Any() && ciphersDict != null)
if (model.Ciphers.Any())
{
foreach (var cipher in existingCiphers.Where(c => ciphersDict.ContainsKey(c.Id)))
{
ciphers.Add(ciphersDict[cipher.Id].ToCipher(cipher));
}
var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id);
ciphers.AddRange(existingCiphers
.Join(model.Ciphers, c => c.Id, c => c.Id, (existing, c) => c.ToCipher(existing)));
}
var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id);
var foldersDict = model.Folders?.ToDictionary(f => f.Id);
var folders = new List<Folder>();
if (existingFolders.Any() && foldersDict != null)
if (model.Folders.Any())
{
foreach (var folder in existingFolders.Where(f => foldersDict.ContainsKey(f.Id)))
{
folders.Add(foldersDict[folder.Id].ToFolder(folder));
}
var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id);
folders.AddRange(existingFolders
.Join(model.Folders, f => f.Id, f => f.Id, (existing, f) => f.ToFolder(existing)));
}
var sends = new List<Send>();
if (model.Sends?.Any() == true)
{
var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id);
sends.AddRange(existingSends
.Join(model.Sends, s => s.Id, s => s.Id, (existing, s) => s.ToSend(existing, _sendService)));
}
var result = await _userService.UpdateKeyAsync(
@ -311,7 +319,8 @@ namespace Bit.Api.Controllers
model.Key,
model.PrivateKey,
ciphers,
folders);
folders,
sends);
if (result.Succeeded)
{

View File

@ -12,6 +12,7 @@ namespace Bit.Core.Models.Api
public IEnumerable<CipherWithIdRequestModel> Ciphers { get; set; }
[Required]
public IEnumerable<FolderWithIdRequestModel> Folders { get; set; }
public IEnumerable<SendWithIdRequestModel> Sends { get; set; }
[Required]
public string PrivateKey { get; set; }
[Required]

View File

@ -130,4 +130,10 @@ namespace Bit.Core.Models.Api
return existingSend;
}
}
public class SendWithIdRequestModel : SendRequestModel
{
[Required]
public Guid? Id { get; set; }
}
}

View File

@ -28,7 +28,7 @@ namespace Bit.Core.Repositories
Task MoveAsync(IEnumerable<Guid> ids, Guid? folderId, Guid userId);
Task DeleteByUserIdAsync(Guid userId);
Task DeleteByOrganizationIdAsync(Guid organizationId);
Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders);
Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends);
Task UpdateCiphersAsync(Guid userId, IEnumerable<Cipher> ciphers);
Task CreateAsync(IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders);
Task CreateAsync(IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections,

View File

@ -282,7 +282,7 @@ namespace Bit.Core.Repositories.SqlServer
}
}
public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders)
public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends)
{
using (var connection = new SqlConnection(ConnectionString))
{
@ -323,7 +323,11 @@ namespace Bit.Core.Repositories.SqlServer
SELECT TOP 0 *
INTO #TempFolder
FROM [dbo].[Folder]";
FROM [dbo].[Folder]
SELECT TOP 0 *
INTO #TempSend
FROM [dbo].[Send]";
using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction))
{
@ -352,6 +356,16 @@ namespace Bit.Core.Repositories.SqlServer
}
}
if (sends.Any())
{
using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction))
{
bulkCopy.DestinationTableName = "#TempSend";
var dataTable = BuildSendsTable(bulkCopy, sends);
bulkCopy.WriteToServer(dataTable);
}
}
// 4. Insert into real tables from temp tables and clean up.
var sql = string.Empty;
@ -389,9 +403,26 @@ namespace Bit.Core.Repositories.SqlServer
F.[UserId] = @UserId";
}
if (sends.Any())
{
sql += @"
UPDATE
[dbo].[Send]
SET
[Key] = TS.[Key],
[RevisionDate] = TS.[RevisionDate]
FROM
[dbo].[Send] S
INNER JOIN
#TempSend TS ON S.Id = TS.Id
WHERE
S.[UserId] = @UserId";
}
sql += @"
DROP TABLE #TempCipher
DROP TABLE #TempFolder";
DROP TABLE #TempFolder
DROP TABLE #TempSend";
using (var cmd = new SqlCommand(sql, connection, transaction))
{
@ -833,6 +864,82 @@ namespace Bit.Core.Repositories.SqlServer
return collectionCiphersTable;
}
private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable<Send> sends)
{
var s = sends.FirstOrDefault();
if (s == null)
{
throw new ApplicationException("Must have some Sends to bulk import.");
}
var sendsTable = new DataTable("SendsDataTable");
var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType());
sendsTable.Columns.Add(idColumn);
var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid));
sendsTable.Columns.Add(userIdColumn);
var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid));
sendsTable.Columns.Add(organizationIdColumn);
var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType());
sendsTable.Columns.Add(typeColumn);
var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType());
sendsTable.Columns.Add(dataColumn);
var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType());
sendsTable.Columns.Add(keyColumn);
var passwordColumn = new DataColumn(nameof(s.Password), typeof(string));
sendsTable.Columns.Add(passwordColumn);
var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int));
sendsTable.Columns.Add(maxAccessCountColumn);
var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType());
sendsTable.Columns.Add(accessCountColumn);
var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType());
sendsTable.Columns.Add(creationDateColumn);
var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType());
sendsTable.Columns.Add(revisionDateColumn);
var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime));
sendsTable.Columns.Add(expirationDateColumn);
var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType());
sendsTable.Columns.Add(deletionDateColumn);
var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType());
sendsTable.Columns.Add(disabledColumn);
var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool));
sendsTable.Columns.Add(hideEmailColumn);
foreach (DataColumn col in sendsTable.Columns)
{
bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName);
}
var keys = new DataColumn[1];
keys[0] = idColumn;
sendsTable.PrimaryKey = keys;
foreach (var send in sends)
{
var row = sendsTable.NewRow();
row[idColumn] = send.Id;
row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value;
row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value;
row[typeColumn] = (short)send.Type;
row[dataColumn] = send.Data;
row[keyColumn] = send.Key;
row[passwordColumn] = send.Password;
row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value;
row[accessCountColumn] = send.AccessCount;
row[creationDateColumn] = send.CreationDate;
row[revisionDateColumn] = send.RevisionDate;
row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value;
row[deletionDateColumn] = send.DeletionDate;
row[disabledColumn] = send.Disabled;
row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value;
sendsTable.Rows.Add(row);
}
return sendsTable;
}
public class CipherDetailsWithCollections : CipherDetails
{
public DataTable CollectionIds { get; set; }

View File

@ -38,7 +38,7 @@ namespace Bit.Core.Services
Task<IdentityResult> ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key,
KdfType kdf, int kdfIterations);
Task<IdentityResult> UpdateKeyAsync(User user, string masterPassword, string key, string privateKey,
IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders);
IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends);
Task<IdentityResult> RefreshSecurityStampAsync(User user, string masterPasswordHash);
Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true);
Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type,

View File

@ -52,6 +52,7 @@ namespace Bit.Core.Services
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
private readonly IOrganizationService _organizationService;
private readonly ISendRepository _sendRepository;
public UserService(
IUserRepository userRepository,
@ -79,7 +80,8 @@ namespace Bit.Core.Services
IFido2 fido2,
ICurrentContext currentContext,
GlobalSettings globalSettings,
IOrganizationService organizationService)
IOrganizationService organizationService,
ISendRepository sendRepository)
: base(
store,
optionsAccessor,
@ -113,6 +115,7 @@ namespace Bit.Core.Services
_currentContext = currentContext;
_globalSettings = globalSettings;
_organizationService = organizationService;
_sendRepository = sendRepository;
}
public Guid? GetProperUserId(ClaimsPrincipal principal)
@ -726,7 +729,7 @@ namespace Bit.Core.Services
}
public async Task<IdentityResult> UpdateKeyAsync(User user, string masterPassword, string key, string privateKey,
IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders)
IEnumerable<Cipher> ciphers, IEnumerable<Folder> folders, IEnumerable<Send> sends)
{
if (user == null)
{
@ -739,9 +742,9 @@ namespace Bit.Core.Services
user.SecurityStamp = Guid.NewGuid().ToString();
user.Key = key;
user.PrivateKey = privateKey;
if (ciphers.Any() || folders.Any())
if (ciphers.Any() || folders.Any() || sends.Any())
{
await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders);
await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends);
}
else
{

View File

@ -30,6 +30,8 @@ namespace Bit.Api.Test.Controllers
private readonly ISsoUserRepository _ssoUserRepository;
private readonly IUserRepository _userRepository;
private readonly IUserService _userService;
private readonly ISendRepository _sendRepository;
private readonly ISendService _sendService;
private readonly IProviderUserRepository _providerUserRepository;
public AccountsControllerTests()
@ -43,6 +45,8 @@ namespace Bit.Api.Test.Controllers
_providerUserRepository = Substitute.For<IProviderUserRepository>();
_paymentService = Substitute.For<IPaymentService>();
_globalSettings = new GlobalSettings();
_sendRepository = Substitute.For<ISendRepository>();
_sendService = Substitute.For<ISendService>();
_sut = new AccountsController(
_globalSettings,
_cipherRepository,
@ -53,7 +57,9 @@ namespace Bit.Api.Test.Controllers
_paymentService,
_ssoUserRepository,
_userRepository,
_userService
_userService,
_sendRepository,
_sendService
);
}

View File

@ -45,6 +45,7 @@ namespace Bit.Core.Test.Services
private readonly CurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
private readonly IOrganizationService _organizationService;
private readonly ISendRepository _sendRepository;
public UserServiceTests()
{
@ -74,6 +75,7 @@ namespace Bit.Core.Test.Services
_currentContext = new CurrentContext();
_globalSettings = new GlobalSettings();
_organizationService = Substitute.For<IOrganizationService>();
_sendRepository = Substitute.For<ISendRepository>();
_sut = new UserService(
_userRepository,
@ -101,7 +103,8 @@ namespace Bit.Core.Test.Services
_fido2,
_currentContext,
_globalSettings,
_organizationService
_organizationService,
_sendRepository
);
}