diff --git a/apps/browser/src/auth/background/service-factories/auth-service.factory.ts b/apps/browser/src/auth/background/service-factories/auth-service.factory.ts index bc4e621bc6..f600efa18d 100644 --- a/apps/browser/src/auth/background/service-factories/auth-service.factory.ts +++ b/apps/browser/src/auth/background/service-factories/auth-service.factory.ts @@ -24,6 +24,7 @@ import { } from "../../../platform/background/service-factories/state-service.factory"; import { AccountServiceInitOptions, accountServiceFactory } from "./account-service.factory"; +import { TokenServiceInitOptions, tokenServiceFactory } from "./token-service.factory"; type AuthServiceFactoryOptions = FactoryOptions; @@ -32,7 +33,8 @@ export type AuthServiceInitOptions = AuthServiceFactoryOptions & MessagingServiceInitOptions & CryptoServiceInitOptions & ApiServiceInitOptions & - StateServiceInitOptions; + StateServiceInitOptions & + TokenServiceInitOptions; export function authServiceFactory( cache: { authService?: AbstractAuthService } & CachedServices, @@ -49,6 +51,7 @@ export function authServiceFactory( await cryptoServiceFactory(cache, opts), await apiServiceFactory(cache, opts), await stateServiceFactory(cache, opts), + await tokenServiceFactory(cache, opts), ), ); } diff --git a/apps/browser/src/background/main.background.ts b/apps/browser/src/background/main.background.ts index ee17a7f1f0..49b4b96249 100644 --- a/apps/browser/src/background/main.background.ts +++ b/apps/browser/src/background/main.background.ts @@ -579,6 +579,7 @@ export default class MainBackground { this.cryptoService, this.apiService, this.stateService, + this.tokenService, ); this.billingAccountProfileStateService = new DefaultBillingAccountProfileStateService( diff --git a/apps/cli/src/bw.ts b/apps/cli/src/bw.ts index 7f23e6f2d0..d1105427f6 100644 --- a/apps/cli/src/bw.ts +++ b/apps/cli/src/bw.ts @@ -503,6 +503,7 @@ export class Main { this.cryptoService, this.apiService, this.stateService, + this.tokenService, ); this.configApiService = new ConfigApiService(this.apiService, this.tokenService); diff --git a/libs/angular/src/services/jslib-services.module.ts b/libs/angular/src/services/jslib-services.module.ts index a31d5141c4..b08c53ec06 100644 --- a/libs/angular/src/services/jslib-services.module.ts +++ b/libs/angular/src/services/jslib-services.module.ts @@ -349,6 +349,7 @@ const safeProviders: SafeProvider[] = [ CryptoServiceAbstraction, ApiServiceAbstraction, StateServiceAbstraction, + TokenService, ], }), safeProvider({ diff --git a/libs/common/src/auth/abstractions/auth.service.ts b/libs/common/src/auth/abstractions/auth.service.ts index 9e4fd3cd0b..de08dbd4e9 100644 --- a/libs/common/src/auth/abstractions/auth.service.ts +++ b/libs/common/src/auth/abstractions/auth.service.ts @@ -1,10 +1,17 @@ import { Observable } from "rxjs"; +import { UserId } from "../../types/guid"; import { AuthenticationStatus } from "../enums/authentication-status"; export abstract class AuthService { /** Authentication status for the active user */ abstract activeAccountStatus$: Observable; + /** + * Returns an observable authentication status for the given user id. + * @note userId is a required parameter, null values will always return `AuthenticationStatus.LoggedOut` + * @param userId The user id to check for an access token. + */ + abstract authStatusFor$(userId: UserId): Observable; /** @deprecated use {@link activeAccountStatus$} instead */ abstract getAuthStatus: (userId?: string) => Promise; abstract logOut: (callback: () => void) => void; diff --git a/libs/common/src/auth/abstractions/token.service.ts b/libs/common/src/auth/abstractions/token.service.ts index 18366c5f1b..75bb383882 100644 --- a/libs/common/src/auth/abstractions/token.service.ts +++ b/libs/common/src/auth/abstractions/token.service.ts @@ -1,8 +1,15 @@ +import { Observable } from "rxjs"; + import { VaultTimeoutAction } from "../../enums/vault-timeout-action.enum"; import { UserId } from "../../types/guid"; import { DecodedAccessToken } from "../services/token.service"; export abstract class TokenService { + /** + * Returns an observable that emits a boolean indicating whether the user has an access token. + * @param userId The user id to check for an access token. + */ + abstract hasAccessToken$(userId: UserId): Observable; /** * Sets the access token, refresh token, API Key Client ID, and API Key Client Secret in memory or disk * based on the given vaultTimeoutAction and vaultTimeout and the derived access token user id. diff --git a/libs/common/src/auth/services/auth.service.spec.ts b/libs/common/src/auth/services/auth.service.spec.ts index dd4daf8cfa..07e38def4b 100644 --- a/libs/common/src/auth/services/auth.service.spec.ts +++ b/libs/common/src/auth/services/auth.service.spec.ts @@ -1,13 +1,21 @@ import { MockProxy, mock } from "jest-mock-extended"; -import { firstValueFrom } from "rxjs"; +import { firstValueFrom, of } from "rxjs"; -import { FakeAccountService, mockAccountServiceWith } from "../../../spec"; +import { + FakeAccountService, + makeStaticByteArray, + mockAccountServiceWith, + trackEmissions, +} from "../../../spec"; import { ApiService } from "../../abstractions/api.service"; import { CryptoService } from "../../platform/abstractions/crypto.service"; import { MessagingService } from "../../platform/abstractions/messaging.service"; import { StateService } from "../../platform/abstractions/state.service"; import { Utils } from "../../platform/misc/utils"; +import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypto-key"; import { UserId } from "../../types/guid"; +import { UserKey } from "../../types/key"; +import { TokenService } from "../abstractions/token.service"; import { AuthenticationStatus } from "../enums/authentication-status"; import { AuthService } from "./auth.service"; @@ -20,15 +28,18 @@ describe("AuthService", () => { let cryptoService: MockProxy; let apiService: MockProxy; let stateService: MockProxy; + let tokenService: MockProxy; const userId = Utils.newGuid() as UserId; + const userKey = new SymmetricCryptoKey(makeStaticByteArray(32) as Uint8Array) as UserKey; beforeEach(() => { accountService = mockAccountServiceWith(userId); - messagingService = mock(); - cryptoService = mock(); - apiService = mock(); - stateService = mock(); + messagingService = mock(); + cryptoService = mock(); + apiService = mock(); + stateService = mock(); + tokenService = mock(); sut = new AuthService( accountService, @@ -36,26 +47,115 @@ describe("AuthService", () => { cryptoService, apiService, stateService, + tokenService, ); }); describe("activeAccountStatus$", () => { - test.each([ - AuthenticationStatus.LoggedOut, - AuthenticationStatus.Locked, - AuthenticationStatus.Unlocked, - ])( - `should emit %p when activeAccount$ emits an account with %p auth status`, - async (status) => { - accountService.activeAccountSubject.next({ - id: userId, - email: "email", - name: "name", - status, - }); + const accountInfo = { + status: AuthenticationStatus.Unlocked, + id: userId, + email: "email", + name: "name", + }; - expect(await firstValueFrom(sut.activeAccountStatus$)).toEqual(status); - }, - ); + beforeEach(() => { + accountService.activeAccountSubject.next(accountInfo); + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(undefined)); + }); + + it("emits LoggedOut when there is no active account", async () => { + accountService.activeAccountSubject.next(undefined); + + expect(await firstValueFrom(sut.activeAccountStatus$)).toEqual( + AuthenticationStatus.LoggedOut, + ); + }); + + it("emits LoggedOut when there is no access token", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(false)); + + expect(await firstValueFrom(sut.activeAccountStatus$)).toEqual( + AuthenticationStatus.LoggedOut, + ); + }); + + it("emits LoggedOut when there is no access token but has a user key", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(false)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(userKey)); + + expect(await firstValueFrom(sut.activeAccountStatus$)).toEqual( + AuthenticationStatus.LoggedOut, + ); + }); + + it("emits Locked when there is an access token and no user key", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(undefined)); + + expect(await firstValueFrom(sut.activeAccountStatus$)).toEqual(AuthenticationStatus.Locked); + }); + + it("emits Unlocked when there is an access token and user key", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(userKey)); + + expect(await firstValueFrom(sut.activeAccountStatus$)).toEqual(AuthenticationStatus.Unlocked); + }); + + it("follows the current active user", async () => { + const accountInfo2 = { + status: AuthenticationStatus.Unlocked, + id: Utils.newGuid() as UserId, + email: "email2", + name: "name2", + }; + + const emissions = trackEmissions(sut.activeAccountStatus$); + + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(userKey)); + accountService.activeAccountSubject.next(accountInfo2); + + expect(emissions).toEqual([AuthenticationStatus.Locked, AuthenticationStatus.Unlocked]); + }); + }); + + describe("authStatusFor$", () => { + beforeEach(() => { + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(undefined)); + }); + + it("emits LoggedOut when userId is null", async () => { + expect(await firstValueFrom(sut.authStatusFor$(null))).toEqual( + AuthenticationStatus.LoggedOut, + ); + }); + + it("emits LoggedOut when there is no access token", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(false)); + + expect(await firstValueFrom(sut.authStatusFor$(userId))).toEqual( + AuthenticationStatus.LoggedOut, + ); + }); + + it("emits Locked when there is an access token and no user key", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(undefined)); + + expect(await firstValueFrom(sut.authStatusFor$(userId))).toEqual(AuthenticationStatus.Locked); + }); + + it("emits Unlocked when there is an access token and user key", async () => { + tokenService.hasAccessToken$.mockReturnValue(of(true)); + cryptoService.getInMemoryUserKeyFor$.mockReturnValue(of(userKey)); + + expect(await firstValueFrom(sut.authStatusFor$(userId))).toEqual( + AuthenticationStatus.Unlocked, + ); + }); }); }); diff --git a/libs/common/src/auth/services/auth.service.ts b/libs/common/src/auth/services/auth.service.ts index ae5dd30a36..de5eb66c06 100644 --- a/libs/common/src/auth/services/auth.service.ts +++ b/libs/common/src/auth/services/auth.service.ts @@ -1,12 +1,22 @@ -import { Observable, distinctUntilChanged, map, shareReplay } from "rxjs"; +import { + Observable, + combineLatest, + distinctUntilChanged, + map, + of, + shareReplay, + switchMap, +} from "rxjs"; import { ApiService } from "../../abstractions/api.service"; import { CryptoService } from "../../platform/abstractions/crypto.service"; import { MessagingService } from "../../platform/abstractions/messaging.service"; import { StateService } from "../../platform/abstractions/state.service"; import { KeySuffixOptions } from "../../platform/enums"; +import { UserId } from "../../types/guid"; import { AccountService } from "../abstractions/account.service"; import { AuthService as AuthServiceAbstraction } from "../abstractions/auth.service"; +import { TokenService } from "../abstractions/token.service"; import { AuthenticationStatus } from "../enums/authentication-status"; export class AuthService implements AuthServiceAbstraction { @@ -18,9 +28,36 @@ export class AuthService implements AuthServiceAbstraction { protected cryptoService: CryptoService, protected apiService: ApiService, protected stateService: StateService, + private tokenService: TokenService, ) { this.activeAccountStatus$ = this.accountService.activeAccount$.pipe( - map((account) => account.status), + map((account) => account?.id), + switchMap((userId) => { + return this.authStatusFor$(userId); + }), + ); + } + + authStatusFor$(userId: UserId): Observable { + if (userId == null) { + return of(AuthenticationStatus.LoggedOut); + } + + return combineLatest([ + this.cryptoService.getInMemoryUserKeyFor$(userId), + this.tokenService.hasAccessToken$(userId), + ]).pipe( + map(([userKey, hasAccessToken]) => { + if (!hasAccessToken) { + return AuthenticationStatus.LoggedOut; + } + + if (!userKey) { + return AuthenticationStatus.Locked; + } + + return AuthenticationStatus.Unlocked; + }), distinctUntilChanged(), shareReplay({ bufferSize: 1, refCount: false }), ); diff --git a/libs/common/src/auth/services/token.service.spec.ts b/libs/common/src/auth/services/token.service.spec.ts index 8e8ed08853..c409263209 100644 --- a/libs/common/src/auth/services/token.service.spec.ts +++ b/libs/common/src/auth/services/token.service.spec.ts @@ -1,4 +1,5 @@ import { MockProxy, mock } from "jest-mock-extended"; +import { firstValueFrom } from "rxjs"; import { FakeSingleUserStateProvider, FakeGlobalStateProvider } from "../../../spec"; import { VaultTimeoutAction } from "../../enums/vault-timeout-action.enum"; @@ -104,6 +105,61 @@ describe("TokenService", () => { const accessTokenKeyPartialSecureStorageKey = `_accessTokenKey`; const accessTokenKeySecureStorageKey = `${userIdFromAccessToken}${accessTokenKeyPartialSecureStorageKey}`; + describe("hasAccessToken$", () => { + it("returns true when an access token exists in memory", async () => { + // Arrange + singleUserStateProvider + .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) + .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + + // Act + const result = await firstValueFrom(tokenService.hasAccessToken$(userIdFromAccessToken)); + + // Assert + expect(result).toEqual(true); + }); + + it("returns true when an access token exists in disk", async () => { + // Arrange + singleUserStateProvider + .getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY) + .stateSubject.next([userIdFromAccessToken, undefined]); + + singleUserStateProvider + .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) + .stateSubject.next([userIdFromAccessToken, accessTokenJwt]); + + // Act + const result = await firstValueFrom(tokenService.hasAccessToken$(userIdFromAccessToken)); + + // Assert + expect(result).toEqual(true); + }); + + it("returns true when an access token exists in secure storage", async () => { + // Arrange + singleUserStateProvider + .getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK) + .stateSubject.next([userIdFromAccessToken, "encryptedAccessToken"]); + + secureStorageService.get.mockResolvedValue(accessTokenKeyB64); + + // Act + const result = await firstValueFrom(tokenService.hasAccessToken$(userIdFromAccessToken)); + + // Assert + expect(result).toEqual(true); + }); + + it("should return false if no access token exists in memory, disk, or secure storage", async () => { + // Act + const result = await firstValueFrom(tokenService.hasAccessToken$(userIdFromAccessToken)); + + // Assert + expect(result).toEqual(false); + }); + }); + describe("setAccessToken", () => { it("should throw an error if the access token is null", async () => { // Act diff --git a/libs/common/src/auth/services/token.service.ts b/libs/common/src/auth/services/token.service.ts index dd011eb40b..fb13c21870 100644 --- a/libs/common/src/auth/services/token.service.ts +++ b/libs/common/src/auth/services/token.service.ts @@ -1,4 +1,4 @@ -import { firstValueFrom } from "rxjs"; +import { Observable, combineLatest, firstValueFrom, map } from "rxjs"; import { Opaque } from "type-fest"; import { decodeJwtTokenToJson } from "@bitwarden/auth/common"; @@ -135,6 +135,15 @@ export class TokenService implements TokenServiceAbstraction { this.initializeState(); } + hasAccessToken$(userId: UserId): Observable { + // FIXME Once once vault timeout action is observable, we can use it to determine storage location + // and avoid the need to check both disk and memory. + return combineLatest([ + this.singleUserStateProvider.get(userId, ACCESS_TOKEN_DISK).state$, + this.singleUserStateProvider.get(userId, ACCESS_TOKEN_MEMORY).state$, + ]).pipe(map(([disk, memory]) => Boolean(disk || memory))); + } + // pivoting to an approach where we create a symmetric key we store in secure storage // which is used to protect the data before persisting to disk. // We will also use the same symmetric key to decrypt the data when reading from disk. diff --git a/libs/common/src/platform/abstractions/crypto.service.ts b/libs/common/src/platform/abstractions/crypto.service.ts index 44ff521680..85b2bfe82e 100644 --- a/libs/common/src/platform/abstractions/crypto.service.ts +++ b/libs/common/src/platform/abstractions/crypto.service.ts @@ -13,6 +13,14 @@ import { SymmetricCryptoKey } from "../models/domain/symmetric-crypto-key"; export abstract class CryptoService { abstract activeUserKey$: Observable; + + /** + * Returns the an observable key for the given user id. + * + * @note this observable represents only user keys stored in memory. A null value does not indicate that we cannot load a user key from storage. + * @param userId The desired user + */ + abstract getInMemoryUserKeyFor$(userId: UserId): Observable; /** * Sets the provided user key and stores * any other necessary versions (such as auto, biometrics, diff --git a/libs/common/src/platform/services/crypto.service.ts b/libs/common/src/platform/services/crypto.service.ts index fbb6a85293..dd3c497470 100644 --- a/libs/common/src/platform/services/crypto.service.ts +++ b/libs/common/src/platform/services/crypto.service.ts @@ -160,6 +160,10 @@ export class CryptoService implements CryptoServiceAbstraction { await this.setUserKey(key); } + getInMemoryUserKeyFor$(userId: UserId): Observable { + return this.stateProvider.getUserState$(USER_KEY, userId); + } + async getUserKey(userId?: UserId): Promise { let userKey = await firstValueFrom(this.stateProvider.getUserState$(USER_KEY, userId)); if (userKey) {