From 9459cda304c906188879c24625446fa10e436711 Mon Sep 17 00:00:00 2001 From: Matt Gibson Date: Mon, 26 Aug 2024 17:44:08 -0700 Subject: [PATCH] Pm-10953/add-user-context-to-sync-replaces (#10627) * Require userId for setting masterKeyEncryptedUserKey * Replace folders for specified user * Require userId for collection replace * Cipher Replace requires userId * Require UserId to update equivalent domains * Require userId for policy replace * sync state updates between fake state for better testing * Revert to public observable tests Since they now sync, we can test single-user updates impacting active user observables * Do not init fake states through sync Do not sync initial null values, that might wipe out already existing data. * Require userId for Send replace * Include userId for organization replace * Require userId for billing sync data * Require user Id for key connector sync data * Allow decode of token by userId * Require userId for synced key connector updates * Add userId to policy setting during organization invite accept * Fix cli * Handle null userId --------- Co-authored-by: bnagawiecki <107435978+bnagawiecki@users.noreply.github.com> --- apps/cli/src/auth/commands/login.command.ts | 6 +- apps/cli/src/auth/commands/unlock.command.ts | 1 + apps/cli/src/base-program.ts | 20 +- .../convert-to-key-connector.command.ts | 4 +- apps/cli/src/program.ts | 4 +- .../web/src/app/auth/login/login.component.ts | 5 +- .../src/auth/components/login.component.ts | 5 +- .../auth-request-login.strategy.spec.ts | 10 +- .../auth-request-login.strategy.ts | 2 +- .../common/login-strategies/login.strategy.ts | 6 +- .../user-api-login.strategy.spec.ts | 5 +- .../user-api-login.strategy.ts | 2 +- libs/common/spec/fake-state-provider.ts | 159 ++++++--- libs/common/spec/fake-state.ts | 60 +++- .../policy/policy.service.abstraction.ts | 2 +- .../services/policy/policy.service.spec.ts | 12 +- .../services/policy/policy.service.ts | 4 +- .../abstractions/key-connector.service.ts | 14 +- .../src/auth/abstractions/token.service.ts | 8 +- .../services/key-connector.service.spec.ts | 8 +- .../auth/services/key-connector.service.ts | 34 +- .../src/auth/services/token.service.spec.ts | 319 ++++++++---------- .../common/src/auth/services/token.service.ts | 14 +- .../services/domain-settings.service.ts | 7 +- .../billing-account-profile-state.service.ts | 3 + ...ling-account-profile-state.service.spec.ts | 15 +- .../billing-account-profile-state.service.ts | 6 +- .../platform/abstractions/crypto.service.ts | 2 +- .../biometric-state.service.spec.ts | 2 +- .../platform/services/crypto.service.spec.ts | 22 +- .../src/platform/services/crypto.service.ts | 2 +- .../default-state.provider.spec.ts | 2 +- .../src/platform/sync/default-sync.service.ts | 53 +-- .../send-state.provider.abstraction.ts | 10 +- .../send/services/send-state.provider.spec.ts | 4 +- .../send/services/send-state.provider.ts | 13 +- .../send/services/send.service.abstraction.ts | 2 +- .../tools/send/services/send.service.spec.ts | 240 ++++++++----- .../src/tools/send/services/send.service.ts | 26 +- .../src/vault/abstractions/cipher.service.ts | 2 +- .../vault/abstractions/collection.service.ts | 4 +- .../folder/folder.service.abstraction.ts | 2 +- .../src/vault/services/cipher.service.ts | 17 +- .../src/vault/services/collection.service.ts | 6 +- .../services/folder/folder.service.spec.ts | 2 +- .../vault/services/folder/folder.service.ts | 4 +- 46 files changed, 666 insertions(+), 484 deletions(-) diff --git a/apps/cli/src/auth/commands/login.command.ts b/apps/cli/src/auth/commands/login.command.ts index 3b67f95540..9a69bcc3c0 100644 --- a/apps/cli/src/auth/commands/login.command.ts +++ b/apps/cli/src/auth/commands/login.command.ts @@ -342,7 +342,7 @@ export class LoginCommand { } } - return await this.handleSuccessResponse(); + return await this.handleSuccessResponse(response); } catch (e) { return Response.error(e); } @@ -353,8 +353,8 @@ export class LoginCommand { process.env.BW_SESSION = Utils.fromBufferToB64(key); } - private async handleSuccessResponse(): Promise { - const usesKeyConnector = await this.keyConnectorService.getUsesKeyConnector(); + private async handleSuccessResponse(response: AuthResult): Promise { + const usesKeyConnector = await this.keyConnectorService.getUsesKeyConnector(response.userId); if ( (this.options.sso != null || this.options.apikey != null) && diff --git a/apps/cli/src/auth/commands/unlock.command.ts b/apps/cli/src/auth/commands/unlock.command.ts index f4486ff966..bebaa94604 100644 --- a/apps/cli/src/auth/commands/unlock.command.ts +++ b/apps/cli/src/auth/commands/unlock.command.ts @@ -73,6 +73,7 @@ export class UnlockCommand { if (await this.keyConnectorService.getConvertAccountRequired()) { const convertToKeyConnectorCommand = new ConvertToKeyConnectorCommand( + userId, this.keyConnectorService, this.environmentService, this.syncService, diff --git a/apps/cli/src/base-program.ts b/apps/cli/src/base-program.ts index f308bdc2de..e4340b68e2 100644 --- a/apps/cli/src/base-program.ts +++ b/apps/cli/src/base-program.ts @@ -116,20 +116,30 @@ export abstract class BaseProgram { } } + /** + * Exist if no user is authenticated + * @returns the userId of the active account + */ protected async exitIfNotAuthed() { - const authed = await this.serviceContainer.stateService.getIsAuthenticated(); - if (!authed) { - this.processResponse(Response.error("You are not logged in."), true); + const fail = () => this.processResponse(Response.error("You are not logged in."), true); + const userId = (await firstValueFrom(this.serviceContainer.accountService.activeAccount$))?.id; + if (!userId) { + fail(); } + const authed = await this.serviceContainer.stateService.getIsAuthenticated({ userId }); + if (!authed) { + fail(); + } + return userId; } protected async exitIfLocked() { - await this.exitIfNotAuthed(); + const userId = await this.exitIfNotAuthed(); if (await this.serviceContainer.cryptoService.hasUserKey()) { return; } else if (process.env.BW_NOINTERACTION !== "true") { // must unlock - if (await this.serviceContainer.keyConnectorService.getUsesKeyConnector()) { + if (await this.serviceContainer.keyConnectorService.getUsesKeyConnector(userId)) { const response = Response.error( "Your vault is locked. You must unlock your vault using your session key.\n" + "If you do not have your session key, you can get a new one by logging out and logging in again.", diff --git a/apps/cli/src/commands/convert-to-key-connector.command.ts b/apps/cli/src/commands/convert-to-key-connector.command.ts index 654606dc06..0dbdbb4325 100644 --- a/apps/cli/src/commands/convert-to-key-connector.command.ts +++ b/apps/cli/src/commands/convert-to-key-connector.command.ts @@ -7,6 +7,7 @@ import { EnvironmentService, Region, } from "@bitwarden/common/platform/abstractions/environment.service"; +import { UserId } from "@bitwarden/common/types/guid"; import { SyncService } from "@bitwarden/common/vault/abstractions/sync/sync.service.abstraction"; import { Response } from "../models/response"; @@ -14,6 +15,7 @@ import { MessageResponse } from "../models/response/message.response"; export class ConvertToKeyConnectorCommand { constructor( + private readonly userId: UserId, private keyConnectorService: KeyConnectorService, private environmentService: EnvironmentService, private syncService: SyncService, @@ -68,7 +70,7 @@ export class ConvertToKeyConnectorCommand { } await this.keyConnectorService.removeConvertAccountRequired(); - await this.keyConnectorService.setUsesKeyConnector(true); + await this.keyConnectorService.setUsesKeyConnector(true, this.userId); // Update environment URL - required for api key login const env = await firstValueFrom(this.environmentService.environment$); diff --git a/apps/cli/src/program.ts b/apps/cli/src/program.ts index 51c4b39e98..6ecdb24931 100644 --- a/apps/cli/src/program.ts +++ b/apps/cli/src/program.ts @@ -206,9 +206,9 @@ export class Program extends BaseProgram { writeLn("", true); }) .action(async (cmd) => { - await this.exitIfNotAuthed(); + const userId = await this.exitIfNotAuthed(); - if (await this.serviceContainer.keyConnectorService.getUsesKeyConnector()) { + if (await this.serviceContainer.keyConnectorService.getUsesKeyConnector(userId)) { const logoutCommand = new LogoutCommand( this.serviceContainer.authService, this.serviceContainer.i18nService, diff --git a/apps/web/src/app/auth/login/login.component.ts b/apps/web/src/app/auth/login/login.component.ts index d0a4376556..145d766627 100644 --- a/apps/web/src/app/auth/login/login.component.ts +++ b/apps/web/src/app/auth/login/login.component.ts @@ -28,6 +28,7 @@ import { LogService } from "@bitwarden/common/platform/abstractions/log.service" import { PlatformUtilsService } from "@bitwarden/common/platform/abstractions/platform-utils.service"; import { StateService } from "@bitwarden/common/platform/abstractions/state.service"; import { PasswordStrengthServiceAbstraction } from "@bitwarden/common/tools/password-strength"; +import { UserId } from "@bitwarden/common/types/guid"; import { PasswordGenerationServiceAbstraction } from "@bitwarden/generator-legacy"; import { flagEnabled } from "../../../utils/flags"; @@ -129,7 +130,7 @@ export class LoginComponent extends BaseLoginComponent implements OnInit { } } - async goAfterLogIn() { + async goAfterLogIn(userId: UserId) { const masterPassword = this.formGroup.value.masterPassword; // Check master password against policy @@ -150,7 +151,7 @@ export class LoginComponent extends BaseLoginComponent implements OnInit { ) { const policiesData: { [id: string]: PolicyData } = {}; this.policies.map((p) => (policiesData[p.id] = PolicyData.fromPolicy(p))); - await this.policyService.replace(policiesData); + await this.policyService.replace(policiesData, userId); await this.router.navigate(["update-password"]); return; } diff --git a/libs/angular/src/auth/components/login.component.ts b/libs/angular/src/auth/components/login.component.ts index 057d67b152..40880b514a 100644 --- a/libs/angular/src/auth/components/login.component.ts +++ b/libs/angular/src/auth/components/login.component.ts @@ -23,6 +23,7 @@ import { LogService } from "@bitwarden/common/platform/abstractions/log.service" import { PlatformUtilsService } from "@bitwarden/common/platform/abstractions/platform-utils.service"; import { StateService } from "@bitwarden/common/platform/abstractions/state.service"; import { Utils } from "@bitwarden/common/platform/misc/utils"; +import { UserId } from "@bitwarden/common/types/guid"; import { PasswordGenerationServiceAbstraction } from "@bitwarden/generator-legacy"; import { @@ -39,7 +40,7 @@ export class LoginComponent extends CaptchaProtectedComponent implements OnInit, showPassword = false; formPromise: Promise; onSuccessfulLogin: () => Promise; - onSuccessfulLoginNavigate: () => Promise; + onSuccessfulLoginNavigate: (userId: UserId) => Promise; onSuccessfulLoginTwoFactorNavigate: () => Promise; onSuccessfulLoginForceResetNavigate: () => Promise; showLoginWithDevice: boolean; @@ -185,7 +186,7 @@ export class LoginComponent extends CaptchaProtectedComponent implements OnInit, if (this.onSuccessfulLoginNavigate != null) { // FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling. // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.onSuccessfulLoginNavigate(); + this.onSuccessfulLoginNavigate(response.userId); } else { this.loginEmailService.clearValues(); // FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling. diff --git a/libs/auth/src/common/login-strategies/auth-request-login.strategy.spec.ts b/libs/auth/src/common/login-strategies/auth-request-login.strategy.spec.ts index 9e9efa12ba..b112e5aa2a 100644 --- a/libs/auth/src/common/login-strategies/auth-request-login.strategy.spec.ts +++ b/libs/auth/src/common/login-strategies/auth-request-login.strategy.spec.ts @@ -158,7 +158,10 @@ describe("AuthRequestLoginStrategy", () => { decMasterKeyHash, mockUserId, ); - expect(cryptoService.setMasterKeyEncryptedUserKey).toHaveBeenCalledWith(tokenResponse.key); + expect(cryptoService.setMasterKeyEncryptedUserKey).toHaveBeenCalledWith( + tokenResponse.key, + mockUserId, + ); expect(cryptoService.setUserKey).toHaveBeenCalledWith(userKey, mockUserId); expect(deviceTrustService.trustDeviceIfRequired).toHaveBeenCalled(); expect(cryptoService.setPrivateKey).toHaveBeenCalledWith(tokenResponse.privateKey, mockUserId); @@ -183,7 +186,10 @@ describe("AuthRequestLoginStrategy", () => { expect(masterPasswordService.mock.setMasterKeyHash).not.toHaveBeenCalled(); // setMasterKeyEncryptedUserKey, setUserKey, and setPrivateKey should still be called - expect(cryptoService.setMasterKeyEncryptedUserKey).toHaveBeenCalledWith(tokenResponse.key); + expect(cryptoService.setMasterKeyEncryptedUserKey).toHaveBeenCalledWith( + tokenResponse.key, + mockUserId, + ); expect(cryptoService.setUserKey).toHaveBeenCalledWith(decUserKey, mockUserId); expect(cryptoService.setPrivateKey).toHaveBeenCalledWith(tokenResponse.privateKey, mockUserId); diff --git a/libs/auth/src/common/login-strategies/auth-request-login.strategy.ts b/libs/auth/src/common/login-strategies/auth-request-login.strategy.ts index 9998abb30d..ae0024d218 100644 --- a/libs/auth/src/common/login-strategies/auth-request-login.strategy.ts +++ b/libs/auth/src/common/login-strategies/auth-request-login.strategy.ts @@ -99,7 +99,7 @@ export class AuthRequestLoginStrategy extends LoginStrategy { const authRequestCredentials = this.cache.value.authRequestCredentials; // User now may or may not have a master password // but set the master key encrypted user key if it exists regardless - await this.cryptoService.setMasterKeyEncryptedUserKey(response.key); + await this.cryptoService.setMasterKeyEncryptedUserKey(response.key, userId); if (authRequestCredentials.decryptedUserKey) { await this.cryptoService.setUserKey(authRequestCredentials.decryptedUserKey, userId); diff --git a/libs/auth/src/common/login-strategies/login.strategy.ts b/libs/auth/src/common/login-strategies/login.strategy.ts index 2065f898be..ff6bf07af7 100644 --- a/libs/auth/src/common/login-strategies/login.strategy.ts +++ b/libs/auth/src/common/login-strategies/login.strategy.ts @@ -222,7 +222,11 @@ export abstract class LoginStrategy { ), ); - await this.billingAccountProfileStateService.setHasPremium(accountInformation.premium, false); + await this.billingAccountProfileStateService.setHasPremium( + accountInformation.premium, + false, + userId, + ); return userId; } diff --git a/libs/auth/src/common/login-strategies/user-api-login.strategy.spec.ts b/libs/auth/src/common/login-strategies/user-api-login.strategy.spec.ts index 6b9cddd99c..1661449796 100644 --- a/libs/auth/src/common/login-strategies/user-api-login.strategy.spec.ts +++ b/libs/auth/src/common/login-strategies/user-api-login.strategy.spec.ts @@ -172,7 +172,10 @@ describe("UserApiLoginStrategy", () => { await apiLogInStrategy.logIn(credentials); - expect(cryptoService.setMasterKeyEncryptedUserKey).toHaveBeenCalledWith(tokenResponse.key); + expect(cryptoService.setMasterKeyEncryptedUserKey).toHaveBeenCalledWith( + tokenResponse.key, + userId, + ); expect(cryptoService.setPrivateKey).toHaveBeenCalledWith(tokenResponse.privateKey, userId); }); diff --git a/libs/auth/src/common/login-strategies/user-api-login.strategy.ts b/libs/auth/src/common/login-strategies/user-api-login.strategy.ts index 1faac3f6c7..3b112c79a0 100644 --- a/libs/auth/src/common/login-strategies/user-api-login.strategy.ts +++ b/libs/auth/src/common/login-strategies/user-api-login.strategy.ts @@ -64,7 +64,7 @@ export class UserApiLoginStrategy extends LoginStrategy { response: IdentityTokenResponse, userId: UserId, ): Promise { - await this.cryptoService.setMasterKeyEncryptedUserKey(response.key); + await this.cryptoService.setMasterKeyEncryptedUserKey(response.key, userId); if (response.apiUseKeyConnector) { const masterKey = await firstValueFrom(this.masterPasswordService.masterKey$(userId)); diff --git a/libs/common/spec/fake-state-provider.ts b/libs/common/spec/fake-state-provider.ts index cd868931f2..666487ecf0 100644 --- a/libs/common/spec/fake-state-provider.ts +++ b/libs/common/spec/fake-state-provider.ts @@ -32,7 +32,7 @@ export class FakeGlobalStateProvider implements GlobalStateProvider { states: Map> = new Map(); get(keyDefinition: KeyDefinition): GlobalState { this.mock.get(keyDefinition); - const cacheKey = `${keyDefinition.fullName}_${keyDefinition.stateDefinition.defaultStorageLocation}`; + const cacheKey = this.cacheKey(keyDefinition); let result = this.states.get(cacheKey); if (result == null) { @@ -53,94 +53,143 @@ export class FakeGlobalStateProvider implements GlobalStateProvider { return result as GlobalState; } + private cacheKey(keyDefinition: KeyDefinition) { + return `${keyDefinition.fullName}_${keyDefinition.stateDefinition.defaultStorageLocation}`; + } + getFake(keyDefinition: KeyDefinition): FakeGlobalState { return this.get(keyDefinition) as FakeGlobalState; } - mockFor(keyDefinitionKey: string, initialValue?: T): FakeGlobalState { - if (!this.establishedMocks.has(keyDefinitionKey)) { - this.establishedMocks.set(keyDefinitionKey, new FakeGlobalState(initialValue)); + mockFor(keyDefinition: KeyDefinition, initialValue?: T): FakeGlobalState { + const cacheKey = this.cacheKey(keyDefinition); + if (!this.states.has(cacheKey)) { + this.states.set(cacheKey, new FakeGlobalState(initialValue)); } - return this.establishedMocks.get(keyDefinitionKey) as FakeGlobalState; + return this.states.get(cacheKey) as FakeGlobalState; } } export class FakeSingleUserStateProvider implements SingleUserStateProvider { mock = mock(); - establishedMocks: Map> = new Map(); states: Map> = new Map(); + + constructor( + readonly updateSyncCallback?: ( + key: UserKeyDefinition, + userId: UserId, + newValue: unknown, + ) => Promise, + ) {} + get(userId: UserId, userKeyDefinition: UserKeyDefinition): SingleUserState { this.mock.get(userId, userKeyDefinition); - const cacheKey = `${userKeyDefinition.fullName}_${userKeyDefinition.stateDefinition.defaultStorageLocation}_${userId}`; + const cacheKey = this.cacheKey(userId, userKeyDefinition); let result = this.states.get(cacheKey); if (result == null) { - let fake: FakeSingleUserState; - // Look for established mock - if (this.establishedMocks.has(userKeyDefinition.key)) { - fake = this.establishedMocks.get(userKeyDefinition.key) as FakeSingleUserState; - } else { - fake = new FakeSingleUserState(userId); - } - fake.keyDefinition = userKeyDefinition; - result = fake; + result = this.buildFakeState(userId, userKeyDefinition); this.states.set(cacheKey, result); } return result as SingleUserState; } - getFake(userId: UserId, userKeyDefinition: UserKeyDefinition): FakeSingleUserState { + getFake( + userId: UserId, + userKeyDefinition: UserKeyDefinition, + { allowInit }: { allowInit: boolean } = { allowInit: true }, + ): FakeSingleUserState { + if (!allowInit && this.states.get(this.cacheKey(userId, userKeyDefinition)) == null) { + return null; + } + return this.get(userId, userKeyDefinition) as FakeSingleUserState; } - mockFor(userId: UserId, keyDefinitionKey: string, initialValue?: T): FakeSingleUserState { - if (!this.establishedMocks.has(keyDefinitionKey)) { - this.establishedMocks.set(keyDefinitionKey, new FakeSingleUserState(userId, initialValue)); + mockFor( + userId: UserId, + userKeyDefinition: UserKeyDefinition, + initialValue?: T, + ): FakeSingleUserState { + const cacheKey = this.cacheKey(userId, userKeyDefinition); + if (!this.states.has(cacheKey)) { + this.states.set(cacheKey, this.buildFakeState(userId, userKeyDefinition, initialValue)); } - return this.establishedMocks.get(keyDefinitionKey) as FakeSingleUserState; + return this.states.get(cacheKey) as FakeSingleUserState; + } + + private buildFakeState( + userId: UserId, + userKeyDefinition: UserKeyDefinition, + initialValue?: T, + ) { + const state = new FakeSingleUserState(userId, initialValue, async (...args) => { + await this.updateSyncCallback?.(userKeyDefinition, ...args); + }); + state.keyDefinition = userKeyDefinition; + return state; + } + + private cacheKey(userId: UserId, userKeyDefinition: UserKeyDefinition) { + return `${userKeyDefinitionCacheKey(userKeyDefinition)}_${userId}`; } } export class FakeActiveUserStateProvider implements ActiveUserStateProvider { activeUserId$: Observable; - establishedMocks: Map> = new Map(); - states: Map> = new Map(); - constructor(public accountService: FakeAccountService) { + constructor( + public accountService: FakeAccountService, + readonly updateSyncCallback?: ( + key: UserKeyDefinition, + userId: UserId, + newValue: unknown, + ) => Promise, + ) { this.activeUserId$ = accountService.activeAccountSubject.asObservable().pipe(map((a) => a?.id)); } get(userKeyDefinition: UserKeyDefinition): ActiveUserState { - const cacheKey = `${userKeyDefinition.fullName}_${userKeyDefinition.stateDefinition.defaultStorageLocation}`; + const cacheKey = userKeyDefinitionCacheKey(userKeyDefinition); let result = this.states.get(cacheKey); if (result == null) { - // Look for established mock - if (this.establishedMocks.has(userKeyDefinition.key)) { - result = this.establishedMocks.get(userKeyDefinition.key); - } else { - result = new FakeActiveUserState(this.accountService); - } - result.keyDefinition = userKeyDefinition; + result = this.buildFakeState(userKeyDefinition); this.states.set(cacheKey, result); } return result as ActiveUserState; } - getFake(userKeyDefinition: UserKeyDefinition): FakeActiveUserState { + getFake( + userKeyDefinition: UserKeyDefinition, + { allowInit }: { allowInit: boolean } = { allowInit: true }, + ): FakeActiveUserState { + if (!allowInit && this.states.get(userKeyDefinitionCacheKey(userKeyDefinition)) == null) { + return null; + } return this.get(userKeyDefinition) as FakeActiveUserState; } - mockFor(keyDefinitionKey: string, initialValue?: T): FakeActiveUserState { - if (!this.establishedMocks.has(keyDefinitionKey)) { - this.establishedMocks.set( - keyDefinitionKey, - new FakeActiveUserState(this.accountService, initialValue), - ); + mockFor(userKeyDefinition: UserKeyDefinition, initialValue?: T): FakeActiveUserState { + const cacheKey = userKeyDefinitionCacheKey(userKeyDefinition); + if (!this.states.has(cacheKey)) { + this.states.set(cacheKey, this.buildFakeState(userKeyDefinition, initialValue)); } - return this.establishedMocks.get(keyDefinitionKey) as FakeActiveUserState; + return this.states.get(cacheKey) as FakeActiveUserState; } + + private buildFakeState(userKeyDefinition: UserKeyDefinition, initialValue?: T) { + const state = new FakeActiveUserState(this.accountService, initialValue, async (...args) => { + await this.updateSyncCallback?.(userKeyDefinition, ...args); + }); + state.keyDefinition = userKeyDefinition; + return state; + } +} + +function userKeyDefinitionCacheKey(userKeyDefinition: UserKeyDefinition) { + return `${userKeyDefinition.fullName}_${userKeyDefinition.stateDefinition.defaultStorageLocation}`; } export class FakeStateProvider implements StateProvider { @@ -207,9 +256,35 @@ export class FakeStateProvider implements StateProvider { constructor(public accountService: FakeAccountService) {} + private distributeSingleUserUpdate( + key: UserKeyDefinition, + userId: UserId, + newState: unknown, + ) { + if (this.activeUser.accountService.activeUserId === userId) { + const state = this.activeUser.getFake(key, { allowInit: false }); + state?.nextState(newState, { syncValue: false }); + } + } + + private distributeActiveUserUpdate( + key: UserKeyDefinition, + userId: UserId, + newState: unknown, + ) { + this.singleUser + .getFake(userId, key, { allowInit: false }) + ?.nextState(newState, { syncValue: false }); + } + global: FakeGlobalStateProvider = new FakeGlobalStateProvider(); - singleUser: FakeSingleUserStateProvider = new FakeSingleUserStateProvider(); - activeUser: FakeActiveUserStateProvider = new FakeActiveUserStateProvider(this.accountService); + singleUser: FakeSingleUserStateProvider = new FakeSingleUserStateProvider( + this.distributeSingleUserUpdate.bind(this), + ); + activeUser: FakeActiveUserStateProvider = new FakeActiveUserStateProvider( + this.accountService, + this.distributeActiveUserUpdate.bind(this), + ); derived: FakeDerivedStateProvider = new FakeDerivedStateProvider(); activeUserId$: Observable = this.activeUser.activeUserId$; } diff --git a/libs/common/spec/fake-state.ts b/libs/common/spec/fake-state.ts index 0f2a09d9c1..2400e470d4 100644 --- a/libs/common/spec/fake-state.ts +++ b/libs/common/spec/fake-state.ts @@ -1,4 +1,4 @@ -import { Observable, ReplaySubject, concatMap, firstValueFrom, map, timeout } from "rxjs"; +import { Observable, ReplaySubject, concatMap, filter, firstValueFrom, map, timeout } from "rxjs"; import { DerivedState, @@ -41,6 +41,10 @@ export class FakeGlobalState implements GlobalState { this.stateSubject.next(initialValue ?? null); } + nextState(state: T) { + this.stateSubject.next(state); + } + async update( configureState: (state: T, dependency: TCombine) => T, options?: StateUpdateOptions, @@ -89,7 +93,10 @@ export class FakeGlobalState implements GlobalState { export class FakeSingleUserState implements SingleUserState { // eslint-disable-next-line rxjs/no-exposed-subjects -- exposed for testing setup - stateSubject = new ReplaySubject>(1); + stateSubject = new ReplaySubject<{ + syncValue: boolean; + combinedState: CombinedState; + }>(1); state$: Observable; combinedState$: Observable>; @@ -97,15 +104,28 @@ export class FakeSingleUserState implements SingleUserState { constructor( readonly userId: UserId, initialValue?: T, + updateSyncCallback?: (userId: UserId, newValue: T) => Promise, ) { - this.stateSubject.next([userId, initialValue ?? null]); + // Inform the state provider of updates to keep active user states in sync + this.stateSubject + .pipe( + filter((next) => next.syncValue), + concatMap(async ({ combinedState }) => { + await updateSyncCallback?.(...combinedState); + }), + ) + .subscribe(); + this.nextState(initialValue ?? null, { syncValue: initialValue != null }); - this.combinedState$ = this.stateSubject.asObservable(); + this.combinedState$ = this.stateSubject.pipe(map((v) => v.combinedState)); this.state$ = this.combinedState$.pipe(map(([_userId, state]) => state)); } - nextState(state: T) { - this.stateSubject.next([this.userId, state]); + nextState(state: T, { syncValue }: { syncValue: boolean } = { syncValue: true }) { + this.stateSubject.next({ + syncValue, + combinedState: [this.userId, state], + }); } async update( @@ -122,7 +142,7 @@ export class FakeSingleUserState implements SingleUserState { return current; } const newState = configureState(current, combinedDependencies); - this.stateSubject.next([this.userId, newState]); + this.nextState(newState); this.nextMock(newState); return newState; } @@ -146,7 +166,10 @@ export class FakeActiveUserState implements ActiveUserState { [activeMarker]: true; // eslint-disable-next-line rxjs/no-exposed-subjects -- exposed for testing setup - stateSubject = new ReplaySubject>(1); + stateSubject = new ReplaySubject<{ + syncValue: boolean; + combinedState: CombinedState; + }>(1); state$: Observable; combinedState$: Observable>; @@ -154,10 +177,18 @@ export class FakeActiveUserState implements ActiveUserState { constructor( private accountService: FakeAccountService, initialValue?: T, + updateSyncCallback?: (userId: UserId, newValue: T) => Promise, ) { - this.stateSubject.next([accountService.activeUserId, initialValue ?? null]); + // Inform the state provider of updates to keep single user states in sync + this.stateSubject.pipe( + filter((next) => next.syncValue), + concatMap(async ({ combinedState }) => { + await updateSyncCallback?.(...combinedState); + }), + ); + this.nextState(initialValue ?? null, { syncValue: initialValue != null }); - this.combinedState$ = this.stateSubject.asObservable(); + this.combinedState$ = this.stateSubject.pipe(map((v) => v.combinedState)); this.state$ = this.combinedState$.pipe(map(([_userId, state]) => state)); } @@ -165,8 +196,11 @@ export class FakeActiveUserState implements ActiveUserState { return this.accountService.activeUserId; } - nextState(state: T) { - this.stateSubject.next([this.userId, state]); + nextState(state: T, { syncValue }: { syncValue: boolean } = { syncValue: true }) { + this.stateSubject.next({ + syncValue, + combinedState: [this.userId, state], + }); } async update( @@ -183,7 +217,7 @@ export class FakeActiveUserState implements ActiveUserState { return [this.userId, current]; } const newState = configureState(current, combinedDependencies); - this.stateSubject.next([this.userId, newState]); + this.nextState(newState); this.nextMock([this.userId, newState]); return [this.userId, newState]; } diff --git a/libs/common/src/admin-console/abstractions/policy/policy.service.abstraction.ts b/libs/common/src/admin-console/abstractions/policy/policy.service.abstraction.ts index 21669f78ad..1067c24234 100644 --- a/libs/common/src/admin-console/abstractions/policy/policy.service.abstraction.ts +++ b/libs/common/src/admin-console/abstractions/policy/policy.service.abstraction.ts @@ -77,5 +77,5 @@ export abstract class PolicyService { export abstract class InternalPolicyService extends PolicyService { upsert: (policy: PolicyData) => Promise; - replace: (policies: { [id: string]: PolicyData }) => Promise; + replace: (policies: { [id: string]: PolicyData }, userId: UserId) => Promise; } diff --git a/libs/common/src/admin-console/services/policy/policy.service.spec.ts b/libs/common/src/admin-console/services/policy/policy.service.spec.ts index 88264d1c3b..d9802db9e3 100644 --- a/libs/common/src/admin-console/services/policy/policy.service.spec.ts +++ b/libs/common/src/admin-console/services/policy/policy.service.spec.ts @@ -20,6 +20,7 @@ import { POLICIES, PolicyService } from "../../../admin-console/services/policy/ import { PolicyId, UserId } from "../../../types/guid"; describe("PolicyService", () => { + const userId = "userId" as UserId; let stateProvider: FakeStateProvider; let organizationService: MockProxy; let activeUserState: FakeActiveUserState>; @@ -27,7 +28,7 @@ describe("PolicyService", () => { let policyService: PolicyService; beforeEach(() => { - const accountService = mockAccountServiceWith("userId" as UserId); + const accountService = mockAccountServiceWith(userId); stateProvider = new FakeStateProvider(accountService); organizationService = mock(); @@ -95,9 +96,12 @@ describe("PolicyService", () => { ]), ); - await policyService.replace({ - "2": policyData("2", "test-organization", PolicyType.DisableSend, true), - }); + await policyService.replace( + { + "2": policyData("2", "test-organization", PolicyType.DisableSend, true), + }, + userId, + ); expect(await firstValueFrom(policyService.policies$)).toEqual([ { diff --git a/libs/common/src/admin-console/services/policy/policy.service.ts b/libs/common/src/admin-console/services/policy/policy.service.ts index 2287ef9b4f..f52d061ad9 100644 --- a/libs/common/src/admin-console/services/policy/policy.service.ts +++ b/libs/common/src/admin-console/services/policy/policy.service.ts @@ -219,8 +219,8 @@ export class PolicyService implements InternalPolicyServiceAbstraction { }); } - async replace(policies: { [id: string]: PolicyData }): Promise { - await this.activeUserPolicyState.update(() => policies); + async replace(policies: { [id: string]: PolicyData }, userId: UserId): Promise { + await this.stateProvider.setUserState(POLICIES, policies, userId); } /** diff --git a/libs/common/src/auth/abstractions/key-connector.service.ts b/libs/common/src/auth/abstractions/key-connector.service.ts index b1b6727cd1..26335ced48 100644 --- a/libs/common/src/auth/abstractions/key-connector.service.ts +++ b/libs/common/src/auth/abstractions/key-connector.service.ts @@ -4,17 +4,17 @@ import { IdentityTokenResponse } from "../models/response/identity-token.respons export abstract class KeyConnectorService { setMasterKeyFromUrl: (url: string, userId: UserId) => Promise; - getManagingOrganization: () => Promise; - getUsesKeyConnector: () => Promise; - migrateUser: () => Promise; - userNeedsMigration: () => Promise; + getManagingOrganization: (userId?: UserId) => Promise; + getUsesKeyConnector: (userId: UserId) => Promise; + migrateUser: (userId?: UserId) => Promise; + userNeedsMigration: (userId: UserId) => Promise; convertNewSsoUserToKeyConnector: ( tokenResponse: IdentityTokenResponse, orgId: string, userId: UserId, ) => Promise; - setUsesKeyConnector: (enabled: boolean) => Promise; - setConvertAccountRequired: (status: boolean) => Promise; + setUsesKeyConnector: (enabled: boolean, userId: UserId) => Promise; + setConvertAccountRequired: (status: boolean, userId?: UserId) => Promise; getConvertAccountRequired: () => Promise; - removeConvertAccountRequired: () => Promise; + removeConvertAccountRequired: (userId?: UserId) => Promise; } diff --git a/libs/common/src/auth/abstractions/token.service.ts b/libs/common/src/auth/abstractions/token.service.ts index c86b5f1ee3..9239a0db54 100644 --- a/libs/common/src/auth/abstractions/token.service.ts +++ b/libs/common/src/auth/abstractions/token.service.ts @@ -148,10 +148,11 @@ export abstract class TokenService { /** * Decodes the access token. - * @param token The access token to decode. + * @param tokenOrUserId The access token to decode or the user id to retrieve the access token for, and then decode. + * If null, the currently active user's token is used. * @returns A promise that resolves with the decoded access token. */ - decodeAccessToken: (token?: string) => Promise; + decodeAccessToken: (tokenOrUserId?: string | UserId) => Promise; /** * Gets the expiration date for the access token. Returns if token can't be decoded or has no expiration @@ -212,9 +213,10 @@ export abstract class TokenService { /** * Gets whether or not the user authenticated via an external mechanism. + * @param userId The optional user id to check for external authN status; if not provided, the active user is used. * @returns A promise that resolves with a boolean representing the user's external authN status. */ - getIsExternal: () => Promise; + getIsExternal: (userId: UserId) => Promise; /** Gets the active or passed in user's security stamp */ getSecurityStamp: (userId?: UserId) => Promise; diff --git a/libs/common/src/auth/services/key-connector.service.spec.ts b/libs/common/src/auth/services/key-connector.service.spec.ts index 0fc0267a53..5d1aff45f6 100644 --- a/libs/common/src/auth/services/key-connector.service.spec.ts +++ b/libs/common/src/auth/services/key-connector.service.spec.ts @@ -78,9 +78,9 @@ describe("KeyConnectorService", () => { const newValue = true; - await keyConnectorService.setUsesKeyConnector(newValue); + await keyConnectorService.setUsesKeyConnector(newValue, mockUserId); - expect(await keyConnectorService.getUsesKeyConnector()).toBe(newValue); + expect(await keyConnectorService.getUsesKeyConnector(mockUserId)).toBe(newValue); }); }); @@ -185,7 +185,7 @@ describe("KeyConnectorService", () => { const state = stateProvider.activeUser.getFake(USES_KEY_CONNECTOR); state.nextState(false); - const result = await keyConnectorService.userNeedsMigration(); + const result = await keyConnectorService.userNeedsMigration(mockUserId); expect(result).toBe(true); }); @@ -197,7 +197,7 @@ describe("KeyConnectorService", () => { const state = stateProvider.activeUser.getFake(USES_KEY_CONNECTOR); state.nextState(true); - const result = await keyConnectorService.userNeedsMigration(); + const result = await keyConnectorService.userNeedsMigration(mockUserId); expect(result).toBe(false); }); diff --git a/libs/common/src/auth/services/key-connector.service.ts b/libs/common/src/auth/services/key-connector.service.ts index 8f204e557e..ad9b7081cd 100644 --- a/libs/common/src/auth/services/key-connector.service.ts +++ b/libs/common/src/auth/services/key-connector.service.ts @@ -69,25 +69,25 @@ export class KeyConnectorService implements KeyConnectorServiceAbstraction { ); } - async setUsesKeyConnector(usesKeyConnector: boolean) { - await this.usesKeyConnectorState.update(() => usesKeyConnector); + async setUsesKeyConnector(usesKeyConnector: boolean, userId: UserId) { + await this.stateProvider.getUser(userId, USES_KEY_CONNECTOR).update(() => usesKeyConnector); } - getUsesKeyConnector(): Promise { - return firstValueFrom(this.usesKeyConnectorState.state$); + getUsesKeyConnector(userId: UserId): Promise { + return firstValueFrom(this.stateProvider.getUserState$(USES_KEY_CONNECTOR, userId)); } - async userNeedsMigration() { - const loggedInUsingSso = await this.tokenService.getIsExternal(); - const requiredByOrganization = (await this.getManagingOrganization()) != null; - const userIsNotUsingKeyConnector = !(await this.getUsesKeyConnector()); + async userNeedsMigration(userId: UserId) { + const loggedInUsingSso = await this.tokenService.getIsExternal(userId); + const requiredByOrganization = (await this.getManagingOrganization(userId)) != null; + const userIsNotUsingKeyConnector = !(await this.getUsesKeyConnector(userId)); return loggedInUsingSso && requiredByOrganization && userIsNotUsingKeyConnector; } - async migrateUser() { - const organization = await this.getManagingOrganization(); - const userId = (await firstValueFrom(this.accountService.activeAccount$))?.id; + async migrateUser(userId?: UserId) { + userId ??= (await firstValueFrom(this.accountService.activeAccount$))?.id; + const organization = await this.getManagingOrganization(userId); const masterKey = await firstValueFrom(this.masterPasswordService.masterKey$(userId)); const keyConnectorRequest = new KeyConnectorUserKeyRequest(masterKey.encKeyB64); @@ -115,8 +115,8 @@ export class KeyConnectorService implements KeyConnectorServiceAbstraction { } } - async getManagingOrganization(): Promise { - const orgs = await this.organizationService.getAll(); + async getManagingOrganization(userId?: UserId): Promise { + const orgs = await this.organizationService.getAll(userId); return orgs.find( (o) => o.keyConnectorEnabled && @@ -178,16 +178,16 @@ export class KeyConnectorService implements KeyConnectorServiceAbstraction { await this.apiService.postSetKeyConnectorKey(setPasswordRequest); } - async setConvertAccountRequired(status: boolean) { - await this.convertAccountToKeyConnectorState.update(() => status); + async setConvertAccountRequired(status: boolean, userId?: UserId) { + await this.stateProvider.setUserState(CONVERT_ACCOUNT_TO_KEY_CONNECTOR, status, userId); } getConvertAccountRequired(): Promise { return firstValueFrom(this.convertAccountToKeyConnectorState.state$); } - async removeConvertAccountRequired() { - await this.setConvertAccountRequired(null); + async removeConvertAccountRequired(userId?: UserId) { + await this.setConvertAccountRequired(null, userId); } private handleKeyConnectorError(e: any) { diff --git a/libs/common/src/auth/services/token.service.spec.ts b/libs/common/src/auth/services/token.service.spec.ts index 4be945de5f..f8882e1b11 100644 --- a/libs/common/src/auth/services/token.service.spec.ts +++ b/libs/common/src/auth/services/token.service.spec.ts @@ -126,7 +126,7 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // Act const result = await firstValueFrom(tokenService.hasAccessToken$(userIdFromAccessToken)); @@ -139,11 +139,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // Act const result = await firstValueFrom(tokenService.hasAccessToken$(userIdFromAccessToken)); @@ -156,7 +156,7 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, "encryptedAccessToken"]); + .nextState("encryptedAccessToken"); secureStorageService.get.mockResolvedValue(accessTokenKeyB64); @@ -282,7 +282,7 @@ describe("TokenService", () => { // For testing purposes, let's assume that the access token is already in memory singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); keyGenerationService.createKey.mockResolvedValue(accessTokenKey); @@ -411,9 +411,7 @@ describe("TokenService", () => { it("returns null when no access token is found in memory, disk, or secure storage", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getAccessToken(); @@ -429,18 +427,16 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // set disk to undefined singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Need to have global active id set to the user id if (!userId) { - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); } // Act @@ -459,17 +455,15 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // Need to have global active id set to the user id if (!userId) { - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); } // Act @@ -498,20 +492,18 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, "encryptedAccessToken"]); + .nextState("encryptedAccessToken"); secureStorageService.get.mockResolvedValue(accessTokenKeyB64); encryptService.decryptToUtf8.mockResolvedValue("decryptedAccessToken"); // Need to have global active id set to the user id if (!userId) { - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); } // Act @@ -534,17 +526,15 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // Need to have global active id set to the user id if (!userId) { - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); } // No access token key set @@ -564,11 +554,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, encryptedAccessToken]); + .nextState(encryptedAccessToken); // No access token key set @@ -596,11 +586,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, encryptedAccessToken]); + .nextState(encryptedAccessToken); // Mock linux secure storage error const secureStorageError = "Secure storage error"; @@ -655,17 +645,15 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // Need to have global active id set to the user id if (!userId) { - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); } // Act @@ -688,8 +676,32 @@ describe("TokenService", () => { }); describe("decodeAccessToken", () => { + it("retrieves the requested user's token when the passed in parameter is a Guid", async () => { + // Arrange + tokenService.getAccessToken = jest.fn().mockResolvedValue(accessTokenJwt); + + // Act + const result = await tokenService.decodeAccessToken(userIdFromAccessToken); + + // Assert + expect(result).toEqual(accessTokenDecoded); + expect(tokenService.getAccessToken).toHaveBeenCalledWith(userIdFromAccessToken); + }); + + it("decodes the given token when a string is passed in that is not a Guid", async () => { + // Arrange + tokenService.getAccessToken = jest.fn(); + + // Act + const result = await tokenService.decodeAccessToken(accessTokenJwt); + + // Assert + expect(result).toEqual(accessTokenDecoded); + expect(tokenService.getAccessToken).not.toHaveBeenCalled(); + }); + it("throws an error when no access token is provided or retrievable from state", async () => { - // Access + // Arrange tokenService.getAccessToken = jest.fn().mockResolvedValue(null); // Act @@ -1194,7 +1206,7 @@ describe("TokenService", () => { // Act // note: don't await here because we want to test the error - const result = tokenService.getIsExternal(); + const result = tokenService.getIsExternal(null); // Assert await expect(result).rejects.toThrow("Failed to decode access token: Mock error"); }); @@ -1210,7 +1222,7 @@ describe("TokenService", () => { .mockResolvedValue(accessTokenDecodedWithoutExternalAmr); // Act - const result = await tokenService.getIsExternal(); + const result = await tokenService.getIsExternal(null); // Assert expect(result).toEqual(false); @@ -1227,11 +1239,22 @@ describe("TokenService", () => { .mockResolvedValue(accessTokenDecodedWithExternalAmr); // Act - const result = await tokenService.getIsExternal(); + const result = await tokenService.getIsExternal(null); // Assert expect(result).toEqual(true); }); + + it("passes the requested userId to decode", async () => { + // Arrange + tokenService.decodeAccessToken = jest.fn().mockResolvedValue(accessTokenDecoded); + + // Act + await tokenService.getIsExternal(userIdFromAccessToken); + + // Assert + expect(tokenService.decodeAccessToken).toHaveBeenCalledWith(userIdFromAccessToken); + }); }); }); }); @@ -1326,11 +1349,11 @@ describe("TokenService", () => { // For testing purposes, let's assume that the token is already in disk and memory singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); // We immediately call to get the refresh token from secure storage after setting it to ensure it was set. secureStorageService.get.mockResolvedValue(refreshToken); @@ -1423,11 +1446,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + .nextState(accessTokenJwt); // Mock linux secure storage error const secureStorageError = "Secure storage error"; @@ -1480,11 +1503,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, encryptedAccessToken]); + .nextState(encryptedAccessToken); secureStorageService.get.mockResolvedValue(accessTokenKeyB64); encryptService.decryptToUtf8.mockRejectedValue(new Error("Decryption error")); @@ -1520,9 +1543,7 @@ describe("TokenService", () => { it("returns null when no refresh token is found in memory, disk, or secure storage", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await (tokenService as any).getRefreshToken(); @@ -1535,16 +1556,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getRefreshToken(); @@ -1557,11 +1576,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Act const result = await tokenService.getRefreshToken(userIdFromAccessToken); @@ -1575,16 +1594,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getRefreshToken(); @@ -1596,11 +1613,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); // Act const result = await tokenService.getRefreshToken(userIdFromAccessToken); @@ -1619,18 +1636,16 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); secureStorageService.get.mockResolvedValue(refreshToken); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getRefreshToken(); @@ -1643,11 +1658,11 @@ describe("TokenService", () => { singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); secureStorageService.get.mockResolvedValue(refreshToken); @@ -1661,11 +1676,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); // Act const result = await tokenService.getRefreshToken(userIdFromAccessToken); @@ -1681,16 +1696,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getRefreshToken(); @@ -1719,11 +1732,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); secureStorageService.get.mockResolvedValue(null); @@ -1743,11 +1756,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); const secureStorageSvcMockErrorMsg = "Secure storage retrieval error"; @@ -1792,11 +1805,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_MEMORY) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); singleUserStateProvider .getFake(userIdFromAccessToken, REFRESH_TOKEN_DISK) - .stateSubject.next([userIdFromAccessToken, refreshToken]); + .nextState(refreshToken); // Act await (tokenService as any).clearRefreshToken(userIdFromAccessToken); @@ -1833,9 +1846,7 @@ describe("TokenService", () => { it("should throw an error if the vault timeout is missing", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = tokenService.setClientId(clientId, VaultTimeoutAction.Lock, null); @@ -1847,9 +1858,7 @@ describe("TokenService", () => { it("should throw an error if the vault timeout action is missing", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = tokenService.setClientId(clientId, null, VaultTimeoutStringType.Never); @@ -1861,9 +1870,7 @@ describe("TokenService", () => { describe("Memory storage tests", () => { it("sets the client id in memory when there is an active user in global state", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await tokenService.setClientId(clientId, memoryVaultTimeoutAction, memoryVaultTimeout); @@ -1895,9 +1902,7 @@ describe("TokenService", () => { describe("Disk storage tests", () => { it("sets the client id in disk when there is an active user in global state", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await tokenService.setClientId(clientId, diskVaultTimeoutAction, diskVaultTimeout); @@ -1935,9 +1940,7 @@ describe("TokenService", () => { it("returns null when no client id is found in memory or disk", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getClientId(); @@ -1950,17 +1953,15 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); // set disk to undefined singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getClientId(); @@ -1973,12 +1974,12 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); // set disk to undefined singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Act const result = await tokenService.getClientId(userIdFromAccessToken); @@ -1992,16 +1993,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getClientId(); @@ -2013,11 +2012,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); // Act const result = await tokenService.getClientId(userIdFromAccessToken); @@ -2040,11 +2039,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); // Act await (tokenService as any).clearClientId(userIdFromAccessToken); @@ -2062,16 +2061,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_ID_DISK) - .stateSubject.next([userIdFromAccessToken, clientId]); + .nextState(clientId); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await (tokenService as any).clearClientId(); @@ -2106,9 +2103,7 @@ describe("TokenService", () => { it("should throw an error if the vault timeout is missing", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = tokenService.setClientSecret(clientSecret, VaultTimeoutAction.Lock, null); @@ -2120,9 +2115,7 @@ describe("TokenService", () => { it("should throw an error if the vault timeout action is missing", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = tokenService.setClientSecret( @@ -2138,9 +2131,7 @@ describe("TokenService", () => { describe("Memory storage tests", () => { it("sets the client secret in memory when there is an active user in global state", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await tokenService.setClientSecret( @@ -2176,9 +2167,7 @@ describe("TokenService", () => { describe("Disk storage tests", () => { it("sets the client secret on disk when there is an active user in global state", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await tokenService.setClientSecret( @@ -2222,9 +2211,7 @@ describe("TokenService", () => { it("returns null when no client secret is found in memory or disk", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getClientSecret(); @@ -2237,17 +2224,15 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); // set disk to undefined singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getClientSecret(); @@ -2260,12 +2245,12 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); // set disk to undefined singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); // Act const result = await tokenService.getClientSecret(userIdFromAccessToken); @@ -2279,16 +2264,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act const result = await tokenService.getClientSecret(); @@ -2300,11 +2283,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .stateSubject.next([userIdFromAccessToken, undefined]); + .nextState(undefined); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); // Act const result = await tokenService.getClientSecret(userIdFromAccessToken); @@ -2327,11 +2310,11 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); // Act await (tokenService as any).clearClientSecret(userIdFromAccessToken); @@ -2351,16 +2334,14 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_MEMORY) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); singleUserStateProvider .getFake(userIdFromAccessToken, API_KEY_CLIENT_SECRET_DISK) - .stateSubject.next([userIdFromAccessToken, clientSecret]); + .nextState(clientSecret); // Need to have global active id set to the user id - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await (tokenService as any).clearClientSecret(); @@ -2634,7 +2615,7 @@ describe("TokenService", () => { // Arrange const userId = "userId" as UserId; - globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).stateSubject.next(userId); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userId); tokenService.clearAccessToken = jest.fn(); (tokenService as any).clearRefreshToken = jest.fn(); @@ -2693,7 +2674,7 @@ describe("TokenService", () => { globalStateProvider .getFake(EMAIL_TWO_FACTOR_TOKEN_RECORD_DISK_LOCAL) - .stateSubject.next(initialTwoFactorTokenRecord); + .nextState(initialTwoFactorTokenRecord); // Act await tokenService.setTwoFactorToken(email, twoFactorToken); @@ -2716,7 +2697,7 @@ describe("TokenService", () => { globalStateProvider .getFake(EMAIL_TWO_FACTOR_TOKEN_RECORD_DISK_LOCAL) - .stateSubject.next(initialTwoFactorTokenRecord); + .nextState(initialTwoFactorTokenRecord); // Act const result = await tokenService.getTwoFactorToken(email); @@ -2734,7 +2715,7 @@ describe("TokenService", () => { globalStateProvider .getFake(EMAIL_TWO_FACTOR_TOKEN_RECORD_DISK_LOCAL) - .stateSubject.next(initialTwoFactorTokenRecord); + .nextState(initialTwoFactorTokenRecord); // Act const result = await tokenService.getTwoFactorToken(email); @@ -2745,9 +2726,7 @@ describe("TokenService", () => { it("returns null when there is no two factor token record", async () => { // Arrange - globalStateProvider - .getFake(EMAIL_TWO_FACTOR_TOKEN_RECORD_DISK_LOCAL) - .stateSubject.next(null); + globalStateProvider.getFake(EMAIL_TWO_FACTOR_TOKEN_RECORD_DISK_LOCAL).nextState(null); // Act const result = await tokenService.getTwoFactorToken("testUser"); @@ -2768,7 +2747,7 @@ describe("TokenService", () => { globalStateProvider .getFake(EMAIL_TWO_FACTOR_TOKEN_RECORD_DISK_LOCAL) - .stateSubject.next(initialTwoFactorTokenRecord); + .nextState(initialTwoFactorTokenRecord); // Act await tokenService.clearTwoFactorToken(email); @@ -2808,9 +2787,7 @@ describe("TokenService", () => { it("sets the security stamp in memory when there is an active user in global state", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); // Act await tokenService.setSecurityStamp(mockSecurityStamp); @@ -2843,13 +2820,11 @@ describe("TokenService", () => { it("returns the security stamp from memory when no user id is specified (uses global active user)", async () => { // Arrange - globalStateProvider - .getFake(ACCOUNT_ACTIVE_ACCOUNT_ID) - .stateSubject.next(userIdFromAccessToken); + globalStateProvider.getFake(ACCOUNT_ACTIVE_ACCOUNT_ID).nextState(userIdFromAccessToken); singleUserStateProvider .getFake(userIdFromAccessToken, SECURITY_STAMP_MEMORY) - .stateSubject.next([userIdFromAccessToken, mockSecurityStamp]); + .nextState(mockSecurityStamp); // Act const result = await tokenService.getSecurityStamp(); @@ -2862,7 +2837,7 @@ describe("TokenService", () => { // Arrange singleUserStateProvider .getFake(userIdFromAccessToken, SECURITY_STAMP_MEMORY) - .stateSubject.next([userIdFromAccessToken, mockSecurityStamp]); + .nextState(mockSecurityStamp); // Act const result = await tokenService.getSecurityStamp(userIdFromAccessToken); diff --git a/libs/common/src/auth/services/token.service.ts b/libs/common/src/auth/services/token.service.ts index ef7f23cb05..c2150bc5c5 100644 --- a/libs/common/src/auth/services/token.service.ts +++ b/libs/common/src/auth/services/token.service.ts @@ -9,6 +9,7 @@ import { KeyGenerationService } from "../../platform/abstractions/key-generation import { LogService } from "../../platform/abstractions/log.service"; import { AbstractStorageService } from "../../platform/abstractions/storage.service"; import { StorageLocation } from "../../platform/enums"; +import { Utils } from "../../platform/misc/utils"; import { EncString, EncryptedString } from "../../platform/models/domain/enc-string"; import { StorageOptions } from "../../platform/models/domain/storage-options"; import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypto-key"; @@ -875,8 +876,13 @@ export class TokenService implements TokenServiceAbstraction { // jwthelper methods // ref https://github.com/auth0/angular-jwt/blob/master/src/angularJwt/services/jwt.js - async decodeAccessToken(token?: string): Promise { - token = token ?? (await this.getAccessToken()); + async decodeAccessToken(tokenOrUserId?: string | UserId): Promise { + let token = tokenOrUserId as string; + if (Utils.isGuid(tokenOrUserId)) { + token = await this.getAccessToken(tokenOrUserId as UserId); + } else { + token ??= await this.getAccessToken(); + } if (token == null) { throw new Error("Access token not found."); @@ -1012,10 +1018,10 @@ export class TokenService implements TokenServiceAbstraction { return decoded.iss; } - async getIsExternal(): Promise { + async getIsExternal(userId: UserId): Promise { let decoded: DecodedAccessToken; try { - decoded = await this.decodeAccessToken(); + decoded = await this.decodeAccessToken(userId); } catch (error) { throw new Error("Failed to decode access token: " + error.message); } diff --git a/libs/common/src/autofill/services/domain-settings.service.ts b/libs/common/src/autofill/services/domain-settings.service.ts index 4b36e8d2bf..7f2e8c3150 100644 --- a/libs/common/src/autofill/services/domain-settings.service.ts +++ b/libs/common/src/autofill/services/domain-settings.service.ts @@ -15,6 +15,7 @@ import { StateProvider, UserKeyDefinition, } from "../../platform/state"; +import { UserId } from "../../types/guid"; const SHOW_FAVICONS = new KeyDefinition(DOMAIN_SETTINGS_DISK, "showFavicons", { deserializer: (value: boolean) => value ?? true, @@ -44,7 +45,7 @@ export abstract class DomainSettingsService { neverDomains$: Observable; setNeverDomains: (newValue: NeverDomains) => Promise; equivalentDomains$: Observable; - setEquivalentDomains: (newValue: EquivalentDomains) => Promise; + setEquivalentDomains: (newValue: EquivalentDomains, userId: UserId) => Promise; defaultUriMatchStrategy$: Observable; setDefaultUriMatchStrategy: (newValue: UriMatchStrategySetting) => Promise; getUrlEquivalentDomains: (url: string) => Observable>; @@ -87,8 +88,8 @@ export class DefaultDomainSettingsService implements DomainSettingsService { await this.neverDomainsState.update(() => newValue); } - async setEquivalentDomains(newValue: EquivalentDomains): Promise { - await this.equivalentDomainsState.update(() => newValue); + async setEquivalentDomains(newValue: EquivalentDomains, userId: UserId): Promise { + await this.stateProvider.getUser(userId, EQUIVALENT_DOMAINS).update(() => newValue); } async setDefaultUriMatchStrategy(newValue: UriMatchStrategySetting): Promise { diff --git a/libs/common/src/billing/abstractions/account/billing-account-profile-state.service.ts b/libs/common/src/billing/abstractions/account/billing-account-profile-state.service.ts index e07dec3cf9..080c61e9ff 100644 --- a/libs/common/src/billing/abstractions/account/billing-account-profile-state.service.ts +++ b/libs/common/src/billing/abstractions/account/billing-account-profile-state.service.ts @@ -1,5 +1,7 @@ import { Observable } from "rxjs"; +import { UserId } from "../../../types/guid"; + export type BillingAccountProfile = { hasPremiumPersonally: boolean; hasPremiumFromAnyOrganization: boolean; @@ -32,5 +34,6 @@ export abstract class BillingAccountProfileStateService { abstract setHasPremium( hasPremiumPersonally: boolean, hasPremiumFromAnyOrganization: boolean, + userId: UserId, ): Promise; } diff --git a/libs/common/src/billing/services/account/billing-account-profile-state.service.spec.ts b/libs/common/src/billing/services/account/billing-account-profile-state.service.spec.ts index 7f0f218a23..7e0dee0eed 100644 --- a/libs/common/src/billing/services/account/billing-account-profile-state.service.spec.ts +++ b/libs/common/src/billing/services/account/billing-account-profile-state.service.spec.ts @@ -3,7 +3,6 @@ import { firstValueFrom } from "rxjs"; import { FakeAccountService, mockAccountServiceWith, - FakeActiveUserState, FakeStateProvider, FakeSingleUserState, } from "../../../../spec"; @@ -18,7 +17,6 @@ import { describe("BillingAccountProfileStateService", () => { let stateProvider: FakeStateProvider; let sut: DefaultBillingAccountProfileStateService; - let billingAccountProfileState: FakeActiveUserState; let userBillingAccountProfileState: FakeSingleUserState; let accountService: FakeAccountService; @@ -30,10 +28,6 @@ describe("BillingAccountProfileStateService", () => { sut = new DefaultBillingAccountProfileStateService(stateProvider); - billingAccountProfileState = stateProvider.activeUser.getFake( - BILLING_ACCOUNT_PROFILE_KEY_DEFINITION, - ); - userBillingAccountProfileState = stateProvider.singleUser.getFake( userId, BILLING_ACCOUNT_PROFILE_KEY_DEFINITION, @@ -133,12 +127,11 @@ describe("BillingAccountProfileStateService", () => { describe("setHasPremium", () => { it("should update the active users state when called", async () => { - await sut.setHasPremium(true, false); + await sut.setHasPremium(true, false, userId); - expect(billingAccountProfileState.nextMock).toHaveBeenCalledWith([ - userId, - { hasPremiumPersonally: true, hasPremiumFromAnyOrganization: false }, - ]); + expect(await firstValueFrom(sut.hasPremiumFromAnyOrganization$)).toBe(false); + expect(await firstValueFrom(sut.hasPremiumPersonally$)).toBe(true); + expect(await firstValueFrom(sut.hasPremiumFromAnySource$)).toBe(true); }); }); }); diff --git a/libs/common/src/billing/services/account/billing-account-profile-state.service.ts b/libs/common/src/billing/services/account/billing-account-profile-state.service.ts index cf05df2f22..7d256da971 100644 --- a/libs/common/src/billing/services/account/billing-account-profile-state.service.ts +++ b/libs/common/src/billing/services/account/billing-account-profile-state.service.ts @@ -6,6 +6,7 @@ import { StateProvider, UserKeyDefinition, } from "../../../platform/state"; +import { UserId } from "../../../types/guid"; import { BillingAccountProfile, BillingAccountProfileStateService, @@ -27,7 +28,7 @@ export class DefaultBillingAccountProfileStateService implements BillingAccountP hasPremiumPersonally$: Observable; hasPremiumFromAnySource$: Observable; - constructor(stateProvider: StateProvider) { + constructor(private readonly stateProvider: StateProvider) { this.billingAccountProfileState = stateProvider.getActive( BILLING_ACCOUNT_PROFILE_KEY_DEFINITION, ); @@ -62,8 +63,9 @@ export class DefaultBillingAccountProfileStateService implements BillingAccountP async setHasPremium( hasPremiumPersonally: boolean, hasPremiumFromAnyOrganization: boolean, + userId: UserId, ): Promise { - await this.billingAccountProfileState.update((billingAccountProfile) => { + await this.stateProvider.getUser(userId, BILLING_ACCOUNT_PROFILE_KEY_DEFINITION).update((_) => { return { hasPremiumPersonally: hasPremiumPersonally, hasPremiumFromAnyOrganization: hasPremiumFromAnyOrganization, diff --git a/libs/common/src/platform/abstractions/crypto.service.ts b/libs/common/src/platform/abstractions/crypto.service.ts index b9499c8fd5..ed18204e9e 100644 --- a/libs/common/src/platform/abstractions/crypto.service.ts +++ b/libs/common/src/platform/abstractions/crypto.service.ts @@ -143,7 +143,7 @@ export abstract class CryptoService { * @param userKeyMasterKey The master key encrypted user key to set * @param userId The desired user */ - abstract setMasterKeyEncryptedUserKey(UserKeyMasterKey: string, userId?: string): Promise; + abstract setMasterKeyEncryptedUserKey(UserKeyMasterKey: string, userId: string): Promise; /** * @param password The user's master password that will be used to derive a master key if one isn't found * @param userId The desired user diff --git a/libs/common/src/platform/biometrics/biometric-state.service.spec.ts b/libs/common/src/platform/biometrics/biometric-state.service.spec.ts index 097428e16a..56e9cb164f 100644 --- a/libs/common/src/platform/biometrics/biometric-state.service.spec.ts +++ b/libs/common/src/platform/biometrics/biometric-state.service.spec.ts @@ -119,7 +119,7 @@ describe("BiometricStateService", () => { describe("getRequirePasswordOnStart", () => { it("returns the requirePasswordOnStart state value", async () => { - stateProvider.singleUser.mockFor(userId, REQUIRE_PASSWORD_ON_START.key, true); + stateProvider.singleUser.mockFor(userId, REQUIRE_PASSWORD_ON_START, true); expect(await sut.getRequirePasswordOnStart(userId)).toBe(true); }); diff --git a/libs/common/src/platform/services/crypto.service.spec.ts b/libs/common/src/platform/services/crypto.service.spec.ts index 2386ad1371..dfa244ff2a 100644 --- a/libs/common/src/platform/services/crypto.service.spec.ts +++ b/libs/common/src/platform/services/crypto.service.spec.ts @@ -365,9 +365,9 @@ describe("cryptoService", () => { const userKeyState = stateProvider.singleUser.getFake(mockUserId, USER_KEY); const fakeMasterKey = makeMasterKey ? makeSymmetricCryptoKey(64) : null; masterPasswordService.masterKeySubject.next(fakeMasterKey); - userKeyState.stateSubject.next([mockUserId, null]); + userKeyState.nextState(null); const fakeUserKey = makeUserKey ? makeSymmetricCryptoKey(64) : null; - userKeyState.stateSubject.next([mockUserId, fakeUserKey]); + userKeyState.nextState(fakeUserKey); return [fakeUserKey, fakeMasterKey]; } @@ -384,10 +384,7 @@ describe("cryptoService", () => { const fakeEncryptedUserPrivateKey = makeEncString("1"); - userEncryptedPrivateKeyState.stateSubject.next([ - mockUserId, - fakeEncryptedUserPrivateKey.encryptedString, - ]); + userEncryptedPrivateKeyState.nextState(fakeEncryptedUserPrivateKey.encryptedString); // Decryption of the user private key const fakeDecryptedUserPrivateKey = makeStaticByteArray(10, 1); @@ -423,7 +420,7 @@ describe("cryptoService", () => { mockUserId, USER_ENCRYPTED_PRIVATE_KEY, ); - encryptedUserPrivateKeyState.stateSubject.next([mockUserId, null]); + encryptedUserPrivateKeyState.nextState(null); const userPrivateKey = await firstValueFrom(cryptoService.userPrivateKey$(mockUserId)); expect(userPrivateKey).toBeFalsy(); @@ -463,7 +460,7 @@ describe("cryptoService", () => { function updateKeys(keys: Partial = {}) { if ("userKey" in keys) { const userKeyState = stateProvider.singleUser.getFake(mockUserId, USER_KEY); - userKeyState.stateSubject.next([mockUserId, keys.userKey]); + userKeyState.nextState(keys.userKey); } if ("encryptedPrivateKey" in keys) { @@ -471,10 +468,7 @@ describe("cryptoService", () => { mockUserId, USER_ENCRYPTED_PRIVATE_KEY, ); - userEncryptedPrivateKey.stateSubject.next([ - mockUserId, - keys.encryptedPrivateKey.encryptedString, - ]); + userEncryptedPrivateKey.nextState(keys.encryptedPrivateKey.encryptedString); } if ("orgKeys" in keys) { @@ -482,7 +476,7 @@ describe("cryptoService", () => { mockUserId, USER_ENCRYPTED_ORGANIZATION_KEYS, ); - orgKeysState.stateSubject.next([mockUserId, keys.orgKeys]); + orgKeysState.nextState(keys.orgKeys); } if ("providerKeys" in keys) { @@ -490,7 +484,7 @@ describe("cryptoService", () => { mockUserId, USER_ENCRYPTED_PROVIDER_KEYS, ); - providerKeysState.stateSubject.next([mockUserId, keys.providerKeys]); + providerKeysState.nextState(keys.providerKeys); } encryptService.decryptToBytes.mockImplementation((encryptedPrivateKey, userKey) => { diff --git a/libs/common/src/platform/services/crypto.service.ts b/libs/common/src/platform/services/crypto.service.ts index 6d99f92082..6183051313 100644 --- a/libs/common/src/platform/services/crypto.service.ts +++ b/libs/common/src/platform/services/crypto.service.ts @@ -225,7 +225,7 @@ export class CryptoService implements CryptoServiceAbstraction { } } - async setMasterKeyEncryptedUserKey(userKeyMasterKey: string, userId?: UserId): Promise { + async setMasterKeyEncryptedUserKey(userKeyMasterKey: string, userId: UserId): Promise { userId ??= await firstValueFrom(this.stateProvider.activeUserId$); await this.masterPasswordService.setMasterKeyEncryptedUserKey( new EncString(userKeyMasterKey), diff --git a/libs/common/src/platform/state/implementations/default-state.provider.spec.ts b/libs/common/src/platform/state/implementations/default-state.provider.spec.ts index 5b8b2d1bfe..b3190bd532 100644 --- a/libs/common/src/platform/state/implementations/default-state.provider.spec.ts +++ b/libs/common/src/platform/state/implementations/default-state.provider.spec.ts @@ -143,7 +143,7 @@ describe("DefaultStateProvider", () => { it("should not emit any values until a truthy user id is supplied", async () => { accountService.activeAccountSubject.next(null); const state = singleUserStateProvider.getFake(userId, keyDefinition); - state.stateSubject.next([userId, "value"]); + state.nextState("value"); const emissions = trackEmissions(sut.getUserState$(keyDefinition)); diff --git a/libs/common/src/platform/sync/default-sync.service.ts b/libs/common/src/platform/sync/default-sync.service.ts index e48ab0618c..322687ce6a 100644 --- a/libs/common/src/platform/sync/default-sync.service.ts +++ b/libs/common/src/platform/sync/default-sync.service.ts @@ -124,12 +124,12 @@ export class DefaultSyncService extends CoreSyncService { const response = await this.apiService.getSync(); await this.syncProfile(response.profile); - await this.syncFolders(response.folders); - await this.syncCollections(response.collections); - await this.syncCiphers(response.ciphers); - await this.syncSends(response.sends); - await this.syncSettings(response.domains); - await this.syncPolicies(response.policies); + await this.syncFolders(response.folders, response.profile.id); + await this.syncCollections(response.collections, response.profile.id); + await this.syncCiphers(response.ciphers, response.profile.id); + await this.syncSends(response.sends, response.profile.id); + await this.syncSettings(response.domains, response.profile.id); + await this.syncPolicies(response.policies, response.profile.id); await this.setLastSync(now, userId); return this.syncCompleted(true); @@ -190,8 +190,9 @@ export class DefaultSyncService extends CoreSyncService { await this.billingAccountProfileStateService.setHasPremium( response.premiumPersonally, response.premiumFromOrganization, + response.id, ); - await this.keyConnectorService.setUsesKeyConnector(response.usesKeyConnector); + await this.keyConnectorService.setUsesKeyConnector(response.usesKeyConnector, response.id); await this.setForceSetPasswordReasonIfNeeded(response); @@ -200,17 +201,17 @@ export class DefaultSyncService extends CoreSyncService { providers[p.id] = new ProviderData(p); }); - await this.providerService.save(providers); + await this.providerService.save(providers, response.id); - await this.syncProfileOrganizations(response); + await this.syncProfileOrganizations(response, response.id); - if (await this.keyConnectorService.userNeedsMigration()) { - await this.keyConnectorService.setConvertAccountRequired(true); + if (await this.keyConnectorService.userNeedsMigration(response.id)) { + await this.keyConnectorService.setConvertAccountRequired(true, response.id); this.messageSender.send("convertAccountToKeyConnector"); } else { // FIXME: Verify that this floating promise is intentional. If it is, add an explanatory comment and ensure there is proper error handling. // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.keyConnectorService.removeConvertAccountRequired(); + this.keyConnectorService.removeConvertAccountRequired(response.id); } } @@ -261,7 +262,7 @@ export class DefaultSyncService extends CoreSyncService { } } - private async syncProfileOrganizations(response: ProfileResponse) { + private async syncProfileOrganizations(response: ProfileResponse, userId: UserId) { const organizations: { [id: string]: OrganizationData } = {}; response.organizations.forEach((o) => { organizations[o.id] = new OrganizationData(o, { @@ -281,42 +282,42 @@ export class DefaultSyncService extends CoreSyncService { } }); - await this.organizationService.replace(organizations); + await this.organizationService.replace(organizations, userId); } - private async syncFolders(response: FolderResponse[]) { + private async syncFolders(response: FolderResponse[], userId: UserId) { const folders: { [id: string]: FolderData } = {}; response.forEach((f) => { folders[f.id] = new FolderData(f); }); - return await this.folderService.replace(folders); + return await this.folderService.replace(folders, userId); } - private async syncCollections(response: CollectionDetailsResponse[]) { + private async syncCollections(response: CollectionDetailsResponse[], userId: UserId) { const collections: { [id: string]: CollectionData } = {}; response.forEach((c) => { collections[c.id] = new CollectionData(c); }); - return await this.collectionService.replace(collections); + return await this.collectionService.replace(collections, userId); } - private async syncCiphers(response: CipherResponse[]) { + private async syncCiphers(response: CipherResponse[], userId: UserId) { const ciphers: { [id: string]: CipherData } = {}; response.forEach((c) => { ciphers[c.id] = new CipherData(c); }); - return await this.cipherService.replace(ciphers); + return await this.cipherService.replace(ciphers, userId); } - private async syncSends(response: SendResponse[]) { + private async syncSends(response: SendResponse[], userId: UserId) { const sends: { [id: string]: SendData } = {}; response.forEach((s) => { sends[s.id] = new SendData(s); }); - return await this.sendService.replace(sends); + return await this.sendService.replace(sends, userId); } - private async syncSettings(response: DomainsResponse) { + private async syncSettings(response: DomainsResponse, userId: UserId) { let eqDomains: string[][] = []; if (response != null && response.equivalentDomains != null) { eqDomains = eqDomains.concat(response.equivalentDomains); @@ -330,16 +331,16 @@ export class DefaultSyncService extends CoreSyncService { }); } - return this.domainSettingsService.setEquivalentDomains(eqDomains); + return this.domainSettingsService.setEquivalentDomains(eqDomains, userId); } - private async syncPolicies(response: PolicyResponse[]) { + private async syncPolicies(response: PolicyResponse[], userId: UserId) { const policies: { [id: string]: PolicyData } = {}; if (response != null) { response.forEach((p) => { policies[p.id] = new PolicyData(p); }); } - return await this.policyService.replace(policies); + return await this.policyService.replace(policies, userId); } } diff --git a/libs/common/src/tools/send/services/send-state.provider.abstraction.ts b/libs/common/src/tools/send/services/send-state.provider.abstraction.ts index 7a35506b56..c16d06fb92 100644 --- a/libs/common/src/tools/send/services/send-state.provider.abstraction.ts +++ b/libs/common/src/tools/send/services/send-state.provider.abstraction.ts @@ -1,15 +1,19 @@ import { Observable } from "rxjs"; +import type { Simplify } from "type-fest"; +import { CombinedState } from "../../../platform/state"; +import { UserId } from "../../../types/guid"; import { SendData } from "../models/data/send.data"; import { SendView } from "../models/view/send.view"; +type EncryptedSendState = Simplify>>; export abstract class SendStateProvider { - encryptedState$: Observable>; + encryptedState$: Observable; decryptedState$: Observable; - getEncryptedSends: () => Promise<{ [id: string]: SendData }>; + getEncryptedSends: () => Promise; - setEncryptedSends: (value: { [id: string]: SendData }) => Promise; + setEncryptedSends: (value: { [id: string]: SendData }, userId: UserId) => Promise; getDecryptedSends: () => Promise; diff --git a/libs/common/src/tools/send/services/send-state.provider.spec.ts b/libs/common/src/tools/send/services/send-state.provider.spec.ts index 069e0d8069..abca614d11 100644 --- a/libs/common/src/tools/send/services/send-state.provider.spec.ts +++ b/libs/common/src/tools/send/services/send-state.provider.spec.ts @@ -27,11 +27,11 @@ describe("Send State Provider", () => { describe("Encrypted Sends", () => { it("should return SendData", async () => { const sendData = { "1": testSendData("1", "Test Send Data") }; - await sendStateProvider.setEncryptedSends(sendData); + await sendStateProvider.setEncryptedSends(sendData, mockUserId); await awaitAsync(); const actual = await sendStateProvider.getEncryptedSends(); - expect(actual).toStrictEqual(sendData); + expect(actual).toStrictEqual([mockUserId, sendData]); }); }); diff --git a/libs/common/src/tools/send/services/send-state.provider.ts b/libs/common/src/tools/send/services/send-state.provider.ts index 1e9397b7a9..66989a7054 100644 --- a/libs/common/src/tools/send/services/send-state.provider.ts +++ b/libs/common/src/tools/send/services/send-state.provider.ts @@ -1,6 +1,7 @@ import { Observable, firstValueFrom } from "rxjs"; -import { ActiveUserState, StateProvider } from "../../../platform/state"; +import { ActiveUserState, CombinedState, StateProvider } from "../../../platform/state"; +import { UserId } from "../../../types/guid"; import { SendData } from "../models/data/send.data"; import { SendView } from "../models/view/send.view"; @@ -10,7 +11,7 @@ import { SendStateProvider as SendStateProviderAbstraction } from "./send-state. /** State provider for sends */ export class SendStateProvider implements SendStateProviderAbstraction { /** Observable for the encrypted sends for an active user */ - encryptedState$: Observable>; + encryptedState$: Observable>>; /** Observable with the decrypted sends for an active user */ decryptedState$: Observable; @@ -19,20 +20,20 @@ export class SendStateProvider implements SendStateProviderAbstraction { constructor(protected stateProvider: StateProvider) { this.activeUserEncryptedState = this.stateProvider.getActive(SEND_USER_ENCRYPTED); - this.encryptedState$ = this.activeUserEncryptedState.state$; + this.encryptedState$ = this.activeUserEncryptedState.combinedState$; this.activeUserDecryptedState = this.stateProvider.getActive(SEND_USER_DECRYPTED); this.decryptedState$ = this.activeUserDecryptedState.state$; } /** Gets the encrypted sends from state for an active user */ - async getEncryptedSends(): Promise<{ [id: string]: SendData }> { + async getEncryptedSends(): Promise> { return await firstValueFrom(this.encryptedState$); } /** Sets the encrypted send state for an active user */ - async setEncryptedSends(value: { [id: string]: SendData }): Promise { - await this.activeUserEncryptedState.update(() => value); + async setEncryptedSends(value: { [id: string]: SendData }, userId: UserId): Promise { + await this.stateProvider.getUser(userId, SEND_USER_ENCRYPTED).update(() => value); } /** Gets the decrypted sends from state for the active user */ diff --git a/libs/common/src/tools/send/services/send.service.abstraction.ts b/libs/common/src/tools/send/services/send.service.abstraction.ts index 6033c9c6cb..4fa927942c 100644 --- a/libs/common/src/tools/send/services/send.service.abstraction.ts +++ b/libs/common/src/tools/send/services/send.service.abstraction.ts @@ -55,6 +55,6 @@ export abstract class SendService implements UserKeyRotationDataProvider Promise; - replace: (sends: { [id: string]: SendData }) => Promise; + replace: (sends: { [id: string]: SendData }, userId: UserId) => Promise; delete: (id: string | string[]) => Promise; } diff --git a/libs/common/src/tools/send/services/send.service.spec.ts b/libs/common/src/tools/send/services/send.service.spec.ts index 5d04127192..5743eff481 100644 --- a/libs/common/src/tools/send/services/send.service.spec.ts +++ b/libs/common/src/tools/send/services/send.service.spec.ts @@ -110,9 +110,12 @@ describe("SendService", () => { const result = await firstValueFrom(singleSendObservable); expect(result).toEqual(testSend("1", "Test Send")); - await sendService.replace({ - "1": testSendData("1", "Test Send Updated"), - }); + await sendService.replace( + { + "1": testSendData("1", "Test Send Updated"), + }, + mockUserId, + ); const result2 = await firstValueFrom(singleSendObservable); expect(result2).toEqual(testSend("1", "Test Send Updated")); @@ -127,10 +130,13 @@ describe("SendService", () => { //it is immediately called when subscribed, we need to reset the value changed = false; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -138,10 +144,13 @@ describe("SendService", () => { it("reports a change when notes changes on a new send", async () => { const sendDataObject = createSendData() as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -152,10 +161,13 @@ describe("SendService", () => { //it is immediately called when subscribed, we need to reset the value changed = false; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -163,10 +175,13 @@ describe("SendService", () => { it("reports a change when Text changes on a new send", async () => { const sendDataObject = createSendData() as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -177,10 +192,13 @@ describe("SendService", () => { changed = false; sendDataObject.text.text = "new text"; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -188,10 +206,13 @@ describe("SendService", () => { it("reports a change when Text is set as null on a new send", async () => { const sendDataObject = createSendData() as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -202,10 +223,13 @@ describe("SendService", () => { changed = false; sendDataObject.text = null; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -215,10 +239,13 @@ describe("SendService", () => { type: SendType.File, file: new SendFileData(new SendFileApi({ FileName: "name of file" })), }) as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); sendDataObject.file = new SendFileData(new SendFileApi({ FileName: "updated name of file" })); let changed = false; @@ -229,10 +256,13 @@ describe("SendService", () => { //it is immediately called when subscribed, we need to reset the value changed = false; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(false); }); @@ -240,10 +270,13 @@ describe("SendService", () => { it("reports a change when key changes on a new send", async () => { const sendDataObject = createSendData() as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -254,10 +287,13 @@ describe("SendService", () => { changed = false; sendDataObject.key = "newKey"; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -265,10 +301,13 @@ describe("SendService", () => { it("reports a change when revisionDate changes on a new send", async () => { const sendDataObject = createSendData() as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -279,10 +318,13 @@ describe("SendService", () => { changed = false; sendDataObject.revisionDate = "2025-04-05"; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -290,10 +332,13 @@ describe("SendService", () => { it("reports a change when a property is set as null on a new send", async () => { const sendDataObject = createSendData() as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -304,10 +349,13 @@ describe("SendService", () => { changed = false; sendDataObject.name = null; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -317,10 +365,13 @@ describe("SendService", () => { text: new SendTextData(new SendTextApi({ Text: null })), }) as SendData; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); let changed = false; sendService.get$("1").subscribe(() => { @@ -330,23 +381,29 @@ describe("SendService", () => { //it is immediately called when subscribed, we need to reset the value changed = false; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(false); sendDataObject.text.text = "Asdf"; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); - it("do not reports a change when nothing changes on the observed send", async () => { + it("do not report a change when nothing changes on the observed send", async () => { let changed = false; sendService.get$("1").subscribe(() => { changed = true; @@ -357,10 +414,13 @@ describe("SendService", () => { //it is immediately called when subscribed, we need to reset the value changed = false; - await sendService.replace({ - "1": sendDataObject, - "2": testSendData("3", "Test Send 3"), - }); + await sendService.replace( + { + "1": sendDataObject, + "2": testSendData("3", "Test Send 3"), + }, + mockUserId, + ); expect(changed).toEqual(false); }); @@ -373,9 +433,12 @@ describe("SendService", () => { //it is immediately called when subscribed, we need to reset the value changed = false; - await sendService.replace({ - "2": testSendData("2", "Test Send 2"), - }); + await sendService.replace( + { + "2": testSendData("2", "Test Send 2"), + }, + mockUserId, + ); expect(changed).toEqual(true); }); @@ -426,7 +489,7 @@ describe("SendService", () => { }); it("returns empty array if there are no sends", async () => { - await sendService.replace(null); + await sendService.replace(null, mockUserId); await awaitAsync(); @@ -461,16 +524,11 @@ describe("SendService", () => { }); it("replace", async () => { - await sendService.replace({ "2": testSendData("2", "test 2") }); + await sendService.replace({ "2": testSendData("2", "test 2") }, mockUserId); expect(await firstValueFrom(sendService.sends$)).toEqual([testSend("2", "test 2")]); }); - it("clear", async () => { - await sendService.clear(); - await awaitAsync(); - expect(await firstValueFrom(sendService.sends$)).toEqual([]); - }); describe("Delete", () => { it("Sends count should decrease after delete", async () => { const sendsBeforeDelete = await firstValueFrom(sendService.sends$); @@ -488,7 +546,7 @@ describe("SendService", () => { }); it("Deleting on an empty sends array should not throw", async () => { - sendStateProvider.getEncryptedSends = jest.fn().mockResolvedValue(null); + stateProvider.activeUser.getFake(SEND_USER_ENCRYPTED).nextState(null); await expect(sendService.delete("2")).resolves.not.toThrow(); }); diff --git a/libs/common/src/tools/send/services/send.service.ts b/libs/common/src/tools/send/services/send.service.ts index 7048cf5a37..63c07e862f 100644 --- a/libs/common/src/tools/send/services/send.service.ts +++ b/libs/common/src/tools/send/services/send.service.ts @@ -28,10 +28,10 @@ export class SendService implements InternalSendServiceAbstraction { readonly sendKeyPurpose = "send"; sends$ = this.stateProvider.encryptedState$.pipe( - map((record) => Object.values(record || {}).map((data) => new Send(data))), + map(([, record]) => Object.values(record || {}).map((data) => new Send(data))), ); sendViews$ = this.stateProvider.encryptedState$.pipe( - concatMap((record) => + concatMap(([, record]) => this.decryptSends(Object.values(record || {}).map((data) => new Send(data))), ), ); @@ -167,7 +167,7 @@ export class SendService implements InternalSendServiceAbstraction { } async getFromState(id: string): Promise { - const sends = await this.stateProvider.getEncryptedSends(); + const [, sends] = await this.stateProvider.getEncryptedSends(); // eslint-disable-next-line if (sends == null || !sends.hasOwnProperty(id)) { return null; @@ -177,7 +177,7 @@ export class SendService implements InternalSendServiceAbstraction { } async getAll(): Promise { - const sends = await this.stateProvider.getEncryptedSends(); + const [, sends] = await this.stateProvider.getEncryptedSends(); const response: Send[] = []; for (const id in sends) { // eslint-disable-next-line @@ -214,7 +214,8 @@ export class SendService implements InternalSendServiceAbstraction { } async upsert(send: SendData | SendData[]): Promise { - let sends = await this.stateProvider.getEncryptedSends(); + const [userId, currentSends] = await this.stateProvider.getEncryptedSends(); + let sends = currentSends; if (sends == null) { sends = {}; } @@ -227,16 +228,11 @@ export class SendService implements InternalSendServiceAbstraction { }); } - await this.replace(sends); - } - - async clear(userId?: string): Promise { - await this.stateProvider.setDecryptedSends(null); - await this.stateProvider.setEncryptedSends(null); + await this.replace(sends, userId); } async delete(id: string | string[]): Promise { - const sends = await this.stateProvider.getEncryptedSends(); + const [userId, sends] = await this.stateProvider.getEncryptedSends(); if (sends == null) { return; } @@ -252,11 +248,11 @@ export class SendService implements InternalSendServiceAbstraction { }); } - await this.replace(sends); + await this.replace(sends, userId); } - async replace(sends: { [id: string]: SendData }): Promise { - await this.stateProvider.setEncryptedSends(sends); + async replace(sends: { [id: string]: SendData }, userId: UserId): Promise { + await this.stateProvider.setEncryptedSends(sends, userId); } async getRotatedData( diff --git a/libs/common/src/vault/abstractions/cipher.service.ts b/libs/common/src/vault/abstractions/cipher.service.ts index c95ae27f61..061bd5cedb 100644 --- a/libs/common/src/vault/abstractions/cipher.service.ts +++ b/libs/common/src/vault/abstractions/cipher.service.ts @@ -133,7 +133,7 @@ export abstract class CipherService implements UserKeyRotationDataProvider Promise>; - replace: (ciphers: { [id: string]: CipherData }) => Promise; + replace: (ciphers: { [id: string]: CipherData }, userId: UserId) => Promise; clear: (userId?: string) => Promise; moveManyWithServer: (ids: string[], folderId: string) => Promise; delete: (id: string | string[]) => Promise; diff --git a/libs/common/src/vault/abstractions/collection.service.ts b/libs/common/src/vault/abstractions/collection.service.ts index 0c20613963..084aa3a808 100644 --- a/libs/common/src/vault/abstractions/collection.service.ts +++ b/libs/common/src/vault/abstractions/collection.service.ts @@ -1,6 +1,6 @@ import { Observable } from "rxjs"; -import { CollectionId } from "../../types/guid"; +import { CollectionId, UserId } from "../../types/guid"; import { CollectionData } from "../models/data/collection.data"; import { Collection } from "../models/domain/collection"; import { TreeNode } from "../models/domain/tree-node"; @@ -22,7 +22,7 @@ export abstract class CollectionService { getAllNested: (collections?: CollectionView[]) => Promise[]>; getNested: (id: string) => Promise>; upsert: (collection: CollectionData | CollectionData[]) => Promise; - replace: (collections: { [id: string]: CollectionData }) => Promise; + replace: (collections: { [id: string]: CollectionData }, userId: UserId) => Promise; clear: (userId?: string) => Promise; delete: (id: string | string[]) => Promise; } diff --git a/libs/common/src/vault/abstractions/folder/folder.service.abstraction.ts b/libs/common/src/vault/abstractions/folder/folder.service.abstraction.ts index 71b8089fa6..3480a8aca0 100644 --- a/libs/common/src/vault/abstractions/folder/folder.service.abstraction.ts +++ b/libs/common/src/vault/abstractions/folder/folder.service.abstraction.ts @@ -45,7 +45,7 @@ export abstract class FolderService implements UserKeyRotationDataProvider Promise; - replace: (folders: { [id: string]: FolderData }) => Promise; + replace: (folders: { [id: string]: FolderData }, userId: UserId) => Promise; clear: (userId?: string) => Promise; delete: (id: string | string[]) => Promise; } diff --git a/libs/common/src/vault/services/cipher.service.ts b/libs/common/src/vault/services/cipher.service.ts index 92676aea97..cb72d413c8 100644 --- a/libs/common/src/vault/services/cipher.service.ts +++ b/libs/common/src/vault/services/cipher.service.ts @@ -913,8 +913,8 @@ export class CipherService implements CipherServiceAbstraction { }); } - async replace(ciphers: { [id: string]: CipherData }): Promise { - await this.updateEncryptedCipherState(() => ciphers); + async replace(ciphers: { [id: string]: CipherData }, userId: UserId): Promise { + await this.updateEncryptedCipherState(() => ciphers, userId); } /** @@ -924,15 +924,18 @@ export class CipherService implements CipherServiceAbstraction { */ private async updateEncryptedCipherState( update: (current: Record) => Record, + userId: UserId = null, ): Promise> { - const userId = await firstValueFrom(this.stateProvider.activeUserId$); + userId ||= await firstValueFrom(this.stateProvider.activeUserId$); // Store that we should wait for an update to return any ciphers await this.ciphersExpectingUpdate.forceValue(true); await this.clearDecryptedCiphersState(userId); - const [, updatedCiphers] = await this.encryptedCiphersState.update((current) => { - const result = update(current ?? {}); - return result; - }); + const updatedCiphers = await this.stateProvider + .getUser(userId, ENCRYPTED_CIPHERS) + .update((current) => { + const result = update(current ?? {}); + return result; + }); return updatedCiphers; } diff --git a/libs/common/src/vault/services/collection.service.ts b/libs/common/src/vault/services/collection.service.ts index 47063aa29d..e9ad09a483 100644 --- a/libs/common/src/vault/services/collection.service.ts +++ b/libs/common/src/vault/services/collection.service.ts @@ -184,8 +184,10 @@ export class CollectionService implements CollectionServiceAbstraction { }); } - async replace(collections: Record): Promise { - await this.encryptedCollectionDataState.update(() => collections); + async replace(collections: Record, userId: UserId): Promise { + await this.stateProvider + .getUser(userId, ENCRYPTED_COLLECTION_DATA_KEY) + .update(() => collections); } async clear(userId?: UserId): Promise { diff --git a/libs/common/src/vault/services/folder/folder.service.spec.ts b/libs/common/src/vault/services/folder/folder.service.spec.ts index 6f181cf882..c27ea7646b 100644 --- a/libs/common/src/vault/services/folder/folder.service.spec.ts +++ b/libs/common/src/vault/services/folder/folder.service.spec.ts @@ -120,7 +120,7 @@ describe("Folder Service", () => { }); it("replace", async () => { - await folderService.replace({ "2": folderData("2", "test 2") }); + await folderService.replace({ "2": folderData("2", "test 2") }, mockUserId); expect(await firstValueFrom(folderService.folders$)).toEqual([ { diff --git a/libs/common/src/vault/services/folder/folder.service.ts b/libs/common/src/vault/services/folder/folder.service.ts index 7de7222edc..0c17d7178b 100644 --- a/libs/common/src/vault/services/folder/folder.service.ts +++ b/libs/common/src/vault/services/folder/folder.service.ts @@ -111,12 +111,12 @@ export class FolderService implements InternalFolderServiceAbstraction { }); } - async replace(folders: { [id: string]: FolderData }): Promise { + async replace(folders: { [id: string]: FolderData }, userId: UserId): Promise { if (!folders) { return; } - await this.encryptedFoldersState.update(() => { + await this.stateProvider.getUser(userId, FOLDER_ENCRYPTED_FOLDERS).update(() => { const newFolders: Record = { ...folders }; return newFolders; });