diff --git a/apps/browser/src/content/webauthn/messaging/message.ts b/apps/browser/src/content/webauthn/messaging/message.ts index ffde834678..d77aea0b5b 100644 --- a/apps/browser/src/content/webauthn/messaging/message.ts +++ b/apps/browser/src/content/webauthn/messaging/message.ts @@ -40,10 +40,12 @@ export type CredentialGetResponse = { export type AbortRequest = { type: MessageType.AbortRequest; + abortedRequestId: string; }; export type AbortResponse = { type: MessageType.AbortResponse; + abortedRequestId: string; }; export type Message = diff --git a/apps/browser/src/content/webauthn/messaging/messenger.spec.ts b/apps/browser/src/content/webauthn/messaging/messenger.spec.ts index e0482cad68..855860f5e4 100644 --- a/apps/browser/src/content/webauthn/messaging/messenger.spec.ts +++ b/apps/browser/src/content/webauthn/messaging/messenger.spec.ts @@ -18,8 +18,8 @@ describe("Messenger", () => { handlerA = new TestMessageHandler(); handlerB = new TestMessageHandler(); - messengerA.addHandler(handlerA.handler); - messengerB.addHandler(handlerB.handler); + messengerA.handler = handlerA.handler; + messengerB.handler = handlerB.handler; }); it("should deliver message to B when sending request from A", () => { @@ -43,6 +43,25 @@ describe("Messenger", () => { expect(returned).toMatchObject(response); }); + + it("should deliver abort signal to B when requesting abort", () => { + const abortController = new AbortController(); + messengerA.request(createRequest(), abortController); + abortController.abort(); + + const received = handlerB.recieve(); + + expect(received[0].abortController.signal.aborted).toBe(true); + }); + + it.skip("should abort request and throw error when abort is requested from A", () => { + const abortController = new AbortController(); + const requestPromise = messengerA.request(createRequest(), abortController); + + abortController.abort(); + + expect(requestPromise).toThrow(); + }); }); type TestMessage = Message & { testId: string }; @@ -78,16 +97,23 @@ class TestChannelPair { } class TestMessageHandler { - readonly handler: (message: TestMessage) => Promise; + readonly handler: ( + message: TestMessage, + abortController?: AbortController + ) => Promise; - private recievedMessages: { message: TestMessage; respond: (response: TestMessage) => void }[] = - []; + private recievedMessages: { + message: TestMessage; + respond: (response: TestMessage) => void; + abortController?: AbortController; + }[] = []; constructor() { - this.handler = (message) => + this.handler = (message, abortController) => new Promise((resolve, reject) => { this.recievedMessages.push({ message, + abortController, respond: (response) => resolve(response), }); }); diff --git a/apps/browser/src/content/webauthn/messaging/messenger.ts b/apps/browser/src/content/webauthn/messaging/messenger.ts index 29faf61761..68af7b2995 100644 --- a/apps/browser/src/content/webauthn/messaging/messenger.ts +++ b/apps/browser/src/content/webauthn/messaging/messenger.ts @@ -1,6 +1,6 @@ import { concatMap, filter, firstValueFrom, Observable } from "rxjs"; -import { Message } from "./message"; +import { Message, MessageType } from "./message"; type PostMessageFunction = (message: MessageWithMetadata) => void; @@ -11,6 +11,10 @@ export type Channel = { export type Metadata = { requestId: string }; export type MessageWithMetadata = Message & { metadata: Metadata }; +type Handler = ( + message: Message, + abortController?: AbortController +) => Promise; // TODO: This class probably duplicates functionality but I'm not especially familiar with // the inner workings of the browser extension yet. @@ -32,9 +36,42 @@ export class Messenger { }); } - constructor(private channel: Channel) {} + handler?: Handler; + abortControllers = new Map(); - request(request: Message): Promise { + constructor(private channel: Channel) { + this.channel.messages$ + .pipe( + concatMap(async (message) => { + if (this.handler === undefined) { + return; + } + + const abortController = new AbortController(); + this.abortControllers.set(message.metadata.requestId, abortController); + const handlerResponse = await this.handler(message, abortController); + this.abortControllers.delete(message.metadata.requestId); + + if (handlerResponse === undefined) { + return; + } + + const metadata: Metadata = { requestId: message.metadata.requestId }; + this.channel.postMessage({ ...handlerResponse, metadata }); + }) + ) + .subscribe(); + + this.channel.messages$.subscribe((message) => { + if (message.type !== MessageType.AbortRequest) { + return; + } + + this.abortControllers.get(message.abortedRequestId)?.abort(); + }); + } + + request(request: Message, abortController?: AbortController): Promise { const requestId = Date.now().toString(); const metadata: Metadata = { requestId }; @@ -46,25 +83,18 @@ export class Messenger { ) ); + const abortListener = () => + this.channel.postMessage({ + metadata: { requestId: `${requestId}-abort` }, + type: MessageType.AbortRequest, + abortedRequestId: requestId, + }); + abortController?.signal.addEventListener("abort", abortListener); + this.channel.postMessage({ ...request, metadata }); - return promise; - } - - addHandler(handler: (message: Message) => Promise) { - this.channel.messages$ - .pipe( - concatMap(async (message) => { - const handlerResponse = await handler(message); - - if (handlerResponse === undefined) { - return; - } - - const metadata: Metadata = { requestId: message.metadata.requestId }; - this.channel.postMessage({ ...handlerResponse, metadata }); - }) - ) - .subscribe(); + return promise.finally(() => + abortController?.signal.removeEventListener("abort", abortListener) + ); } }