diff --git a/apps/browser/spec/mock-port.spec-util.ts b/apps/browser/spec/mock-port.spec-util.ts new file mode 100644 index 0000000000..b5f7825d8e --- /dev/null +++ b/apps/browser/spec/mock-port.spec-util.ts @@ -0,0 +1,29 @@ +import { mockDeep } from "jest-mock-extended"; + +/** + * Mocks a chrome.runtime.Port set up to send messages through `postMessage` to `onMessage.addListener` callbacks. + * @param name - The name of the port. + * @param immediateOnConnectExecution - Whether to immediately execute the onConnect callbacks against the new port. + * Defaults to false. If true, the creator of the port will not have had a chance to set up listeners yet. + * @returns a mock chrome.runtime.Port + */ +export function mockPorts() { + // notify listeners of a new port + (chrome.runtime.connect as jest.Mock).mockImplementation((portInfo) => { + const port = mockDeep(); + port.name = portInfo.name; + + // set message broadcast + (port.postMessage as jest.Mock).mockImplementation((message) => { + (port.onMessage.addListener as jest.Mock).mock.calls.forEach(([callbackFn]) => { + callbackFn(message, port); + }); + }); + + (chrome.runtime.onConnect.addListener as jest.Mock).mock.calls.forEach(([callbackFn]) => { + callbackFn(port); + }); + + return port; + }); +} diff --git a/apps/browser/src/background/main.background.ts b/apps/browser/src/background/main.background.ts index 91bc2875de..a8fb5947e4 100644 --- a/apps/browser/src/background/main.background.ts +++ b/apps/browser/src/background/main.background.ts @@ -65,18 +65,17 @@ import { SystemService } from "@bitwarden/common/platform/services/system.servic import { WebCryptoFunctionService } from "@bitwarden/common/platform/services/web-crypto-function.service"; import { ActiveUserStateProvider, + DerivedStateProvider, GlobalStateProvider, SingleUserStateProvider, StateProvider, } from "@bitwarden/common/platform/state"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed +/* eslint-disable import/no-restricted-paths -- We need the implementation to inject, but generally these should not be accessed */ import { DefaultActiveUserStateProvider } from "@bitwarden/common/platform/state/implementations/default-active-user-state.provider"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed import { DefaultGlobalStateProvider } from "@bitwarden/common/platform/state/implementations/default-global-state.provider"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed import { DefaultSingleUserStateProvider } from "@bitwarden/common/platform/state/implementations/default-single-user-state.provider"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed import { DefaultStateProvider } from "@bitwarden/common/platform/state/implementations/default-state.provider"; +/* eslint-enable import/no-restricted-paths */ import { AvatarUpdateService } from "@bitwarden/common/services/account/avatar-update.service"; import { ApiService } from "@bitwarden/common/services/api.service"; import { AuditService } from "@bitwarden/common/services/audit.service"; @@ -162,6 +161,7 @@ import BrowserPlatformUtilsService from "../platform/services/browser-platform-u import { BrowserStateService } from "../platform/services/browser-state.service"; import { KeyGenerationService } from "../platform/services/key-generation.service"; import { LocalBackedSessionStorageService } from "../platform/services/local-backed-session-storage.service"; +import { BackgroundDerivedStateProvider } from "../platform/state/background-derived-state.provider"; import { BackgroundMemoryStorageService } from "../platform/storage/background-memory-storage.service"; import { BrowserSendService } from "../services/browser-send.service"; import { BrowserSettingsService } from "../services/browser-settings.service"; @@ -248,6 +248,7 @@ export default class MainBackground { globalStateProvider: GlobalStateProvider; singleUserStateProvider: SingleUserStateProvider; activeUserStateProvider: ActiveUserStateProvider; + derivedStateProvider: DerivedStateProvider; stateProvider: StateProvider; fido2Service: Fido2ServiceAbstraction; @@ -335,10 +336,14 @@ export default class MainBackground { this.memoryStorageService as BackgroundMemoryStorageService, this.storageService as BrowserLocalStorageService, ); + this.derivedStateProvider = new BackgroundDerivedStateProvider( + this.memoryStorageService as BackgroundMemoryStorageService, + ); this.stateProvider = new DefaultStateProvider( this.activeUserStateProvider, this.singleUserStateProvider, this.globalStateProvider, + this.derivedStateProvider, ); this.stateService = new BrowserStateService( this.storageService, diff --git a/apps/browser/src/platform/background/service-factories/derived-state-provider.factory.ts b/apps/browser/src/platform/background/service-factories/derived-state-provider.factory.ts new file mode 100644 index 0000000000..4f329c93d5 --- /dev/null +++ b/apps/browser/src/platform/background/service-factories/derived-state-provider.factory.ts @@ -0,0 +1,27 @@ +import { DerivedStateProvider } from "@bitwarden/common/platform/state"; + +import { BackgroundDerivedStateProvider } from "../../state/background-derived-state.provider"; + +import { CachedServices, FactoryOptions, factory } from "./factory-options"; +import { + MemoryStorageServiceInitOptions, + observableMemoryStorageServiceFactory, +} from "./storage-service.factory"; + +type DerivedStateProviderFactoryOptions = FactoryOptions; + +export type DerivedStateProviderInitOptions = DerivedStateProviderFactoryOptions & + MemoryStorageServiceInitOptions; + +export async function derivedStateProviderFactory( + cache: { derivedStateProvider?: DerivedStateProvider } & CachedServices, + opts: DerivedStateProviderInitOptions, +): Promise { + return factory( + cache, + "derivedStateProvider", + opts, + async () => + new BackgroundDerivedStateProvider(await observableMemoryStorageServiceFactory(cache, opts)), + ); +} diff --git a/apps/browser/src/platform/background/service-factories/state-provider.factory.ts b/apps/browser/src/platform/background/service-factories/state-provider.factory.ts index 69c4b9d011..b5ae9c709f 100644 --- a/apps/browser/src/platform/background/service-factories/state-provider.factory.ts +++ b/apps/browser/src/platform/background/service-factories/state-provider.factory.ts @@ -6,6 +6,10 @@ import { ActiveUserStateProviderInitOptions, activeUserStateProviderFactory, } from "./active-user-state-provider.factory"; +import { + DerivedStateProviderInitOptions, + derivedStateProviderFactory, +} from "./derived-state-provider.factory"; import { CachedServices, FactoryOptions, factory } from "./factory-options"; import { GlobalStateProviderInitOptions, @@ -21,7 +25,8 @@ type StateProviderFactoryOptions = FactoryOptions; export type StateProviderInitOptions = StateProviderFactoryOptions & GlobalStateProviderInitOptions & ActiveUserStateProviderInitOptions & - SingleUserStateProviderInitOptions; + SingleUserStateProviderInitOptions & + DerivedStateProviderInitOptions; export async function stateProviderFactory( cache: { stateProvider?: StateProvider } & CachedServices, @@ -36,6 +41,7 @@ export async function stateProviderFactory( await activeUserStateProviderFactory(cache, opts), await singleUserStateProviderFactory(cache, opts), await globalStateProviderFactory(cache, opts), + await derivedStateProviderFactory(cache, opts), ), ); } diff --git a/apps/browser/src/platform/state/background-derived-state.provider.ts b/apps/browser/src/platform/state/background-derived-state.provider.ts new file mode 100644 index 0000000000..8057c3bfcd --- /dev/null +++ b/apps/browser/src/platform/state/background-derived-state.provider.ts @@ -0,0 +1,23 @@ +import { Observable } from "rxjs"; + +import { DeriveDefinition, DerivedState } from "@bitwarden/common/platform/state"; +// eslint-disable-next-line import/no-restricted-paths -- extending this class for this client +import { DefaultDerivedStateProvider } from "@bitwarden/common/platform/state/implementations/default-derived-state.provider"; +import { ShapeToInstances, Type } from "@bitwarden/common/src/types/state"; + +import { BackgroundDerivedState } from "./background-derived-state"; + +export class BackgroundDerivedStateProvider extends DefaultDerivedStateProvider { + override buildDerivedState>>( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ): DerivedState { + return new BackgroundDerivedState( + parentState$, + deriveDefinition, + this.memoryStorage, + dependencies, + ); + } +} diff --git a/apps/browser/src/platform/state/background-derived-state.ts b/apps/browser/src/platform/state/background-derived-state.ts new file mode 100644 index 0000000000..7e08d1543d --- /dev/null +++ b/apps/browser/src/platform/state/background-derived-state.ts @@ -0,0 +1,131 @@ +import { Observable, Subject, Subscription } from "rxjs"; +import { Jsonify } from "type-fest"; + +import { + AbstractStorageService, + ObservableStorageService, +} from "@bitwarden/common/platform/abstractions/storage.service"; +import { Utils } from "@bitwarden/common/platform/misc/utils"; +import { DeriveDefinition } from "@bitwarden/common/platform/state"; +// eslint-disable-next-line import/no-restricted-paths -- extending this class for this client +import { DefaultDerivedState } from "@bitwarden/common/platform/state/implementations/default-derived-state"; +import { ShapeToInstances, Type } from "@bitwarden/common/types/state"; + +import { BrowserApi } from "../browser/browser-api"; + +export class BackgroundDerivedState< + TFrom, + TTo, + TDeps extends Record>, +> extends DefaultDerivedState { + private portSubscriptions: Map< + chrome.runtime.Port, + { subscription: Subscription; delaySubject: Subject } + > = new Map(); + + constructor( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + memoryStorage: AbstractStorageService & ObservableStorageService, + dependencies: ShapeToInstances, + ) { + super(parentState$, deriveDefinition, memoryStorage, dependencies); + const portName = deriveDefinition.buildCacheKey(); + + // listen for foreground derived states to connect + BrowserApi.addListener(chrome.runtime.onConnect, (port) => { + if (port.name !== portName) { + return; + } + + const listenerCallback = this.onMessageFromForeground.bind(this); + port.onDisconnect.addListener(() => { + const { subscription, delaySubject } = this.portSubscriptions.get(port) ?? { + subscription: null, + delaySubject: null, + }; + subscription?.unsubscribe(); + delaySubject?.complete(); + this.portSubscriptions.delete(port); + port.onMessage.removeListener(listenerCallback); + }); + port.onMessage.addListener(listenerCallback); + + const delaySubject = new Subject(); + const stateSubscription = this.state$.subscribe((state) => { + // delay to allow the foreground to connect. This may just be needed for testing + setTimeout(() => { + this.sendNewMessage( + { + action: "nextState", + data: JSON.stringify(state), + }, + port, + ); + }, 0); + }); + + this.portSubscriptions.set(port, { subscription: stateSubscription, delaySubject }); + }); + } + + private async onMessageFromForeground(message: DerivedStateMessage, port: chrome.runtime.Port) { + if (message.originator === "background") { + return; + } + + switch (message.action) { + case "nextState": { + const dataObj = JSON.parse(message.data) as Jsonify; + const data = this.deriveDefinition.deserialize(dataObj); + await this.forceValue(data); + await this.sendResponse( + message, + { + action: "resolve", + }, + port, + ); + break; + } + } + } + + private async sendNewMessage( + message: Omit, + port: chrome.runtime.Port, + ) { + const id = Utils.newGuid(); + this.sendMessage( + { + ...message, + id: id, + }, + port, + ); + } + + private async sendResponse( + originalMessage: DerivedStateMessage, + response: Omit, + port: chrome.runtime.Port, + ) { + this.sendMessage( + { + ...response, + id: originalMessage.id, + }, + port, + ); + } + + private async sendMessage( + message: Omit, + port: chrome.runtime.Port, + ) { + port.postMessage({ + ...message, + originator: "background", + }); + } +} diff --git a/apps/browser/src/platform/state/derived-state-service-interactions.spec.ts b/apps/browser/src/platform/state/derived-state-service-interactions.spec.ts new file mode 100644 index 0000000000..246507c426 --- /dev/null +++ b/apps/browser/src/platform/state/derived-state-service-interactions.spec.ts @@ -0,0 +1,112 @@ +/** + * need to update test environment so structuredClone works appropriately + * @jest-environment ../../libs/shared/test.environment.ts + */ + +import { FakeStorageService } from "@bitwarden/common/../spec/fake-storage.service"; +import { awaitAsync, trackEmissions } from "@bitwarden/common/../spec/utils"; +import { Subject, firstValueFrom } from "rxjs"; + +import { DeriveDefinition } from "@bitwarden/common/platform/state"; +// eslint-disable-next-line import/no-restricted-paths -- needed to define a derive definition +import { StateDefinition } from "@bitwarden/common/platform/state/state-definition"; +import { Type } from "@bitwarden/common/types/state"; + +import { mockPorts } from "../../../spec/mock-port.spec-util"; + +import { BackgroundDerivedState } from "./background-derived-state"; +import { ForegroundDerivedState } from "./foreground-derived-state"; + +const stateDefinition = new StateDefinition("test", "memory"); +const deriveDefinition = new DeriveDefinition(stateDefinition, "test", { + derive: (dateString: string) => (dateString == null ? null : new Date(dateString)), + deserializer: (dateString: string) => (dateString == null ? null : new Date(dateString)), +}); + +describe("foreground background derived state interactions", () => { + let foreground: ForegroundDerivedState; + let background: BackgroundDerivedState>>; + let parentState$: Subject; + let memoryStorage: FakeStorageService; + const initialParent = "2020-01-01"; + + beforeEach(() => { + mockPorts(); + parentState$ = new Subject(); + memoryStorage = new FakeStorageService(); + + background = new BackgroundDerivedState(parentState$, deriveDefinition, memoryStorage, {}); + foreground = new ForegroundDerivedState(deriveDefinition); + }); + + afterEach(() => { + parentState$.complete(); + jest.resetAllMocks(); + }); + + it("should connect between foreground and background", async () => { + const foregroundEmissions = trackEmissions(foreground.state$); + const backgroundEmissions = trackEmissions(background.state$); + + parentState$.next(initialParent); + await awaitAsync(10); + + expect(foregroundEmissions).toEqual([new Date(initialParent)]); + expect(backgroundEmissions).toEqual([new Date(initialParent)]); + }); + + it("should initialize a late-connected foreground", async () => { + const newForeground = new ForegroundDerivedState(deriveDefinition); + const backgroundEmissions = trackEmissions(background.state$); + parentState$.next(initialParent); + await awaitAsync(); + + const foregroundEmissions = trackEmissions(newForeground.state$); + await awaitAsync(10); + + expect(backgroundEmissions).toEqual([new Date(initialParent)]); + expect(foregroundEmissions).toEqual([new Date(initialParent)]); + }); + + describe("forceValue", () => { + it("should force the value to the background", async () => { + const dateString = "2020-12-12"; + const emissions = trackEmissions(background.state$); + + foreground.forceValue(new Date(dateString)); + await awaitAsync(); + + expect(emissions).toEqual([new Date(dateString)]); + }); + + it("should not create new ports if already connected", async () => { + // establish port with subscription + trackEmissions(foreground.state$); + + const connectMock = chrome.runtime.connect as jest.Mock; + const initialConnectCalls = connectMock.mock.calls.length; + + expect(foreground["port"]).toBeDefined(); + const newDate = new Date(); + foreground.forceValue(newDate); + await awaitAsync(); + + expect(connectMock.mock.calls.length).toBe(initialConnectCalls); + expect(await firstValueFrom(background.state$)).toEqual(newDate); + }); + + it("should create a port if not connected", async () => { + const connectMock = chrome.runtime.connect as jest.Mock; + const initialConnectCalls = connectMock.mock.calls.length; + + expect(foreground["port"]).toBeUndefined(); + const newDate = new Date(); + foreground.forceValue(newDate); + await awaitAsync(); + + expect(connectMock.mock.calls.length).toBe(initialConnectCalls + 1); + expect(foreground["port"]).toBeNull(); + expect(await firstValueFrom(background.state$)).toEqual(newDate); + }); + }); +}); diff --git a/apps/browser/src/platform/state/foreground-derived-state.provider.ts b/apps/browser/src/platform/state/foreground-derived-state.provider.ts new file mode 100644 index 0000000000..bbbf10bdd1 --- /dev/null +++ b/apps/browser/src/platform/state/foreground-derived-state.provider.ts @@ -0,0 +1,18 @@ +import { Observable } from "rxjs"; + +import { DeriveDefinition, DerivedState } from "@bitwarden/common/platform/state"; +// eslint-disable-next-line import/no-restricted-paths -- extending this class for this client +import { DefaultDerivedStateProvider } from "@bitwarden/common/platform/state/implementations/default-derived-state.provider"; +import { ShapeToInstances, Type } from "@bitwarden/common/src/types/state"; + +import { ForegroundDerivedState } from "./foreground-derived-state"; + +export class ForegroundDerivedStateProvider extends DefaultDerivedStateProvider { + override buildDerivedState>>( + _parentState$: Observable, + deriveDefinition: DeriveDefinition, + _dependencies: ShapeToInstances, + ): DerivedState { + return new ForegroundDerivedState(deriveDefinition); + } +} diff --git a/apps/browser/src/platform/state/foreground-derived-state.spec.ts b/apps/browser/src/platform/state/foreground-derived-state.spec.ts new file mode 100644 index 0000000000..965e51e446 --- /dev/null +++ b/apps/browser/src/platform/state/foreground-derived-state.spec.ts @@ -0,0 +1,61 @@ +import { awaitAsync } from "@bitwarden/common/../spec/utils"; + +import { DeriveDefinition } from "@bitwarden/common/platform/state"; +// eslint-disable-next-line import/no-restricted-paths -- needed to define a derive definition +import { StateDefinition } from "@bitwarden/common/platform/state/state-definition"; + +import { mockPorts } from "../../../spec/mock-port.spec-util"; + +import { ForegroundDerivedState } from "./foreground-derived-state"; + +const stateDefinition = new StateDefinition("test", "memory"); +const deriveDefinition = new DeriveDefinition(stateDefinition, "test", { + derive: (dateString: string) => (dateString == null ? null : new Date(dateString)), + deserializer: (dateString: string) => (dateString == null ? null : new Date(dateString)), + cleanupDelayMs: 1, +}); + +describe("ForegroundDerivedState", () => { + let sut: ForegroundDerivedState; + + beforeEach(() => { + mockPorts(); + sut = new ForegroundDerivedState(deriveDefinition); + }); + + afterEach(() => { + jest.resetAllMocks(); + }); + + it("should not connect a port until subscribed", async () => { + expect(sut["port"]).toBeUndefined(); + const subscription = sut.state$.subscribe(); + + expect(sut["port"]).toBeDefined(); + subscription.unsubscribe(); + }); + + it("should disconnect its port when unsubscribed", async () => { + const subscription = sut.state$.subscribe(); + + expect(sut["port"]).toBeDefined(); + const disconnectSpy = jest.spyOn(sut["port"], "disconnect"); + subscription.unsubscribe(); + // wait for the cleanup delay + await awaitAsync(deriveDefinition.cleanupDelayMs * 2); + + expect(disconnectSpy).toHaveBeenCalled(); + expect(sut["port"]).toBeNull(); + }); + + it("should complete its replay subject when torn down", async () => { + const subscription = sut.state$.subscribe(); + + const completeSpy = jest.spyOn(sut["replaySubject"], "complete"); + subscription.unsubscribe(); + // wait for the cleanup delay + await awaitAsync(deriveDefinition.cleanupDelayMs * 2); + + expect(completeSpy).toHaveBeenCalled(); + }); +}); diff --git a/apps/browser/src/platform/state/foreground-derived-state.ts b/apps/browser/src/platform/state/foreground-derived-state.ts new file mode 100644 index 0000000000..3200507b8c --- /dev/null +++ b/apps/browser/src/platform/state/foreground-derived-state.ts @@ -0,0 +1,112 @@ +import { + Observable, + ReplaySubject, + defer, + filter, + firstValueFrom, + map, + share, + tap, + timer, +} from "rxjs"; + +import { Utils } from "@bitwarden/common/platform/misc/utils"; +import { DeriveDefinition, DerivedState } from "@bitwarden/common/platform/state"; +import { Type } from "@bitwarden/common/types/state"; + +import { fromChromeEvent } from "../browser/from-chrome-event"; + +export class ForegroundDerivedState implements DerivedState { + private port: chrome.runtime.Port; + // For testing purposes + private replaySubject: ReplaySubject; + private backgroundResponses$: Observable; + state$: Observable; + + constructor( + private deriveDefinition: DeriveDefinition>>, + ) { + this.state$ = defer(() => this.initializePort()).pipe( + filter((message) => message.action === "nextState"), + map((message) => this.hydrateNext(message.data)), + share({ + connector: () => { + this.replaySubject = new ReplaySubject(1); + return this.replaySubject; + }, + resetOnRefCountZero: () => + timer(this.deriveDefinition.cleanupDelayMs).pipe(tap(() => this.tearDown())), + }), + ); + } + + async forceValue(value: TTo): Promise { + let cleanPort = false; + if (this.port == null) { + this.initializePort(); + cleanPort = true; + } + await this.delegateToBackground("nextState", value); + if (cleanPort) { + this.tearDownPort(); + } + return value; + } + + private initializePort(): Observable { + if (this.port != null) { + return; + } + + this.port = chrome.runtime.connect({ name: this.deriveDefinition.buildCacheKey() }); + + this.backgroundResponses$ = fromChromeEvent(this.port.onMessage).pipe( + map(([message]) => message as DerivedStateMessage), + filter((message) => message.originator === "background"), + ); + return this.backgroundResponses$; + } + + private async delegateToBackground(action: DerivedStateActions, data: TTo): Promise { + const id = Utils.newGuid(); + // listen for response before request + const response = firstValueFrom( + this.backgroundResponses$.pipe(filter((message) => message.id === id)), + ); + + this.sendMessage({ + id, + action, + data: JSON.stringify(data), + }); + + await response; + } + + private sendMessage(message: Omit) { + this.port.postMessage({ + ...message, + originator: "foreground", + }); + } + + private hydrateNext(value: string): TTo { + const jsonObj = JSON.parse(value); + return this.deriveDefinition.deserialize(jsonObj); + } + + private tearDownPort() { + if (this.port == null) { + return; + } + + this.port.disconnect(); + this.port = null; + this.backgroundResponses$ = null; + } + + private tearDown() { + this.tearDownPort(); + this.replaySubject.complete(); + } +} diff --git a/apps/browser/src/platform/state/port-message.d.ts b/apps/browser/src/platform/state/port-message.d.ts new file mode 100644 index 0000000000..3aa23aac39 --- /dev/null +++ b/apps/browser/src/platform/state/port-message.d.ts @@ -0,0 +1,7 @@ +type DerivedStateActions = "nextState" | "resolve"; +type DerivedStateMessage = { + id: string; + action: DerivedStateActions; + data?: string; // Json stringified TTo + originator: "foreground" | "background"; +}; diff --git a/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts b/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts index 77aa096073..43ffb6a065 100644 --- a/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts +++ b/apps/browser/src/platform/storage/memory-storage-service-interactions.spec.ts @@ -5,9 +5,10 @@ import { trackEmissions } from "@bitwarden/common/../spec/utils"; +import { mockPorts } from "../../../spec/mock-port.spec-util"; + import { BackgroundMemoryStorageService } from "./background-memory-storage.service"; import { ForegroundMemoryStorageService } from "./foreground-memory-storage.service"; -import { mockPorts } from "./mock-ports.spec-util"; describe("foreground background memory storage interaction", () => { let foreground: ForegroundMemoryStorageService; diff --git a/apps/browser/src/popup/services/services.module.ts b/apps/browser/src/popup/services/services.module.ts index 1722109c67..f4db8f0d43 100644 --- a/apps/browser/src/popup/services/services.module.ts +++ b/apps/browser/src/popup/services/services.module.ts @@ -67,6 +67,7 @@ import { GlobalState } from "@bitwarden/common/platform/models/domain/global-sta import { ConfigService } from "@bitwarden/common/platform/services/config/config.service"; import { ConsoleLogService } from "@bitwarden/common/platform/services/console-log.service"; import { ContainerService } from "@bitwarden/common/platform/services/container.service"; +import { DerivedStateProvider } from "@bitwarden/common/platform/state"; import { SearchService } from "@bitwarden/common/services/search.service"; import { PasswordGenerationServiceAbstraction } from "@bitwarden/common/tools/generator/password"; import { UsernameGenerationServiceAbstraction } from "@bitwarden/common/tools/generator/username"; @@ -109,6 +110,7 @@ import BrowserLocalStorageService from "../../platform/services/browser-local-st import BrowserMessagingPrivateModePopupService from "../../platform/services/browser-messaging-private-mode-popup.service"; import BrowserMessagingService from "../../platform/services/browser-messaging.service"; import { BrowserStateService } from "../../platform/services/browser-state.service"; +import { ForegroundDerivedStateProvider } from "../../platform/state/foreground-derived-state.provider"; import { ForegroundMemoryStorageService } from "../../platform/storage/foreground-memory-storage.service"; import { BrowserSendService } from "../../services/browser-send.service"; import { BrowserSettingsService } from "../../services/browser-settings.service"; @@ -552,6 +554,11 @@ function getBgService(service: keyof MainBackground) { }, deps: [PlatformUtilsService], }, + { + provide: DerivedStateProvider, + useClass: ForegroundDerivedStateProvider, + deps: [OBSERVABLE_MEMORY_STORAGE], + }, ], }) export class ServicesModule {} diff --git a/apps/cli/src/bw.ts b/apps/cli/src/bw.ts index 557bbe733b..4572c16879 100644 --- a/apps/cli/src/bw.ts +++ b/apps/cli/src/bw.ts @@ -48,18 +48,18 @@ import { NoopMessagingService } from "@bitwarden/common/platform/services/noop-m import { StateService } from "@bitwarden/common/platform/services/state.service"; import { ActiveUserStateProvider, + DerivedStateProvider, GlobalStateProvider, SingleUserStateProvider, StateProvider, } from "@bitwarden/common/platform/state"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed +/* eslint-disable import/no-restricted-paths -- We need the implementation to inject, but generally these should not be accessed */ import { DefaultActiveUserStateProvider } from "@bitwarden/common/platform/state/implementations/default-active-user-state.provider"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed +import { DefaultDerivedStateProvider } from "@bitwarden/common/platform/state/implementations/default-derived-state.provider"; import { DefaultGlobalStateProvider } from "@bitwarden/common/platform/state/implementations/default-global-state.provider"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed import { DefaultSingleUserStateProvider } from "@bitwarden/common/platform/state/implementations/default-single-user-state.provider"; -// eslint-disable-next-line import/no-restricted-paths -- We need the implementation to inject, but generally this should not be accessed import { DefaultStateProvider } from "@bitwarden/common/platform/state/implementations/default-state.provider"; +/* eslint-enable import/no-restricted-paths */ import { AuditService } from "@bitwarden/common/services/audit.service"; import { EventCollectionService } from "@bitwarden/common/services/event/event-collection.service"; import { EventUploadService } from "@bitwarden/common/services/event/event-upload.service"; @@ -179,6 +179,7 @@ export class Main { globalStateProvider: GlobalStateProvider; singleUserStateProvider: SingleUserStateProvider; activeUserStateProvider: ActiveUserStateProvider; + derivedStateProvider: DerivedStateProvider; stateProvider: StateProvider; constructor() { @@ -245,10 +246,13 @@ export class Main { this.storageService, ); + this.derivedStateProvider = new DefaultDerivedStateProvider(this.memoryStorageService); + this.stateProvider = new DefaultStateProvider( this.activeUserStateProvider, this.singleUserStateProvider, this.globalStateProvider, + this.derivedStateProvider, ); this.stateService = new StateService( diff --git a/libs/angular/src/services/jslib-services.module.ts b/libs/angular/src/services/jslib-services.module.ts index 412e3ce5ae..7a58aff953 100644 --- a/libs/angular/src/services/jslib-services.module.ts +++ b/libs/angular/src/services/jslib-services.module.ts @@ -112,9 +112,11 @@ import { GlobalStateProvider, SingleUserStateProvider, StateProvider, + DerivedStateProvider, } from "@bitwarden/common/platform/state"; /* eslint-disable import/no-restricted-paths -- We need the implementations to inject, but generally these should not be accessed */ import { DefaultActiveUserStateProvider } from "@bitwarden/common/platform/state/implementations/default-active-user-state.provider"; +import { DefaultDerivedStateProvider } from "@bitwarden/common/platform/state/implementations/default-derived-state.provider"; import { DefaultGlobalStateProvider } from "@bitwarden/common/platform/state/implementations/default-global-state.provider"; import { DefaultSingleUserStateProvider } from "@bitwarden/common/platform/state/implementations/default-single-user-state.provider"; import { DefaultStateProvider } from "@bitwarden/common/platform/state/implementations/default-state.provider"; @@ -810,10 +812,20 @@ import { ModalService } from "./modal.service"; useClass: DefaultSingleUserStateProvider, deps: [EncryptService, OBSERVABLE_MEMORY_STORAGE, OBSERVABLE_DISK_STORAGE], }, + { + provide: DerivedStateProvider, + useClass: DefaultDerivedStateProvider, + deps: [OBSERVABLE_MEMORY_STORAGE], + }, { provide: StateProvider, useClass: DefaultStateProvider, - deps: [ActiveUserStateProvider, SingleUserStateProvider, GlobalStateProvider], + deps: [ + ActiveUserStateProvider, + SingleUserStateProvider, + GlobalStateProvider, + DerivedStateProvider, + ], }, ], }) diff --git a/libs/common/spec/fake-state-provider.ts b/libs/common/spec/fake-state-provider.ts index bc5a6da6a7..36ffe8e186 100644 --- a/libs/common/spec/fake-state-provider.ts +++ b/libs/common/spec/fake-state-provider.ts @@ -1,3 +1,5 @@ +import { Observable } from "rxjs"; + import { GlobalState, GlobalStateProvider, @@ -7,10 +9,19 @@ import { SingleUserStateProvider, StateProvider, ActiveUserStateProvider, + DerivedState, + DeriveDefinition, + DerivedStateProvider, } from "../src/platform/state"; import { UserId } from "../src/types/guid"; +import { ShapeToInstances, DerivedStateDependencies } from "../src/types/state"; -import { FakeActiveUserState, FakeGlobalState, FakeSingleUserState } from "./fake-state"; +import { + FakeActiveUserState, + FakeDerivedState, + FakeGlobalState, + FakeSingleUserState, +} from "./fake-state"; export class FakeGlobalStateProvider implements GlobalStateProvider { states: Map> = new Map(); @@ -78,7 +89,33 @@ export class FakeStateProvider implements StateProvider { return this.singleUser.get(userId, keyDefinition); } + getDerived( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ): DerivedState { + return this.derived.get(parentState$, deriveDefinition, dependencies); + } + global: FakeGlobalStateProvider = new FakeGlobalStateProvider(); singleUser: FakeSingleUserStateProvider = new FakeSingleUserStateProvider(); activeUser: FakeActiveUserStateProvider = new FakeActiveUserStateProvider(); + derived: FakeDerivedStateProvider = new FakeDerivedStateProvider(); +} + +export class FakeDerivedStateProvider implements DerivedStateProvider { + states: Map> = new Map(); + get( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ): DerivedState { + let result = this.states.get(deriveDefinition.buildCacheKey()) as DerivedState; + + if (result == null) { + result = new FakeDerivedState(); + this.states.set(deriveDefinition.buildCacheKey(), result); + } + return result; + } } diff --git a/libs/common/spec/fake-state.ts b/libs/common/spec/fake-state.ts index 2d455fe237..b3ea80d3e4 100644 --- a/libs/common/spec/fake-state.ts +++ b/libs/common/spec/fake-state.ts @@ -1,11 +1,6 @@ import { ReplaySubject, firstValueFrom, timeout } from "rxjs"; -import { - DerivedUserState, - GlobalState, - SingleUserState, - ActiveUserState, -} from "../src/platform/state"; +import { DerivedState, GlobalState, SingleUserState, ActiveUserState } from "../src/platform/state"; // eslint-disable-next-line import/no-restricted-paths -- using unexposed options for clean typing in test class import { StateUpdateOptions } from "../src/platform/state/state-update-options"; // eslint-disable-next-line import/no-restricted-paths -- using unexposed options for clean typing in test class @@ -92,10 +87,6 @@ export class FakeUserState implements UserState { options?: StateUpdateOptions, ) => Promise = jest.fn(); - createDerived: ( - converter: (data: T, context: any) => Promise, - ) => DerivedUserState = jest.fn(); - getFromState: () => Promise = jest.fn(async () => { return await firstValueFrom(this.state$.pipe(timeout(10))); }); @@ -113,3 +104,18 @@ export class FakeSingleUserState extends FakeUserState implements SingleUs export class FakeActiveUserState extends FakeUserState implements ActiveUserState { [activeMarker]: true; } + +export class FakeDerivedState implements DerivedState { + // eslint-disable-next-line rxjs/no-exposed-subjects -- exposed for testing setup + stateSubject = new ReplaySubject(1); + + forceValue(value: T): Promise { + this.stateSubject.next(value); + return Promise.resolve(value); + } + forceValueMock = this.forceValue as jest.MockedFunction; + + get state$() { + return this.stateSubject.asObservable(); + } +} diff --git a/libs/common/src/platform/state/derive-definition.ts b/libs/common/src/platform/state/derive-definition.ts new file mode 100644 index 0000000000..f698f820f7 --- /dev/null +++ b/libs/common/src/platform/state/derive-definition.ts @@ -0,0 +1,131 @@ +import { Jsonify } from "type-fest"; + +import { DerivedStateDependencies, ShapeToInstances, StorageKey } from "../../types/state"; + +import { KeyDefinition } from "./key-definition"; +import { StateDefinition } from "./state-definition"; + +declare const depShapeMarker: unique symbol; +/** + * A set of options for customizing the behavior of a {@link DeriveDefinition} + */ +type DeriveDefinitionOptions = { + /** + * A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable + * and the resulting value will be emitted from the derived state observable. + * + * @param from Populated with the latest emission from the parent state observable. + * @param deps Populated with the dependencies passed into the constructor of the derived state. + * These are constant for the lifetime of the derived state. + * @returns The derived state value or a Promise that resolves to the derived state value. + */ + derive: (from: TFrom, deps: ShapeToInstances) => TTo | Promise; + /** + * A function to use to safely convert your type from json to your expected type. + * + * **Important:** Your data may be serialized/deserialized at any time and this + * callback needs to be able to faithfully re-initialize from the JSON object representation of your type. + * + * @param jsonValue The JSON object representation of your state. + * @returns The fully typed version of your state. + */ + deserializer: (serialized: Jsonify) => TTo; + /** + * An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies + * and the values are the types of the dependencies. + * + * for example: + * ``` + * { + * myService: MyService, + * myOtherService: MyOtherService, + * } + * ``` + */ + [depShapeMarker]?: TDeps; + /** + * The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed. + * Defaults to 1000ms. + */ + cleanupDelayMs?: number; +}; + +/** + * DeriveDefinitions describe state derived from another observable, the value type of which is given by `TFrom`. + * + * The StateDefinition is used to describe the domain of the state, and the DeriveDefinition + * sub-divides that domain into specific keys. These keys are used to cache data in memory and enables derived state to + * be calculated once regardless of multiple execution contexts. + */ + +export class DeriveDefinition { + /** + * Creates a new instance of a DeriveDefinition. Derived state is always stored in memory, so the storage location + * defined in @link{StateDefinition} is ignored. + * + * @param stateDefinition The state definition for which this key belongs to. + * @param uniqueDerivationName The name of the key, this should be unique per domain. + * @param options A set of options to customize the behavior of {@link DeriveDefinition}. + * @param options.derive A function to use to convert values from TFrom to TTo. This is called on each emit of the parent state observable + * and the resulting value will be emitted from the derived state observable. + * @param options.cleanupDelayMs The number of milliseconds to wait before cleaning up the state after the last subscriber has unsubscribed. + * Defaults to 1000ms. + * @param options.dependencyShape An object defining the dependencies of the derive function. The keys of the object are the names of the dependencies + * and the values are the types of the dependencies. + * for example: + * ``` + * { + * myService: MyService, + * myOtherService: MyOtherService, + * } + * ``` + * + * @param options.deserializer A function to use to safely convert your type from json to your expected type. + * Your data may be serialized/deserialized at any time and this needs callback needs to be able to faithfully re-initialize + * from the JSON object representation of your type. + */ + constructor( + readonly stateDefinition: StateDefinition, + readonly uniqueDerivationName: string, + readonly options: DeriveDefinitionOptions, + ) {} + + /** + * Factory that produces a {@link DeriveDefinition} from a {@link KeyDefinition} and a set of options. The returned + * definition will have the same key as the given key definition, but will not collide with it in storage, even if + * they both reside in memory. + * @param keyDefinition + * @param options + * @returns + */ + static from( + keyDefinition: KeyDefinition, + options: DeriveDefinitionOptions, + ) { + return new DeriveDefinition(keyDefinition.stateDefinition, keyDefinition.key, options); + } + + get derive() { + return this.options.derive; + } + + deserialize(serialized: Jsonify): TTo { + return this.options.deserializer(serialized); + } + + get cleanupDelayMs() { + return this.options.cleanupDelayMs < 0 ? 0 : this.options.cleanupDelayMs ?? 1000; + } + + buildCacheKey(): string { + return `derived_${this.stateDefinition.name}_${this.uniqueDerivationName}`; + } + + /** + * Creates a {@link StorageKey} that points to the data for the given derived definition. + * @returns A key that is ready to be used in a storage service to get data. + */ + get storageKey(): StorageKey { + return `derived_${this.stateDefinition.name}_${this.uniqueDerivationName}` as StorageKey; + } +} diff --git a/libs/common/src/platform/state/derived-state.provider.ts b/libs/common/src/platform/state/derived-state.provider.ts new file mode 100644 index 0000000000..55db8bf788 --- /dev/null +++ b/libs/common/src/platform/state/derived-state.provider.ts @@ -0,0 +1,25 @@ +import { Observable } from "rxjs"; + +import { ShapeToInstances, DerivedStateDependencies } from "../../types/state"; + +import { DeriveDefinition } from "./derive-definition"; +import { DerivedState } from "./derived-state"; + +/** + * State derived from an observable and a derive function + */ +export abstract class DerivedStateProvider { + /** + * Creates a derived state observable from a parent state observable, a deriveDefinition, and the dependencies + * required by the deriveDefinition + * @param parentState$ The parent state observable + * @param deriveDefinition The deriveDefinition that defines conversion from the parent state to the derived state as + * well as some memory persistent information. + * @param dependencies The dependencies of the derive function + */ + get: ( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ) => DerivedState; +} diff --git a/libs/common/src/platform/state/derived-state.ts b/libs/common/src/platform/state/derived-state.ts new file mode 100644 index 0000000000..b466c3024f --- /dev/null +++ b/libs/common/src/platform/state/derived-state.ts @@ -0,0 +1,23 @@ +import { Observable } from "rxjs"; + +export type StateConverter, TTo> = (...args: TFrom) => TTo; + +/** + * State derived from an observable and a converter function + * + * Derived state is cached and persisted to memory for sychronization across execution contexts. + * For clients with multiple execution contexts, the derived state will be executed only once in the background process. + */ +export interface DerivedState { + /** + * The derived state observable + */ + state$: Observable; + /** + * Forces the derived state to a given value. + * + * Useful for setting an in-memory value as a side effect of some event, such as emptying state as a result of a lock. + * @param value The value to force the derived state to + */ + forceValue(value: T): Promise; +} diff --git a/libs/common/src/platform/state/derived-user-state.ts b/libs/common/src/platform/state/derived-user-state.ts deleted file mode 100644 index 89e0b6ec76..0000000000 --- a/libs/common/src/platform/state/derived-user-state.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { Observable } from "rxjs"; - -export interface DerivedUserState { - state$: Observable; -} diff --git a/libs/common/src/platform/state/implementations/default-active-user-state.ts b/libs/common/src/platform/state/implementations/default-active-user-state.ts index 02cd53cfb8..51688be338 100644 --- a/libs/common/src/platform/state/implementations/default-active-user-state.ts +++ b/libs/common/src/platform/state/implementations/default-active-user-state.ts @@ -18,12 +18,10 @@ import { AbstractStorageService, ObservableStorageService, } from "../../abstractions/storage.service"; -import { DerivedUserState } from "../derived-user-state"; import { KeyDefinition, userKeyBuilder } from "../key-definition"; import { StateUpdateOptions, populateOptionsWithDefault } from "../state-update-options"; -import { Converter, ActiveUserState, activeMarker } from "../user-state"; +import { ActiveUserState, activeMarker } from "../user-state"; -import { DefaultDerivedUserState } from "./default-derived-state"; import { getStoredValue } from "./util"; const FAKE_DEFAULT = Symbol("fakeDefault"); @@ -91,10 +89,6 @@ export class DefaultActiveUserState implements ActiveUserState { return await getStoredValue(key, this.chosenStorageLocation, this.keyDefinition.deserializer); } - createDerived(converter: Converter): DerivedUserState { - return new DefaultDerivedUserState(converter, this.encryptService, this); - } - private async internalUpdate( configureState: (state: T, dependency: TCombine) => T, options: StateUpdateOptions, diff --git a/libs/common/src/platform/state/implementations/default-derived-state.provider.ts b/libs/common/src/platform/state/implementations/default-derived-state.provider.ts new file mode 100644 index 0000000000..824a81d2cf --- /dev/null +++ b/libs/common/src/platform/state/implementations/default-derived-state.provider.ts @@ -0,0 +1,49 @@ +import { Observable } from "rxjs"; + +import { DerivedStateDependencies, ShapeToInstances } from "../../../types/state"; +import { + AbstractStorageService, + ObservableStorageService, +} from "../../abstractions/storage.service"; +import { DeriveDefinition } from "../derive-definition"; +import { DerivedState } from "../derived-state"; +import { DerivedStateProvider } from "../derived-state.provider"; + +import { DefaultDerivedState } from "./default-derived-state"; + +export class DefaultDerivedStateProvider implements DerivedStateProvider { + private cache: Record> = {}; + + constructor(protected memoryStorage: AbstractStorageService & ObservableStorageService) {} + + get( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ): DerivedState { + const cacheKey = deriveDefinition.buildCacheKey(); + const existingDerivedState = this.cache[cacheKey]; + if (existingDerivedState != null) { + // I have to cast out of the unknown generic but this should be safe if rules + // around domain token are made + return existingDerivedState as DefaultDerivedState; + } + + const newDerivedState = this.buildDerivedState(parentState$, deriveDefinition, dependencies); + this.cache[cacheKey] = newDerivedState; + return newDerivedState; + } + + protected buildDerivedState( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ): DerivedState { + return new DefaultDerivedState( + parentState$, + deriveDefinition, + this.memoryStorage, + dependencies, + ); + } +} diff --git a/libs/common/src/platform/state/implementations/default-derived-state.spec.ts b/libs/common/src/platform/state/implementations/default-derived-state.spec.ts new file mode 100644 index 0000000000..b4d6ba03e1 --- /dev/null +++ b/libs/common/src/platform/state/implementations/default-derived-state.spec.ts @@ -0,0 +1,222 @@ +/** + * need to update test environment so trackEmissions works appropriately + * @jest-environment ../shared/test.environment.ts + */ +import { Subject, firstValueFrom } from "rxjs"; + +import { awaitAsync, trackEmissions } from "../../../../spec"; +import { FakeStorageService } from "../../../../spec/fake-storage.service"; +import { DeriveDefinition } from "../derive-definition"; +import { StateDefinition } from "../state-definition"; + +import { DefaultDerivedState } from "./default-derived-state"; + +let callCount = 0; +const cleanupDelayMs = 10; +const stateDefinition = new StateDefinition("test", "memory"); +const deriveDefinition = new DeriveDefinition( + stateDefinition, + "test", + { + derive: (dateString: string) => { + callCount++; + return new Date(dateString); + }, + deserializer: (dateString: string) => new Date(dateString), + cleanupDelayMs, + }, +); + +describe("DefaultDerivedState", () => { + let parentState$: Subject; + let memoryStorage: FakeStorageService; + let sut: DefaultDerivedState; + const deps = { + date: new Date(), + }; + + beforeEach(() => { + callCount = 0; + parentState$ = new Subject(); + memoryStorage = new FakeStorageService(); + sut = new DefaultDerivedState(parentState$, deriveDefinition, memoryStorage, deps); + }); + + afterEach(() => { + parentState$.complete(); + jest.resetAllMocks(); + }); + + it("should derive the state", async () => { + const dateString = "2020-01-01"; + const emissions = trackEmissions(sut.state$); + + parentState$.next(dateString); + await awaitAsync(); + + expect(emissions).toEqual([new Date(dateString)]); + }); + + it("should derive the state once", async () => { + const dateString = "2020-01-01"; + trackEmissions(sut.state$); + + parentState$.next(dateString); + + expect(callCount).toBe(1); + }); + + it("should store the derived state in memory", async () => { + const dateString = "2020-01-01"; + trackEmissions(sut.state$); + parentState$.next(dateString); + await awaitAsync(); + + expect(memoryStorage.internalStore[deriveDefinition.buildCacheKey()]).toEqual( + new Date(dateString), + ); + const calls = memoryStorage.mock.save.mock.calls; + expect(calls.length).toBe(1); + expect(calls[0][0]).toBe(deriveDefinition.buildCacheKey()); + expect(calls[0][1]).toEqual(new Date(dateString)); + }); + + describe("forceValue", () => { + const initialParentValue = "2020-01-01"; + const forced = new Date("2020-02-02"); + let emissions: Date[]; + + describe("without observers", () => { + beforeEach(async () => { + parentState$.next(initialParentValue); + await awaitAsync(); + }); + + it("should store the forced value", async () => { + await sut.forceValue(forced); + expect(memoryStorage.internalStore[deriveDefinition.buildCacheKey()]).toEqual(forced); + }); + }); + + describe("with observers", () => { + beforeEach(async () => { + emissions = trackEmissions(sut.state$); + parentState$.next(initialParentValue); + await awaitAsync(); + }); + + it("should store the forced value", async () => { + await sut.forceValue(forced); + expect(memoryStorage.internalStore[deriveDefinition.buildCacheKey()]).toEqual(forced); + }); + + it("should force the value", async () => { + await sut.forceValue(forced); + expect(emissions).toEqual([new Date(initialParentValue), forced]); + }); + + it("should only force the value once", async () => { + await sut.forceValue(forced); + + parentState$.next(initialParentValue); + await awaitAsync(); + + expect(emissions).toEqual([ + new Date(initialParentValue), + forced, + new Date(initialParentValue), + ]); + }); + }); + }); + + describe("cleanup", () => { + const newDate = "2020-02-02"; + + it("should cleanup after last subscriber", async () => { + const subscription = sut.state$.subscribe(); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + expect(parentState$.observed).toBe(false); + }); + + it("should not cleanup if there are still subscribers", async () => { + const subscription1 = sut.state$.subscribe(); + const sub2Emissions: Date[] = []; + const subscription2 = sut.state$.subscribe((v) => sub2Emissions.push(v)); + await awaitAsync(); + + subscription1.unsubscribe(); + + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + // Still be listening to parent updates + parentState$.next(newDate); + await awaitAsync(); + expect(sub2Emissions).toEqual([new Date(newDate)]); + + subscription2.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + expect(parentState$.observed).toBe(false); + }); + + it("can re-initialize after cleanup", async () => { + const subscription = sut.state$.subscribe(); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + const emissions = trackEmissions(sut.state$); + await awaitAsync(); + + parentState$.next(newDate); + await awaitAsync(); + + expect(emissions).toEqual([new Date(newDate)]); + }); + + it("should not cleanup if a subscriber joins during the cleanup delay", async () => { + const subscription = sut.state$.subscribe(); + await awaitAsync(); + + await parentState$.next(newDate); + await awaitAsync(); + + subscription.unsubscribe(); + // Do not wait long enough for cleanup + await awaitAsync(cleanupDelayMs / 2); + + expect(parentState$.observed).toBe(true); // still listening to parent + + const emissions = trackEmissions(sut.state$); + expect(emissions).toEqual([new Date(newDate)]); // we didn't lose our buffered value + }); + + it("state$ observables are durable to cleanup", async () => { + const observable = sut.state$; + let subscription = observable.subscribe(); + + await parentState$.next(newDate); + await awaitAsync(); + + subscription.unsubscribe(); + // Wait for cleanup + await awaitAsync(cleanupDelayMs * 2); + + subscription = observable.subscribe(); + await parentState$.next(newDate); + await awaitAsync(); + + expect(await firstValueFrom(observable)).toEqual(new Date(newDate)); + }); + }); +}); diff --git a/libs/common/src/platform/state/implementations/default-derived-state.ts b/libs/common/src/platform/state/implementations/default-derived-state.ts index 2c5f75f2fc..bc9101fb3d 100644 --- a/libs/common/src/platform/state/implementations/default-derived-state.ts +++ b/libs/common/src/platform/state/implementations/default-derived-state.ts @@ -1,23 +1,57 @@ -import { Observable, switchMap } from "rxjs"; +import { Observable, ReplaySubject, Subject, concatMap, merge, share, timer } from "rxjs"; -import { EncryptService } from "../../abstractions/encrypt.service"; -import { DerivedUserState } from "../derived-user-state"; -import { Converter, DeriveContext, UserState } from "../user-state"; +import { ShapeToInstances, DerivedStateDependencies } from "../../../types/state"; +import { + AbstractStorageService, + ObservableStorageService, +} from "../../abstractions/storage.service"; +import { DeriveDefinition } from "../derive-definition"; +import { DerivedState } from "../derived-state"; + +/** + * Default derived state + */ +export class DefaultDerivedState + implements DerivedState +{ + private readonly storageKey: string; + private forcedValueSubject = new Subject(); -export class DefaultDerivedUserState implements DerivedUserState { state$: Observable; constructor( - private converter: Converter, - private encryptService: EncryptService, - private userState: UserState, + private parentState$: Observable, + protected deriveDefinition: DeriveDefinition, + private memoryStorage: AbstractStorageService & ObservableStorageService, + private dependencies: ShapeToInstances, ) { - this.state$ = userState.state$.pipe( - switchMap(async (from) => { - // TODO: How do I get the key? - const convertedData = await this.converter(from, new DeriveContext(null, encryptService)); - return convertedData; + this.storageKey = deriveDefinition.storageKey; + + const derivedState$ = this.parentState$.pipe( + concatMap(async (state) => { + let derivedStateOrPromise = this.deriveDefinition.derive(state, this.dependencies); + if (derivedStateOrPromise instanceof Promise) { + derivedStateOrPromise = await derivedStateOrPromise; + } + const derivedState = derivedStateOrPromise; + await this.memoryStorage.save(this.storageKey, derivedState); + return derivedState; + }), + ); + + this.state$ = merge(this.forcedValueSubject, derivedState$).pipe( + share({ + connector: () => { + return new ReplaySubject(1); + }, + resetOnRefCountZero: () => timer(this.deriveDefinition.cleanupDelayMs), }), ); } + + async forceValue(value: TTo) { + await this.memoryStorage.save(this.storageKey, value); + this.forcedValueSubject.next(value); + return value; + } } diff --git a/libs/common/src/platform/state/implementations/default-single-user-state.ts b/libs/common/src/platform/state/implementations/default-single-user-state.ts index 4c7c70d426..84940493e7 100644 --- a/libs/common/src/platform/state/implementations/default-single-user-state.ts +++ b/libs/common/src/platform/state/implementations/default-single-user-state.ts @@ -14,12 +14,10 @@ import { AbstractStorageService, ObservableStorageService, } from "../../abstractions/storage.service"; -import { DerivedUserState } from "../derived-user-state"; import { KeyDefinition, userKeyBuilder } from "../key-definition"; import { StateUpdateOptions, populateOptionsWithDefault } from "../state-update-options"; -import { Converter, SingleUserState } from "../user-state"; +import { SingleUserState } from "../user-state"; -import { DefaultDerivedUserState } from "./default-derived-state"; import { getStoredValue } from "./util"; const FAKE_DEFAULT = Symbol("fakeDefault"); @@ -68,10 +66,6 @@ export class DefaultSingleUserState implements SingleUserState { } } - createDerived(converter: Converter): DerivedUserState { - return new DefaultDerivedUserState(converter, this.encryptService, this); - } - private async internalUpdate( configureState: (state: T, dependency: TCombine) => T, options: StateUpdateOptions, diff --git a/libs/common/src/platform/state/implementations/default-state.provider.spec.ts b/libs/common/src/platform/state/implementations/default-state.provider.spec.ts index 16af060dc1..4c86a1b8fd 100644 --- a/libs/common/src/platform/state/implementations/default-state.provider.spec.ts +++ b/libs/common/src/platform/state/implementations/default-state.provider.spec.ts @@ -1,9 +1,13 @@ +import { of } from "rxjs"; + import { FakeActiveUserStateProvider, + FakeDerivedStateProvider, FakeGlobalStateProvider, FakeSingleUserStateProvider, } from "../../../../spec/fake-state-provider"; import { UserId } from "../../../types/guid"; +import { DeriveDefinition } from "../derive-definition"; import { KeyDefinition } from "../key-definition"; import { StateDefinition } from "../state-definition"; @@ -14,15 +18,18 @@ describe("DefaultStateProvider", () => { let activeUserStateProvider: FakeActiveUserStateProvider; let singleUserStateProvider: FakeSingleUserStateProvider; let globalStateProvider: FakeGlobalStateProvider; + let derivedStateProvider: FakeDerivedStateProvider; beforeEach(() => { activeUserStateProvider = new FakeActiveUserStateProvider(); singleUserStateProvider = new FakeSingleUserStateProvider(); globalStateProvider = new FakeGlobalStateProvider(); + derivedStateProvider = new FakeDerivedStateProvider(); sut = new DefaultStateProvider( activeUserStateProvider, singleUserStateProvider, globalStateProvider, + derivedStateProvider, ); }); @@ -53,4 +60,15 @@ describe("DefaultStateProvider", () => { const actual = sut.getGlobal(keyDefinition); expect(actual).toBe(existing); }); + + it("should bind the derivedStateProvider", () => { + const derivedDefinition = new DeriveDefinition(new StateDefinition("test", "disk"), "test", { + derive: () => null, + deserializer: () => null, + }); + const parentState$ = of(null); + const existing = derivedStateProvider.get(parentState$, derivedDefinition, {}); + const actual = sut.getDerived(parentState$, derivedDefinition, {}); + expect(actual).toBe(existing); + }); }); diff --git a/libs/common/src/platform/state/implementations/default-state.provider.ts b/libs/common/src/platform/state/implementations/default-state.provider.ts index 5641f80291..7962739ecf 100644 --- a/libs/common/src/platform/state/implementations/default-state.provider.ts +++ b/libs/common/src/platform/state/implementations/default-state.provider.ts @@ -1,3 +1,9 @@ +import { Observable } from "rxjs"; + +import { ShapeToInstances, DerivedStateDependencies } from "../../../types/state"; +import { DeriveDefinition } from "../derive-definition"; +import { DerivedState } from "../derived-state"; +import { DerivedStateProvider } from "../derived-state.provider"; import { GlobalStateProvider } from "../global-state.provider"; import { StateProvider } from "../state.provider"; import { ActiveUserStateProvider, SingleUserStateProvider } from "../user-state.provider"; @@ -7,6 +13,7 @@ export class DefaultStateProvider implements StateProvider { private readonly activeUserStateProvider: ActiveUserStateProvider, private readonly singleUserStateProvider: SingleUserStateProvider, private readonly globalStateProvider: GlobalStateProvider, + private readonly derivedStateProvider: DerivedStateProvider, ) {} getActive: InstanceType["get"] = @@ -16,4 +23,9 @@ export class DefaultStateProvider implements StateProvider { getGlobal: InstanceType["get"] = this.globalStateProvider.get.bind( this.globalStateProvider, ); + getDerived: ( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ) => DerivedState = this.derivedStateProvider.get.bind(this.derivedStateProvider); } diff --git a/libs/common/src/platform/state/index.ts b/libs/common/src/platform/state/index.ts index 4c7347cb1c..377f91395b 100644 --- a/libs/common/src/platform/state/index.ts +++ b/libs/common/src/platform/state/index.ts @@ -1,4 +1,6 @@ -export { DerivedUserState } from "./derived-user-state"; +export { DeriveDefinition } from "./derive-definition"; +export { DerivedStateProvider } from "./derived-state.provider"; +export { DerivedState } from "./derived-state"; export { GlobalState } from "./global-state"; export { StateProvider } from "./state.provider"; export { GlobalStateProvider } from "./global-state.provider"; diff --git a/libs/common/src/platform/state/key-definition.ts b/libs/common/src/platform/state/key-definition.ts index 9989bf37a2..47b83a4e88 100644 --- a/libs/common/src/platform/state/key-definition.ts +++ b/libs/common/src/platform/state/key-definition.ts @@ -1,6 +1,7 @@ -import { Jsonify, Opaque } from "type-fest"; +import { Jsonify } from "type-fest"; import { UserId } from "../../types/guid"; +import { StorageKey } from "../../types/state"; import { Utils } from "../misc/utils"; import { StateDefinition } from "./state-definition"; @@ -159,8 +160,6 @@ export class KeyDefinition { } } -export type StorageKey = Opaque; - /** * Creates a {@link StorageKey} that points to the data at the given key definition for the specified user. * @param userId The userId of the user you want the key to be for. diff --git a/libs/common/src/platform/state/state.provider.ts b/libs/common/src/platform/state/state.provider.ts index 06bc3a9d90..b234c9cde2 100644 --- a/libs/common/src/platform/state/state.provider.ts +++ b/libs/common/src/platform/state/state.provider.ts @@ -1,5 +1,10 @@ -import { UserId } from "../../types/guid"; +import { Observable } from "rxjs"; +import { UserId } from "../../types/guid"; +import { ShapeToInstances, DerivedStateDependencies } from "../../types/state"; + +import { DeriveDefinition } from "./derive-definition"; +import { DerivedState } from "./derived-state"; import { GlobalState } from "./global-state"; // eslint-disable-next-line @typescript-eslint/no-unused-vars -- used in docs import { GlobalStateProvider } from "./global-state.provider"; @@ -18,4 +23,9 @@ export abstract class StateProvider { getUser: (userId: UserId, keyDefinition: KeyDefinition) => SingleUserState; /** @see{@link GlobalStateProvider.get} */ getGlobal: (keyDefinition: KeyDefinition) => GlobalState; + getDerived: ( + parentState$: Observable, + deriveDefinition: DeriveDefinition, + dependencies: ShapeToInstances, + ) => DerivedState; } diff --git a/libs/common/src/platform/state/user-state.ts b/libs/common/src/platform/state/user-state.ts index d6e5ab3109..7b5ab8a2fd 100644 --- a/libs/common/src/platform/state/user-state.ts +++ b/libs/common/src/platform/state/user-state.ts @@ -1,22 +1,9 @@ import { Observable } from "rxjs"; import { UserId } from "../../types/guid"; -import { EncryptService } from "../abstractions/encrypt.service"; -import { UserKey } from "../models/domain/symmetric-crypto-key"; import { StateUpdateOptions } from "./state-update-options"; -import { DerivedUserState } from "."; - -export class DeriveContext { - constructor( - readonly activeUserKey: UserKey, - readonly encryptService: EncryptService, - ) {} -} - -export type Converter = (data: TFrom, context: DeriveContext) => Promise; - /** * A helper object for interacting with state that is scoped to a specific user. */ @@ -37,13 +24,6 @@ export interface UserState { configureState: (state: T, dependencies: TCombine) => T, options?: StateUpdateOptions, ) => Promise; - - /** - * Creates a derives state from the current state. Derived states are always tied to the active user. - * @param converter - * @returns - */ - createDerived: (converter: Converter) => DerivedUserState; } export const activeMarker: unique symbol = Symbol("active"); diff --git a/libs/common/src/types/state.d.ts b/libs/common/src/types/state.d.ts new file mode 100644 index 0000000000..fea0c7fee0 --- /dev/null +++ b/libs/common/src/types/state.d.ts @@ -0,0 +1,17 @@ +import { Opaque } from "type-fest"; + +type StorageKey = Opaque; + +/** + * A helper type defining Constructor types for javascript and `typeof T` types for Typescript + */ +type Type = abstract new (...args: unknown[]) => T; + +type DerivedStateDependencies = Record>; + +/** + * Converts an object of types to an object of instances + */ +type ShapeToInstances = { + [P in keyof T]: T[P] extends Type ? R : never; +};