diff --git a/cmd/generate/main-generate.go b/cmd/generate/main-generate.go index dd37f9efa..449fcef03 100644 --- a/cmd/generate/main-generate.go +++ b/cmd/generate/main-generate.go @@ -8,11 +8,12 @@ import ( "os" "reflect" "sort" + "strings" "github.com/wavetermdev/thenextwave/pkg/service" "github.com/wavetermdev/thenextwave/pkg/tsgen" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" - "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshserver" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" ) func generateTypesFile(tsTypesMap map[reflect.Type]string) error { @@ -43,6 +44,10 @@ func generateTypesFile(tsTypesMap map[reflect.Type]string) error { return iname < jname }) for _, key := range keys { + // don't output generic types + if strings.Index(key.Name(), "[") != -1 { + continue + } tsCode := tsTypesMap[key] istr := utilfn.IndentString(" ", tsCode) fmt.Fprint(fd, istr) @@ -79,16 +84,17 @@ func generateWshServerFile(tsTypeMap map[reflect.Type]string) error { return err } defer fd.Close() + declMap := wshrpc.GenerateWshCommandDeclMap() fmt.Fprintf(os.Stderr, "generating wshserver file to %s\n", fd.Name()) fmt.Fprintf(fd, "// Copyright 2024, Command Line Inc.\n") fmt.Fprintf(fd, "// SPDX-License-Identifier: Apache-2.0\n\n") fmt.Fprintf(fd, "// generated by cmd/generate/main-generate.go\n\n") fmt.Fprintf(fd, "import * as WOS from \"./wos\";\n\n") - orderedKeys := utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) + orderedKeys := utilfn.GetOrderedMapKeys(declMap) fmt.Fprintf(fd, "// WshServerCommandToDeclMap\n") fmt.Fprintf(fd, "class WshServerType {\n") for _, methodDecl := range orderedKeys { - methodDecl := wshserver.WshServerCommandToDeclMap[methodDecl] + methodDecl := declMap[methodDecl] methodStr := tsgen.GenerateWshServerMethod(methodDecl, tsTypeMap) fmt.Fprint(fd, methodStr) fmt.Fprintf(fd, "\n") diff --git a/cmd/generatewshclient/main-generatewshclient.go b/cmd/generatewshclient/main-generatewshclient.go index c63036e8c..5c6ee5d9b 100644 --- a/cmd/generatewshclient/main-generatewshclient.go +++ b/cmd/generatewshclient/main-generatewshclient.go @@ -8,11 +8,10 @@ import ( "os" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" - "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshserver" - "github.com/wavetermdev/thenextwave/pkg/wshutil" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" ) -func genMethod_ResponseStream(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) { +func genMethod_ResponseStream(fd *os.File, methodDecl *wshrpc.WshRpcMethodDecl) { fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName) var dataType string dataVarName := "nil" @@ -29,7 +28,7 @@ func genMethod_ResponseStream(fd *os.File, methodDecl *wshserver.WshServerMethod fmt.Fprintf(fd, "}\n\n") } -func genMethod_Call(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) { +func genMethod_Call(fd *os.File, methodDecl *wshrpc.WshRpcMethodDecl) { fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName) var dataType string dataVarName := "nil" @@ -70,14 +69,14 @@ func main() { fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n") - fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveai\"\n") fmt.Fprintf(fd, ")\n\n") - for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) { - methodDecl := wshserver.WshServerCommandToDeclMap[key] - if methodDecl.CommandType == wshutil.RpcType_ResponseStream { + wshDeclMap := wshrpc.GenerateWshCommandDeclMap() + for _, key := range utilfn.GetOrderedMapKeys(wshDeclMap) { + methodDecl := wshDeclMap[key] + if methodDecl.CommandType == wshrpc.RpcType_ResponseStream { genMethod_ResponseStream(fd, methodDecl) - } else if methodDecl.CommandType == wshutil.RpcType_Call { + } else if methodDecl.CommandType == wshrpc.RpcType_Call { genMethod_Call(fd, methodDecl) } else { panic("unsupported command type " + methodDecl.CommandType) diff --git a/cmd/wsh/cmd/wshcmd-readfile.go b/cmd/wsh/cmd/wshcmd-readfile.go index 6c6129eb7..f04beedc7 100644 --- a/cmd/wsh/cmd/wshcmd-readfile.go +++ b/cmd/wsh/cmd/wshcmd-readfile.go @@ -43,7 +43,7 @@ func runReadFile(cmd *cobra.Command, args []string) { fmt.Fprintf(os.Stderr, "error resolving oref: %v\r\n", err) return } - resp64, err := wshclient.ReadFile(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.WshRpcCommandOpts{Timeout: 5000}) + resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.WshRpcCommandOpts{Timeout: 5000}) if err != nil { fmt.Fprintf(os.Stderr, "error reading file: %v\r\n", err) return diff --git a/emain/emain.ts b/emain/emain.ts index 0568383a1..31d0944db 100644 --- a/emain/emain.ts +++ b/emain/emain.ts @@ -34,6 +34,7 @@ const waveSrvReady: Promise = new Promise((resolve, _) => { }); let globalIsQuitting = false; let globalIsStarting = true; +let globalIsRelaunching = false; const isDev = !electronApp.isPackaged; const isDevVite = isDev && process.env.ELECTRON_RENDERER_URL; @@ -214,7 +215,8 @@ async function handleWSEvent(evtMsg: WSEventType) { return; } const clientData = await services.ClientService.GetClientData(); - const newWin = createBrowserWindow(clientData.oid, windowData); + const settings = await services.FileService.GetSettingsConfig(); + const newWin = createBrowserWindow(clientData.oid, windowData, settings); await newWin.readyPromise; newWin.show(); } else if (evtMsg.eventtype == "electron:closewindow") { @@ -290,7 +292,11 @@ function shFrameNavHandler(event: Electron.Event void; (bwin as any).readyPromise = new Promise((resolve, _) => { @@ -519,7 +531,8 @@ electron.ipcMain.on("getEnv", (event, varName) => { async function createNewWaveWindow() { const clientData = await services.ClientService.GetClientData(); const newWindow = await services.ClientService.MakeWindow(); - const newBrowserWindow = createBrowserWindow(clientData.oid, newWindow); + const settings = await services.FileService.GetSettingsConfig(); + const newBrowserWindow = createBrowserWindow(clientData.oid, newWindow, settings); newBrowserWindow.show(); } @@ -616,6 +629,12 @@ function makeAppMenu() { { role: "forceReload", }, + { + label: "Relaunch All Windows", + click: () => { + relaunchBrowserWindows(); + }, + }, { role: "toggleDevTools", }, @@ -663,6 +682,9 @@ function makeAppMenu() { } electronApp.on("window-all-closed", () => { + if (globalIsRelaunching) { + return; + } if (unamePlatform !== "darwin") { electronApp.quit(); } @@ -857,6 +879,36 @@ async function configureAutoUpdater() { } // ====== AUTO-UPDATER ====== // +async function relaunchBrowserWindows() { + globalIsRelaunching = true; + const windows = electron.BrowserWindow.getAllWindows(); + for (const window of windows) { + window.removeAllListeners(); + window.close(); + } + globalIsRelaunching = false; + + const clientData = await services.ClientService.GetClientData(); + const settings = await services.FileService.GetSettingsConfig(); + const wins: WaveBrowserWindow[] = []; + for (const windowId of clientData.windowids.slice().reverse()) { + const windowData: WaveWindow = (await services.ObjectService.GetObject("window:" + windowId)) as WaveWindow; + if (windowData == null) { + services.WindowService.CloseWindow(windowId).catch((e) => { + /* ignore */ + }); + continue; + } + const win = createBrowserWindow(clientData.oid, windowData, settings); + wins.push(win); + } + for (const win of wins) { + await win.readyPromise; + console.log("show", win.waveWindowId); + win.show(); + } +} + async function appMain() { const startTs = Date.now(); const instanceLock = electronApp.requestSingleInstanceLock(); @@ -877,27 +929,8 @@ async function appMain() { } const ready = await waveSrvReady; console.log("wavesrv ready signal received", ready, Date.now() - startTs, "ms"); - console.log("get client data"); - const clientData = await services.ClientService.GetClientData(); - console.log("client data ready"); await electronApp.whenReady(); - const wins: WaveBrowserWindow[] = []; - for (const windowId of clientData.windowids.slice().reverse()) { - const windowData: WaveWindow = (await services.ObjectService.GetObject("window:" + windowId)) as WaveWindow; - if (windowData == null) { - services.WindowService.CloseWindow(windowId).catch((e) => { - /* ignore */ - }); - continue; - } - const win = createBrowserWindow(clientData.oid, windowData); - wins.push(win); - } - for (const win of wins) { - await win.readyPromise; - console.log("show", win.waveWindowId); - win.show(); - } + relaunchBrowserWindows(); configureAutoUpdater(); globalIsStarting = false; diff --git a/frontend/app/app.less b/frontend/app/app.less index 3f1439bbb..eee11170a 100644 --- a/frontend/app/app.less +++ b/frontend/app/app.less @@ -14,6 +14,8 @@ body { font: var(--base-font); overflow: hidden; -webkit-font-smoothing: auto; + backface-visibility: hidden; + transform: translateZ(0); } *::-webkit-scrollbar { diff --git a/frontend/app/app.tsx b/frontend/app/app.tsx index 55809b92a..6a98c550f 100644 --- a/frontend/app/app.tsx +++ b/frontend/app/app.tsx @@ -16,6 +16,7 @@ import { HTML5Backend } from "react-dnd-html5-backend"; import { CenteredDiv } from "./element/quickelems"; import clsx from "clsx"; +import Color from "color"; import "overlayscrollbars/overlayscrollbars.css"; import "./app.less"; @@ -200,6 +201,31 @@ function switchBlock(tabId: string, offsetX: number, offsetY: number) { } } +function AppSettingsUpdater() { + const settings = jotai.useAtomValue(atoms.settingsConfigAtom); + React.useEffect(() => { + let isTransparent = settings?.window?.transparent ?? true; + let opacity = util.boundNumber(settings?.window?.opacity ?? 0.8, 0, 1); + let baseBgColor = settings?.window?.bgcolor; + console.log("window settings", settings.window); + + if (isTransparent) { + document.body.classList.add("is-transparent"); + const rootStyles = getComputedStyle(document.documentElement); + if (baseBgColor == null) { + baseBgColor = rootStyles.getPropertyValue("--main-bg-color").trim(); + } + const color = new Color(baseBgColor); + const rgbaColor = color.alpha(opacity).string(); + document.body.style.backgroundColor = rgbaColor; + } else { + document.body.classList.remove("is-transparent"); + document.body.style.opacity = null; + } + }, [settings?.window]); + return null; +} + const AppInner = () => { const client = jotai.useAtomValue(atoms.client); const windowData = jotai.useAtomValue(atoms.waveWindow); @@ -251,6 +277,7 @@ const AppInner = () => { const isFullScreen = jotai.useAtomValue(atoms.isFullScreen); return (
+ diff --git a/frontend/app/block/block.less b/frontend/app/block/block.less index a99ec2445..a13a6f0fb 100644 --- a/frontend/app/block/block.less +++ b/frontend/app/block/block.less @@ -59,7 +59,7 @@ padding: 2px; .block-frame-default-inner { - background-color: rgba(0, 0, 0, 0.5); + background-color: var(--block-bg-color); width: 100%; height: 100%; border-radius: 8px; diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts index d2c399c29..6f4a6f48b 100644 --- a/frontend/app/store/wshserver.ts +++ b/frontend/app/store/wshserver.ts @@ -7,14 +7,19 @@ import * as WOS from "./wos"; // WshServerCommandToDeclMap class WshServerType { - // command "controller:input" [call] - BlockInputCommand(data: CommandBlockInputData, opts?: WshRpcCommandOpts): Promise { - return WOS.wshServerRpcHelper_call("controller:input", data, opts); + // command "authenticate" [call] + AuthenticateCommand(data: string, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("authenticate", data, opts); } - // command "controller:restart" [call] - BlockRestartCommand(data: CommandBlockRestartData, opts?: WshRpcCommandOpts): Promise { - return WOS.wshServerRpcHelper_call("controller:restart", data, opts); + // command "controllerinput" [call] + ControllerInputCommand(data: CommandBlockInputData, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("controllerinput", data, opts); + } + + // command "controllerrestart" [call] + ControllerRestartCommand(data: CommandBlockRestartData, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("controllerrestart", data, opts); } // command "createblock" [call] @@ -27,24 +32,49 @@ class WshServerType { return WOS.wshServerRpcHelper_call("deleteblock", data, opts); } - // command "file:append" [call] - AppendFileCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { - return WOS.wshServerRpcHelper_call("file:append", data, opts); + // command "eventpublish" [call] + EventPublishCommand(data: WaveEvent, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("eventpublish", data, opts); } - // command "file:appendijson" [call] - AppendIJsonCommand(data: CommandAppendIJsonData, opts?: WshRpcCommandOpts): Promise { - return WOS.wshServerRpcHelper_call("file:appendijson", data, opts); + // command "eventrecv" [call] + EventRecvCommand(data: WaveEvent, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("eventrecv", data, opts); } - // command "file:read" [call] - ReadFile(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { - return WOS.wshServerRpcHelper_call("file:read", data, opts); + // command "eventsub" [call] + EventSubCommand(data: SubscriptionRequest, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("eventsub", data, opts); } - // command "file:write" [call] - WriteFile(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { - return WOS.wshServerRpcHelper_call("file:write", data, opts); + // command "eventunsub" [call] + EventUnsubCommand(data: SubscriptionRequest, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("eventunsub", data, opts); + } + + // command "eventunsuball" [call] + EventUnsubAllCommand(opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("eventunsuball", null, opts); + } + + // command "fileappend" [call] + FileAppendCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("fileappend", data, opts); + } + + // command "fileappendijson" [call] + FileAppendIJsonCommand(data: CommandAppendIJsonData, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("fileappendijson", data, opts); + } + + // command "fileread" [call] + FileReadCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("fileread", data, opts); + } + + // command "filewrite" [call] + FileWriteCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { + return WOS.wshServerRpcHelper_call("filewrite", data, opts); } // command "getmeta" [call] @@ -68,18 +98,18 @@ class WshServerType { } // command "setview" [call] - BlockSetViewCommand(data: CommandBlockSetViewData, opts?: WshRpcCommandOpts): Promise { + SetViewCommand(data: CommandBlockSetViewData, opts?: WshRpcCommandOpts): Promise { return WOS.wshServerRpcHelper_call("setview", data, opts); } - // command "stream:waveai" [responsestream] - RespStreamWaveAi(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator { - return WOS.wshServerRpcHelper_responsestream("stream:waveai", data, opts); + // command "streamtest" [responsestream] + StreamTestCommand(opts?: WshRpcCommandOpts): AsyncGenerator { + return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts); } - // command "streamtest" [responsestream] - RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator { - return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts); + // command "streamwaveai" [responsestream] + StreamWaveAiCommand(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator { + return WOS.wshServerRpcHelper_responsestream("streamwaveai", data, opts); } } diff --git a/frontend/app/theme.less b/frontend/app/theme.less index 2ea471881..629ee9168 100644 --- a/frontend/app/theme.less +++ b/frontend/app/theme.less @@ -6,7 +6,7 @@ --title-font-size: 18px; --secondary-text-color: rgb(195, 200, 194); --grey-text-color: #666; - --main-bg-color: #454444; + --main-bg-color: rgb(34, 34, 34); --border-color: #333333; --base-font: normal 14px / normal "Inter", sans-serif; --fixed-font: normal 12px / normal "Hack", monospace; @@ -19,7 +19,7 @@ --warning-color: rgb(224, 185, 86); --success-color: rgb(78, 154, 6); --hover-bg-color: rgba(255, 255, 255, 0.1); - --block-bg-color: rgba(255, 255, 255, 0.05); + --block-bg-color: rgba(0, 0, 0, 0.5); /* scrollbar colors */ --scrollbar-background-color: transparent; diff --git a/frontend/app/view/term/term.tsx b/frontend/app/view/term/term.tsx index 0351b74aa..ef11249d0 100644 --- a/frontend/app/view/term/term.tsx +++ b/frontend/app/view/term/term.tsx @@ -214,7 +214,7 @@ const TerminalView = ({ blockId, model }: TerminalViewProps) => { } if (shellProcStatusRef.current != "running" && keyutil.checkKeyPressed(waveEvent, "Enter")) { // restart - WshServer.BlockRestartCommand({ blockid: blockId }); + WshServer.ControllerRestartCommand({ blockid: blockId }); return false; } } @@ -263,7 +263,7 @@ const TerminalView = ({ blockId, model }: TerminalViewProps) => { return false; } const b64data = btoa(asciiVal); - WshServer.BlockInputCommand({ blockid: blockId, inputdata64: b64data }); + WshServer.ControllerInputCommand({ blockid: blockId, inputdata64: b64data }); return true; }; diff --git a/frontend/app/view/term/termwrap.ts b/frontend/app/view/term/termwrap.ts index 060be56b9..eb5ebfa01 100644 --- a/frontend/app/view/term/termwrap.ts +++ b/frontend/app/view/term/termwrap.ts @@ -76,7 +76,7 @@ export class TermWrap { handleTermData(data: string) { const b64data = btoa(data); - WshServer.BlockInputCommand({ blockid: this.blockId, inputdata64: b64data }); + WshServer.ControllerInputCommand({ blockid: this.blockId, inputdata64: b64data }); } addFocusListener(focusFn: () => void) { diff --git a/frontend/app/view/waveai.tsx b/frontend/app/view/waveai.tsx index b6797ef9a..e22ff372b 100644 --- a/frontend/app/view/waveai.tsx +++ b/frontend/app/view/waveai.tsx @@ -137,7 +137,7 @@ export class WaveAiModel implements ViewModel { opts: opts, prompt: prompt, }; - const aiGen = WshServer.RespStreamWaveAi(beMsg); + const aiGen = WshServer.StreamWaveAiCommand(beMsg); let temp = async () => { let fullMsg = ""; for await (const msg of aiGen) { diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 95b25e619..6d0702d3a 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -187,7 +187,7 @@ declare global { // waveobj.ORef type ORef = string; - // waveai.OpenAIOptsType + // wshrpc.OpenAIOptsType type OpenAIOptsType = { model: string; apitoken: string; @@ -197,7 +197,7 @@ declare global { timeout?: number; }; - // waveai.OpenAIPacketType + // wshrpc.OpenAIPacketType type OpenAIPacketType = { type: string; model?: string; @@ -209,21 +209,21 @@ declare global { error?: string; }; - // waveai.OpenAIPromptMessageType + // wshrpc.OpenAIPromptMessageType type OpenAIPromptMessageType = { role: string; content: string; name?: string; }; - // waveai.OpenAIUsageType + // wshrpc.OpenAIUsageType type OpenAIUsageType = { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number; }; - // waveai.OpenAiStreamRequest + // wshrpc.OpenAiStreamRequest type OpenAiStreamRequest = { clientid?: string; opts: OpenAIOptsType; @@ -270,6 +270,7 @@ declare global { blockheader: BlockHeaderOpts; autoupdate: AutoUpdateOpts; termthemes: {[key: string]: TermThemeType}; + window: WindowSettingsType; }; // wstore.StickerClickOptsType @@ -293,6 +294,13 @@ declare global { display: StickerDisplayOptsType; }; + // wshrpc.SubscriptionRequest + type SubscriptionRequest = { + event: string; + scopes?: string[]; + allscopes?: boolean; + }; + // wstore.Tab type Tab = WaveObj & { name: string; @@ -428,6 +436,14 @@ declare global { error: string; }; + // wshrpc.WaveEvent + type WaveEvent = { + event: string; + scopes?: string[]; + sender?: string; + data?: any; + }; + // filestore.WaveFile type WaveFile = { zoneid: string; @@ -497,6 +513,13 @@ declare global { height: number; }; + // wconfig.WindowSettingsType + type WindowSettingsType = { + transparent: boolean; + opacity: number; + bgcolor: string; + }; + // wstore.Workspace type Workspace = WaveObj & { name: string; diff --git a/frontend/util/util.ts b/frontend/util/util.ts index fc669aa6b..01c025478 100644 --- a/frontend/util/util.ts +++ b/frontend/util/util.ts @@ -34,6 +34,10 @@ function base64ToArray(b64: string): Uint8Array { return rtnArr; } +function boundNumber(num: number, min: number, max: number): number { + return Math.min(Math.max(num, min), max); +} + // works for json-like objects (arrays, objects, strings, numbers, booleans) function jsonDeepEqual(v1: any, v2: any): boolean { if (v1 === v2) { @@ -193,6 +197,7 @@ function getCrypto() { export { base64ToArray, base64ToString, + boundNumber, fireAndForget, getCrypto, getPromiseState, diff --git a/frontend/wave.ts b/frontend/wave.ts index 8dd1c7577..3492c3658 100644 --- a/frontend/wave.ts +++ b/frontend/wave.ts @@ -27,6 +27,7 @@ loadFonts(); (window as any).globalWS = globalWS; (window as any).WOS = WOS; (window as any).globalStore = globalStore; +(window as any).globalAtoms = atoms; (window as any).WshServer = WshServer; (window as any).isFullScreen = false; diff --git a/package.json b/package.json index f28f6236c..3b7ea3c3f 100644 --- a/package.json +++ b/package.json @@ -76,11 +76,13 @@ "@table-nav/core": "^0.0.7", "@table-nav/react": "^0.0.7", "@tanstack/react-table": "^8.19.3", + "@types/color": "^3.0.6", "@xterm/addon-fit": "^0.10.0", "@xterm/addon-serialize": "^0.13.0", "@xterm/xterm": "^5.5.0", "base64-js": "^1.5.1", "clsx": "^2.1.1", + "color": "^4.2.3", "dayjs": "^1.11.12", "electron-updater": "6.3.1", "html-to-image": "^1.11.11", diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 40c71390f..c99be69c0 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -6,7 +6,9 @@ package blockcontroller import ( "bytes" "context" + "crypto/rand" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "io" @@ -21,12 +23,13 @@ import ( "github.com/wavetermdev/thenextwave/pkg/shellexec" "github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/waveobj" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" ) // set by main-server.go (for dependency inversion) -var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) = nil +var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) = nil const ( BlockController_Shell = "shell" @@ -205,6 +208,46 @@ func (bc *BlockController) resetTerminalState() { } } +func getMetaBool(meta map[string]any, key string, def bool) bool { + val, found := meta[key] + if !found { + return def + } + if val == nil { + return def + } + if bval, ok := val.(bool); ok { + return bval + } + return def +} + +func getMetaStr(meta map[string]any, key string, def string) string { + val, found := meta[key] + if !found { + return def + } + if val == nil { + return def + } + if sval, ok := val.(string); ok { + return sval + } + return def +} + +// every byte is 4-bits of randomness +func randomHexString(numHexDigits int) (string, error) { + numBytes := (numHexDigits + 1) / 2 // Calculate the number of bytes needed + bytes := make([]byte, numBytes) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + hexStr := hex.EncodeToString(bytes) + return hexStr[:numHexDigits], nil // Return the exact number of hex digits +} + func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta map[string]any) error { // create a circular blockfile for the output ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) @@ -232,12 +275,35 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta map[str if shellProcErr != nil { return shellProcErr } + var remoteDomainSocketName string + remoteName := getMetaStr(blockMeta, "connection", "") + isRemote := remoteName != "" + if isRemote { + randStr, err := randomHexString(16) // 64-bits of randomness + if err != nil { + return fmt.Errorf("error generating random string: %w", err) + } + remoteDomainSocketName = fmt.Sprintf("/tmp/waveterm-%s.sock", randStr) + } var cmdStr string cmdOpts := shellexec.CommandOptsType{ Env: make(map[string]string), } - // temporary for blockid (will switch to a JWT at some point) - cmdOpts.Env["LC_WAVETERM_BLOCKID"] = bc.BlockId + if !getMetaBool(blockMeta, "nowsh", false) { + if isRemote { + jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, remoteDomainSocketName) + if err != nil { + return fmt.Errorf("error making jwt token: %w", err) + } + cmdOpts.Env["WAVETERM_JWT"] = jwtStr + } else { + jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName()) + if err != nil { + return fmt.Errorf("error making jwt token: %w", err) + } + cmdOpts.Env["WAVETERM_JWT"] = jwtStr + } + } if bc.ControllerType == BlockController_Shell { cmdOpts.Interactive = true cmdOpts.Login = true @@ -284,11 +350,8 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta map[str } else { return fmt.Errorf("unknown controller type %q", bc.ControllerType) } - // pty buffer equivalent for ssh? i think if i have the ecmd or session i can manage it with output - // pty write needs stdin, so if i provide that, i might be able to write that way - // need a way to handle setsize??? var shellProc *shellexec.ShellProc - if remoteName, ok := blockMeta["connection"].(string); ok && remoteName != "" { + if remoteName != "" { shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, remoteName) if err != nil { return err @@ -309,7 +372,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta map[str messageCh := make(chan []byte, 32) ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Pty, messageCh) outputCh := make(chan []byte, 32) - WshServerFactoryFn(messageCh, outputCh, wshutil.RpcContext{BlockId: bc.BlockId, TabId: bc.TabId}) + WshServerFactoryFn(messageCh, outputCh, wshrpc.RpcContext{BlockId: bc.BlockId, TabId: bc.TabId}) go func() { // handles regular output from the pty (goes to the blockfile and xterm) defer func() { diff --git a/pkg/tsgen/tsgen.go b/pkg/tsgen/tsgen.go index 7381e892d..6f45cc930 100644 --- a/pkg/tsgen/tsgen.go +++ b/pkg/tsgen/tsgen.go @@ -20,7 +20,6 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wconfig" "github.com/wavetermdev/thenextwave/pkg/web/webcmd" "github.com/wavetermdev/thenextwave/pkg/wshrpc" - "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshserver" "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" ) @@ -61,10 +60,13 @@ var uiContextRType = reflect.TypeOf((*wstore.UIContext)(nil)).Elem() var waveObjRType = reflect.TypeOf((*waveobj.WaveObj)(nil)).Elem() var updatesRtnRType = reflect.TypeOf(wstore.UpdatesRtnType{}) var orefRType = reflect.TypeOf((*waveobj.ORef)(nil)).Elem() +var wshRpcInterfaceRType = reflect.TypeOf((*wshrpc.WshRpcInterface)(nil)).Elem() -func generateTSMethodTypes(method reflect.Method, tsTypesMap map[reflect.Type]string) error { - for idx := 1; idx < method.Type.NumIn(); idx++ { - // skip receiver +func generateTSMethodTypes(method reflect.Method, tsTypesMap map[reflect.Type]string, skipFirstArg bool) error { + for idx := 0; idx < method.Type.NumIn(); idx++ { + if skipFirstArg && idx == 0 { + continue + } inType := method.Type.In(idx) GenerateTSType(inType, tsTypesMap) } @@ -159,14 +161,13 @@ var tsRenameMap = map[string]string{ func generateTSTypeInternal(rtype reflect.Type, tsTypesMap map[reflect.Type]string) (string, []reflect.Type) { var buf bytes.Buffer - waveObjType := reflect.TypeOf((*waveobj.WaveObj)(nil)).Elem() tsTypeName := rtype.Name() if tsRename, ok := tsRenameMap[tsTypeName]; ok { tsTypeName = tsRename } var isWaveObj bool buf.WriteString(fmt.Sprintf("// %s\n", rtype.String())) - if rtype.Implements(waveObjType) || reflect.PointerTo(rtype).Implements(waveObjType) { + if rtype.Implements(waveObjRType) || reflect.PointerTo(rtype).Implements(waveObjRType) { isWaveObj = true buf.WriteString(fmt.Sprintf("type %s = WaveObj & {\n", tsTypeName)) } else { @@ -253,6 +254,9 @@ func GenerateTSType(rtype reflect.Type, tsTypesMap map[reflect.Type]string) { if rtype == nil { return } + if rtype.Kind() == reflect.Chan { + rtype = rtype.Elem() + } if rtype == metaRType { tsTypesMap[metaRType] = GenerateMetaType() return @@ -397,17 +401,17 @@ func GenerateServiceClass(serviceName string, serviceObj any, tsTypesMap map[ref return sb.String() } -func GenerateWshServerMethod(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string { - if methodDecl.CommandType == wshutil.RpcType_ResponseStream { +func GenerateWshServerMethod(methodDecl *wshrpc.WshRpcMethodDecl, tsTypesMap map[reflect.Type]string) string { + if methodDecl.CommandType == wshrpc.RpcType_ResponseStream { return GenerateWshServerMethod_ResponseStream(methodDecl, tsTypesMap) - } else if methodDecl.CommandType == wshutil.RpcType_Call { + } else if methodDecl.CommandType == wshrpc.RpcType_Call { return GenerateWshServerMethod_Call(methodDecl, tsTypesMap) } else { panic(fmt.Sprintf("cannot generate wshserver commandtype %q", methodDecl.CommandType)) } } -func GenerateWshServerMethod_ResponseStream(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string { +func GenerateWshServerMethod_ResponseStream(methodDecl *wshrpc.WshRpcMethodDecl, tsTypesMap map[reflect.Type]string) string { var sb strings.Builder sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType)) respType := "any" @@ -429,7 +433,7 @@ func GenerateWshServerMethod_ResponseStream(methodDecl *wshserver.WshServerMetho return sb.String() } -func GenerateWshServerMethod_Call(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string { +func GenerateWshServerMethod_Call(methodDecl *wshrpc.WshRpcMethodDecl, tsTypesMap map[reflect.Type]string) string { var sb strings.Builder sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType)) rtnType := "Promise" @@ -469,7 +473,7 @@ func GenerateServiceTypes(tsTypesMap map[reflect.Type]string) error { serviceType := reflect.TypeOf(serviceObj) for midx := 0; midx < serviceType.NumMethod(); midx++ { method := serviceType.Method(midx) - err := generateTSMethodTypes(method, tsTypesMap) + err := generateTSMethodTypes(method, tsTypesMap, true) if err != nil { return fmt.Errorf("error generating TS method types for %s.%s: %v", serviceType, method.Name, err) } @@ -480,16 +484,12 @@ func GenerateServiceTypes(tsTypesMap map[reflect.Type]string) error { func GenerateWshServerTypes(tsTypesMap map[reflect.Type]string) error { GenerateTSType(reflect.TypeOf(wshrpc.WshRpcCommandOpts{}), tsTypesMap) - for _, methodDecl := range wshserver.WshServerCommandToDeclMap { - GenerateTSType(methodDecl.CommandDataType, tsTypesMap) - if methodDecl.DefaultResponseDataType != nil { - GenerateTSType(methodDecl.DefaultResponseDataType, tsTypesMap) - } - for _, rtype := range methodDecl.RequestDataTypes { - GenerateTSType(rtype, tsTypesMap) - } - for _, rtype := range methodDecl.ResponseDataTypes { - GenerateTSType(rtype, tsTypesMap) + rtype := wshRpcInterfaceRType + for midx := 0; midx < rtype.NumMethod(); midx++ { + method := rtype.Method(midx) + err := generateTSMethodTypes(method, tsTypesMap, false) + if err != nil { + return fmt.Errorf("error generating TS method types for %s.%s: %v", rtype, method.Name, err) } } return nil diff --git a/pkg/util/utilfn/utilfn.go b/pkg/util/utilfn/utilfn.go index 9cd6c118a..b3067f51b 100644 --- a/pkg/util/utilfn/utilfn.go +++ b/pkg/util/utilfn/utilfn.go @@ -800,3 +800,29 @@ func MoveSliceIdxToFront[T any](arr []T, idx int) []T { rtn = append(rtn, arr[idx+1:]...) return rtn } + +// matches a delimited string with a pattern string +// the pattern string can contain "*" to match a single part, or "**" to match the rest of the string +// note that "**" may only appear at the end of the string +func StarMatchString(pattern string, s string, delimiter string) bool { + patternParts := strings.Split(pattern, delimiter) + stringParts := strings.Split(s, delimiter) + pLen, sLen := len(patternParts), len(stringParts) + + for i := 0; i < pLen; i++ { + if patternParts[i] == "**" { + // '**' must be at the end to be valid + return i == pLen-1 + } + if i >= sLen { + // If string is exhausted but pattern is not + return false + } + if patternParts[i] != "*" && patternParts[i] != stringParts[i] { + // If current parts don't match and pattern part is not '*' + return false + } + } + // Check if both pattern and string are fully matched + return pLen == sLen +} diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go index 671877414..1a49e0c50 100644 --- a/pkg/waveai/waveai.go +++ b/pkg/waveai/waveai.go @@ -23,12 +23,6 @@ const OpenAIPacketStr = "openai" const OpenAICloudReqStr = "openai-cloudreq" const PacketEOFStr = "EOF" -type OpenAIUsageType struct { - PromptTokens int `json:"prompt_tokens,omitempty"` - CompletionTokens int `json:"completion_tokens,omitempty"` - TotalTokens int `json:"total_tokens,omitempty"` -} - type OpenAICmdInfoPacketOutputType struct { Model string `json:"model,omitempty"` Created int64 `json:"created,omitempty"` @@ -37,19 +31,8 @@ type OpenAICmdInfoPacketOutputType struct { Error string `json:"error,omitempty"` } -type OpenAIPacketType struct { - Type string `json:"type"` - Model string `json:"model,omitempty"` - Created int64 `json:"created,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Usage *OpenAIUsageType `json:"usage,omitempty"` - Index int `json:"index,omitempty"` - Text string `json:"text,omitempty"` - Error string `json:"error,omitempty"` -} - -func MakeOpenAIPacket() *OpenAIPacketType { - return &OpenAIPacketType{Type: OpenAIPacketStr} +func MakeOpenAIPacket() *wshrpc.OpenAIPacketType { + return &wshrpc.OpenAIPacketType{Type: OpenAIPacketStr} } type OpenAICmdInfoChatMessage struct { @@ -60,27 +43,12 @@ type OpenAICmdInfoChatMessage struct { UserEngineeredQuery string `json:"userengineeredquery,omitempty"` } -type OpenAIPromptMessageType struct { - Role string `json:"role"` - Content string `json:"content"` - Name string `json:"name,omitempty"` -} - type OpenAICloudReqPacketType struct { - Type string `json:"type"` - ClientId string `json:"clientid"` - Prompt []OpenAIPromptMessageType `json:"prompt"` - MaxTokens int `json:"maxtokens,omitempty"` - MaxChoices int `json:"maxchoices,omitempty"` -} - -type OpenAIOptsType struct { - Model string `json:"model"` - APIToken string `json:"apitoken"` - BaseURL string `json:"baseurl,omitempty"` - MaxTokens int `json:"maxtokens,omitempty"` - MaxChoices int `json:"maxchoices,omitempty"` - Timeout int `json:"timeout,omitempty"` + Type string `json:"type"` + ClientId string `json:"clientid"` + Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"` + MaxTokens int `json:"maxtokens,omitempty"` + MaxChoices int `json:"maxchoices,omitempty"` } func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { @@ -89,12 +57,6 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { } } -type OpenAiStreamRequest struct { - ClientId string `json:"clientid,omitempty"` - Opts *OpenAIOptsType `json:"opts"` - Prompt []OpenAIPromptMessageType `json:"prompt"` -} - func GetWSEndpoint() string { return PCloudWSEndpoint if !wavebase.IsDevMode() { @@ -116,18 +78,18 @@ const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT" const CloudWebsocketConnectTimeout = 1 * time.Minute -func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType { +func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType { if resp.Usage.TotalTokens == 0 { return nil } - return &OpenAIUsageType{ + return &wshrpc.OpenAIUsageType{ PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, } } -func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { +func ConvertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { var rtn []openaiapi.ChatCompletionMessage for _, p := range prompt { msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} @@ -136,31 +98,31 @@ func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionM return rtn } -func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] { - rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType]) +func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) go func() { log.Printf("start: %v", request) defer close(rtn) if request.Opts == nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} return } websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout) defer dialCancelFn() conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil) + if err == context.DeadlineExceeded { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)} + return + } else if err != nil { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)} + return + } defer func() { err = conn.Close() if err != nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)} } }() - if err == context.DeadlineExceeded { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)} - return - } else if err != nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)} - return - } reqPk := MakeOpenAICloudReqPacket() reqPk.ClientId = request.ClientId reqPk.Prompt = request.Prompt @@ -168,12 +130,12 @@ func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) reqPk.MaxChoices = request.Opts.MaxChoices configMessageBuf, err := json.Marshal(reqPk) if err != nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)} return } err = conn.WriteMessage(websocket.TextMessage, configMessageBuf) if err != nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)} return } for { @@ -184,14 +146,14 @@ func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) } if err != nil { log.Printf("err received: %v", err) - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} break } - var streamResp *OpenAIPacketType + var streamResp *wshrpc.OpenAIPacketType err = json.Unmarshal(socketMessage, &streamResp) log.Printf("ai resp: %v", streamResp) if err != nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)} break } if streamResp.Error == PacketEOFStr { @@ -199,30 +161,30 @@ func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) break } else if streamResp.Error != "" { // use error from server directly - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)} break } - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp} } }() return rtn } -func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] { - rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType]) +func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) go func() { log.Printf("start2: %v", request) defer close(rtn) if request.Opts == nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} return } if request.Opts.Model == "" { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")} return } if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("no api token")} return } clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken) @@ -241,7 +203,7 @@ func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) } apiResp, err := client.CreateChatCompletionStream(ctx, req) if err != nil { - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)} return } sentHeader := false @@ -253,14 +215,14 @@ func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) } if err != nil { log.Printf("err received2: %v", err) - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} break } if streamResp.Model != "" && !sentHeader { pk := MakeOpenAIPacket() pk.Model = streamResp.Model pk.Created = streamResp.Created - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} sentHeader = true } for _, choice := range streamResp.Choices { @@ -268,15 +230,15 @@ func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) pk.Index = choice.Index pk.Text = choice.Delta.Content pk.FinishReason = string(choice.FinishReason) - rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} } } }() return rtn } -func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType { - var rtn []*OpenAIPacketType +func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType { + var rtn []*wshrpc.OpenAIPacketType headerPk := MakeOpenAIPacket() headerPk.Model = resp.Model headerPk.Created = resp.Created @@ -292,14 +254,14 @@ func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType return rtn } -func CreateErrorPacket(errStr string) *OpenAIPacketType { +func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType { errPk := MakeOpenAIPacket() errPk.FinishReason = "error" errPk.Error = errStr return errPk } -func CreateTextPacket(text string) *OpenAIPacketType { +func CreateTextPacket(text string) *wshrpc.OpenAIPacketType { pk := MakeOpenAIPacket() pk.Text = text return pk diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index 6777868e4..28a5ef9de 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -72,13 +72,21 @@ type TermThemesConfigType map[string]TermThemeType // TODO add default term theme settings +// note we pointers so we preserve nulls +type WindowSettingsType struct { + Transparent *bool `json:"transparent"` + Opacity *float64 `json:"opacity"` + BgColor *string `json:"bgcolor"` +} + type SettingsConfigType struct { - MimeTypes map[string]MimeTypeConfigType `json:"mimetypes"` - Term TerminalConfigType `json:"term"` - Widgets []WidgetsConfigType `json:"widgets"` - BlockHeader BlockHeaderOpts `json:"blockheader"` - AutoUpdate *AutoUpdateOpts `json:"autoupdate"` - TermThemes TermThemesConfigType `json:"termthemes"` + MimeTypes map[string]MimeTypeConfigType `json:"mimetypes"` + Term TerminalConfigType `json:"term"` + Widgets []WidgetsConfigType `json:"widgets"` + BlockHeader BlockHeaderOpts `json:"blockheader"` + AutoUpdate *AutoUpdateOpts `json:"autoupdate"` + TermThemes TermThemesConfigType `json:"termthemes"` + WindowSettings WindowSettingsType `json:"window"` } var DefaultTermDarkTheme = TermThemeType{ diff --git a/pkg/web/ws.go b/pkg/web/ws.go index 62594861f..cddbdda9e 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -23,7 +23,7 @@ import ( ) // set by main-server.go (for dependency inversion) -var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) = nil +var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) = nil const wsReadWaitTimeout = 15 * time.Second const wsWriteWaitTimeout = 10 * time.Second @@ -104,7 +104,7 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan [] TermSize: &cmd.TermSize, } rpcMsg := wshutil.RpcMessage{ - Command: wshrpc.Command_BlockInput, + Command: wshrpc.Command_ControllerInput, Data: data, } msgBytes, err := json.Marshal(rpcMsg) @@ -121,7 +121,7 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan [] InputData64: cmd.InputData64, } rpcMsg := wshutil.RpcMessage{ - Command: wshrpc.Command_BlockInput, + Command: wshrpc.Command_ControllerInput, Data: data, } msgBytes, err := json.Marshal(rpcMsg) @@ -281,7 +281,7 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { rpcOutputCh := make(chan []byte, 32) eventbus.RegisterWSChannel(wsConnId, windowId, outputCh) defer eventbus.UnregisterWSChannel(wsConnId) - WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshutil.RpcContext{WindowId: windowId}) + WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{WindowId: windowId}) wg := &sync.WaitGroup{} wg.Add(2) go func() { diff --git a/pkg/wps/wps.go b/pkg/wps/wps.go index a380d8d85..437a640d5 100644 --- a/pkg/wps/wps.go +++ b/pkg/wps/wps.go @@ -5,44 +5,50 @@ package wps import ( + "strings" "sync" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" ) // this broker interface is mostly generic // strong typing and event types can be defined elsewhere -type WaveEvent struct { - Event string `json:"event"` - Scopes []string `json:"scopes,omitempty"` - Sender string `json:"sender,omitempty"` - Data any `json:"data,omitempty"` -} - -type SubscriptionRequest struct { - Event string `json:"event"` - Scopes []string `json:"scopes,omitempty"` - AllScopes bool `json:"allscopes,omitempty"` -} - type Client interface { ClientId() string - SendEvent(event WaveEvent) + SendEvent(event wshrpc.WaveEvent) } type BrokerSubscription struct { AllSubs []string // clientids of client subscribed to "all" events ScopeSubs map[string][]string // clientids of client subscribed to specific scopes + StarSubs map[string][]string // clientids of client subscribed to star scope (scopes with "*" or "**" in them) } -type Broker struct { +type BrokerType struct { Lock *sync.Mutex ClientMap map[string]Client SubMap map[string]*BrokerSubscription } -func (b *Broker) Subscribe(subscriber Client, sub SubscriptionRequest) { +var Broker = &BrokerType{ + Lock: &sync.Mutex{}, + ClientMap: make(map[string]Client), + SubMap: make(map[string]*BrokerSubscription), +} + +func scopeHasStarMatch(scope string) bool { + parts := strings.Split(scope, ":") + for _, part := range parts { + if part == "*" || part == "**" { + return true + } + } + return false +} + +func (b *BrokerType) Subscribe(subscriber Client, sub wshrpc.SubscriptionRequest) { b.Lock.Lock() defer b.Lock.Unlock() clientId := subscriber.ClientId() @@ -51,6 +57,7 @@ func (b *Broker) Subscribe(subscriber Client, sub SubscriptionRequest) { bs = &BrokerSubscription{ AllSubs: []string{}, ScopeSubs: make(map[string][]string), + StarSubs: make(map[string][]string), } b.SubMap[sub.Event] = bs } @@ -58,17 +65,47 @@ func (b *Broker) Subscribe(subscriber Client, sub SubscriptionRequest) { bs.AllSubs = utilfn.AddElemToSliceUniq(bs.AllSubs, clientId) } for _, scope := range sub.Scopes { - scopeSubs := bs.ScopeSubs[scope] - scopeSubs = utilfn.AddElemToSliceUniq(scopeSubs, clientId) - bs.ScopeSubs[scope] = scopeSubs + starMatch := scopeHasStarMatch(scope) + if starMatch { + addStrToScopeMap(bs.StarSubs, scope, clientId) + } else { + addStrToScopeMap(bs.ScopeSubs, scope, clientId) + } } } func (bs *BrokerSubscription) IsEmpty() bool { - return len(bs.AllSubs) == 0 && len(bs.ScopeSubs) == 0 + return len(bs.AllSubs) == 0 && len(bs.ScopeSubs) == 0 && len(bs.StarSubs) == 0 } -func (b *Broker) Unsubscribe(subscriber Client, sub SubscriptionRequest) { +func removeStrFromScopeMap(scopeMap map[string][]string, scope string, clientId string) { + scopeSubs := scopeMap[scope] + scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, clientId) + if len(scopeSubs) == 0 { + delete(scopeMap, scope) + } else { + scopeMap[scope] = scopeSubs + } +} + +func removeStrFromScopeMapAll(scopeMap map[string][]string, clientId string) { + for scope, scopeSubs := range scopeMap { + scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, clientId) + if len(scopeSubs) == 0 { + delete(scopeMap, scope) + } else { + scopeMap[scope] = scopeSubs + } + } +} + +func addStrToScopeMap(scopeMap map[string][]string, scope string, clientId string) { + scopeSubs := scopeMap[scope] + scopeSubs = utilfn.AddElemToSliceUniq(scopeSubs, clientId) + scopeMap[scope] = scopeSubs +} + +func (b *BrokerType) Unsubscribe(subscriber Client, sub wshrpc.SubscriptionRequest) { b.Lock.Lock() defer b.Lock.Unlock() clientId := subscriber.ClientId() @@ -80,12 +117,11 @@ func (b *Broker) Unsubscribe(subscriber Client, sub SubscriptionRequest) { bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, clientId) } for _, scope := range sub.Scopes { - scopeSubs := bs.ScopeSubs[scope] - scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, clientId) - if len(scopeSubs) == 0 { - delete(bs.ScopeSubs, scope) + starMatch := scopeHasStarMatch(scope) + if starMatch { + removeStrFromScopeMap(bs.StarSubs, scope, clientId) } else { - bs.ScopeSubs[scope] = scopeSubs + removeStrFromScopeMap(bs.ScopeSubs, scope, clientId) } } if bs.IsEmpty() { @@ -93,28 +129,22 @@ func (b *Broker) Unsubscribe(subscriber Client, sub SubscriptionRequest) { } } -func (b *Broker) UnsubscribeAll(subscriber Client) { +func (b *BrokerType) UnsubscribeAll(subscriber Client) { b.Lock.Lock() defer b.Lock.Unlock() clientId := subscriber.ClientId() delete(b.ClientMap, clientId) for eventType, bs := range b.SubMap { bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, clientId) - for scope, scopeSubs := range bs.ScopeSubs { - scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, clientId) - if len(scopeSubs) == 0 { - delete(bs.ScopeSubs, scope) - } else { - bs.ScopeSubs[scope] = scopeSubs - } - } + removeStrFromScopeMapAll(bs.StarSubs, clientId) + removeStrFromScopeMapAll(bs.ScopeSubs, clientId) if bs.IsEmpty() { delete(b.SubMap, eventType) } } } -func (b *Broker) Publish(subscriber Client, event WaveEvent) { +func (b *BrokerType) Publish(event wshrpc.WaveEvent) { clientIds := b.getMatchingClientIds(event) for _, clientId := range clientIds { client := b.ClientMap[clientId] @@ -124,7 +154,7 @@ func (b *Broker) Publish(subscriber Client, event WaveEvent) { } } -func (b *Broker) getMatchingClientIds(event WaveEvent) []string { +func (b *BrokerType) getMatchingClientIds(event wshrpc.WaveEvent) []string { b.Lock.Lock() defer b.Lock.Unlock() bs := b.SubMap[event.Event] @@ -139,6 +169,13 @@ func (b *Broker) getMatchingClientIds(event WaveEvent) []string { for _, clientId := range bs.ScopeSubs[scope] { clientIds[clientId] = true } + for starScope := range bs.StarSubs { + if utilfn.StarMatchString(starScope, scope, ":") { + for _, clientId := range bs.StarSubs[starScope] { + clientIds[clientId] = true + } + } + } } var rtn []string for clientId := range clientIds { diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 0e4e9e8d4..e9cc46f62 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -9,24 +9,29 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/waveobj" - "github.com/wavetermdev/thenextwave/pkg/waveai" ) -// command "controller:input", wshserver.BlockInputCommand -func BlockInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.WshRpcCommandOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "controller:input", data, opts) +// command "authenticate", wshserver.AuthenticateCommand +func AuthenticateCommand(w *wshutil.WshRpc, data string, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "authenticate", data, opts) return err } -// command "controller:restart", wshserver.BlockRestartCommand -func BlockRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.WshRpcCommandOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "controller:restart", data, opts) +// command "controllerinput", wshserver.ControllerInputCommand +func ControllerInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "controllerinput", data, opts) + return err +} + +// command "controllerrestart", wshserver.ControllerRestartCommand +func ControllerRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "controllerrestart", data, opts) return err } // command "createblock", wshserver.CreateBlockCommand -func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.WshRpcCommandOpts) (*waveobj.ORef, error) { - resp, err := sendRpcRequestCallHelper[*waveobj.ORef](w, "createblock", data, opts) +func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.WshRpcCommandOpts) (waveobj.ORef, error) { + resp, err := sendRpcRequestCallHelper[waveobj.ORef](w, "createblock", data, opts) return resp, err } @@ -36,27 +41,57 @@ func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, o return err } -// command "file:append", wshserver.AppendFileCommand -func AppendFileCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "file:append", data, opts) +// command "eventpublish", wshserver.EventPublishCommand +func EventPublishCommand(w *wshutil.WshRpc, data wshrpc.WaveEvent, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts) return err } -// command "file:appendijson", wshserver.AppendIJsonCommand -func AppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.WshRpcCommandOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "file:appendijson", data, opts) +// command "eventrecv", wshserver.EventRecvCommand +func EventRecvCommand(w *wshutil.WshRpc, data wshrpc.WaveEvent, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "eventrecv", data, opts) return err } -// command "file:read", wshserver.ReadFile -func ReadFile(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) (string, error) { - resp, err := sendRpcRequestCallHelper[string](w, "file:read", data, opts) +// command "eventsub", wshserver.EventSubCommand +func EventSubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "eventsub", data, opts) + return err +} + +// command "eventunsub", wshserver.EventUnsubCommand +func EventUnsubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "eventunsub", data, opts) + return err +} + +// command "eventunsuball", wshserver.EventUnsubAllCommand +func EventUnsubAllCommand(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "eventunsuball", nil, opts) + return err +} + +// command "fileappend", wshserver.FileAppendCommand +func FileAppendCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "fileappend", data, opts) + return err +} + +// command "fileappendijson", wshserver.FileAppendIJsonCommand +func FileAppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "fileappendijson", data, opts) + return err +} + +// command "fileread", wshserver.FileReadCommand +func FileReadCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) (string, error) { + resp, err := sendRpcRequestCallHelper[string](w, "fileread", data, opts) return resp, err } -// command "file:write", wshserver.WriteFile -func WriteFile(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "file:write", data, opts) +// command "filewrite", wshserver.FileWriteCommand +func FileWriteCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "filewrite", data, opts) return err } @@ -84,20 +119,20 @@ func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wsh return err } -// command "setview", wshserver.BlockSetViewCommand -func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.WshRpcCommandOpts) error { +// command "setview", wshserver.SetViewCommand +func SetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.WshRpcCommandOpts) error { _, err := sendRpcRequestCallHelper[any](w, "setview", data, opts) return err } -// command "stream:waveai", wshserver.RespStreamWaveAi -func RespStreamWaveAi(w *wshutil.WshRpc, data waveai.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] { - return sendRpcRequestResponseStreamHelper[waveai.OpenAIPacketType](w, "stream:waveai", data, opts) -} - -// command "streamtest", wshserver.RespStreamTest -func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] { +// command "streamtest", wshserver.StreamTestCommand +func StreamTestCommand(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] { return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts) } +// command "streamwaveai", wshserver.StreamWaveAiCommand +func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { + return sendRpcRequestResponseStreamHelper[wshrpc.OpenAIPacketType](w, "streamwaveai", data, opts) +} + diff --git a/pkg/wshrpc/wshrpcmeta.go b/pkg/wshrpc/wshrpcmeta.go new file mode 100644 index 000000000..09568ea09 --- /dev/null +++ b/pkg/wshrpc/wshrpcmeta.go @@ -0,0 +1,116 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshrpc + +import ( + "context" + "fmt" + "log" + "reflect" + "strings" +) + +type WshRpcMethodDecl struct { + Command string + CommandType string + MethodName string + CommandDataType reflect.Type + DefaultResponseDataType reflect.Type +} + +var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem() +var wshRpcInterfaceRType = reflect.TypeOf((*WshRpcInterface)(nil)).Elem() + +func getWshCommandType(method reflect.Method) string { + if method.Type.NumOut() == 1 { + outType := method.Type.Out(0) + if outType.Kind() == reflect.Chan { + return RpcType_ResponseStream + } + } + return RpcType_Call +} + +func getWshMethodResponseType(commandType string, method reflect.Method) reflect.Type { + switch commandType { + case RpcType_ResponseStream: + if method.Type.NumOut() != 1 { + panic(fmt.Sprintf("method %q has invalid number of return values for response stream", method.Name)) + } + outType := method.Type.Out(0) + if outType.Kind() != reflect.Chan { + panic(fmt.Sprintf("method %q has invalid return type %s for response stream", method.Name, outType)) + } + elemType := outType.Elem() + if !strings.HasPrefix(elemType.Name(), "RespOrErrorUnion") { + panic(fmt.Sprintf("method %q has invalid return element type %s for response stream (should be RespOrErrorUnion)", method.Name, elemType)) + } + respField, found := elemType.FieldByName("Response") + if !found { + panic(fmt.Sprintf("method %q has invalid return element type %s for response stream (missing Response field)", method.Name, elemType)) + } + return respField.Type + case RpcType_Call: + if method.Type.NumOut() > 1 { + return method.Type.Out(0) + } + return nil + default: + panic(fmt.Sprintf("unsupported command type %q", commandType)) + } +} + +func generateWshCommandDecl(method reflect.Method) *WshRpcMethodDecl { + if method.Type.NumIn() == 0 || method.Type.In(0) != contextRType { + panic(fmt.Sprintf("method %q does not have context as first argument", method.Name)) + } + cmdStr := method.Name + decl := &WshRpcMethodDecl{} + // remove Command suffix + if !strings.HasSuffix(cmdStr, "Command") { + panic(fmt.Sprintf("method %q does not have Command suffix", cmdStr)) + } + cmdStr = cmdStr[:len(cmdStr)-len("Command")] + decl.Command = strings.ToLower(cmdStr) + decl.CommandType = getWshCommandType(method) + decl.MethodName = method.Name + var cdataType reflect.Type + if method.Type.NumIn() > 1 { + cdataType = method.Type.In(1) + } + decl.CommandDataType = cdataType + decl.DefaultResponseDataType = getWshMethodResponseType(decl.CommandType, method) + return decl +} + +func MakeMethodMapForImpl(impl any, declMap map[string]*WshRpcMethodDecl) map[string]reflect.Method { + rtype := reflect.TypeOf(impl) + rtnMap := make(map[string]reflect.Method) + for midx := 0; midx < rtype.NumMethod(); midx++ { + method := rtype.Method(midx) + if !strings.HasSuffix(method.Name, "Command") { + continue + } + commandName := strings.ToLower(method.Name[:len(method.Name)-len("Command")]) + decl := declMap[commandName] + if decl == nil { + log.Printf("WARNING: method %q does not match a command method", method.Name) + continue + } + rtnMap[commandName] = method + } + return rtnMap + +} + +func GenerateWshCommandDeclMap() map[string]*WshRpcMethodDecl { + rtype := wshRpcInterfaceRType + rtnMap := make(map[string]*WshRpcMethodDecl) + for midx := 0; midx < rtype.NumMethod(); midx++ { + method := rtype.Method(midx) + decl := generateWshCommandDecl(method) + rtnMap[decl.Command] = decl + } + return rtnMap +} diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 1679b95c9..627ee557a 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -5,45 +5,77 @@ package wshrpc import ( + "context" "reflect" "github.com/wavetermdev/thenextwave/pkg/ijson" "github.com/wavetermdev/thenextwave/pkg/shellexec" "github.com/wavetermdev/thenextwave/pkg/waveobj" - "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" ) const ( - Command_Message = "message" - Command_SetView = "setview" - Command_SetMeta = "setmeta" - Command_GetMeta = "getmeta" - Command_BlockInput = "controller:input" - Command_Restart = "controller:restart" - Command_AppendFile = "file:append" - Command_AppendIJson = "file:appendijson" - Command_ResolveIds = "resolveids" - Command_CreateBlock = "createblock" - Command_DeleteBlock = "deleteblock" - Command_WriteFile = "file:write" - Command_ReadFile = "file:read" - Command_StreamWaveAi = "stream:waveai" + RpcType_Call = "call" // single response (regular rpc) + RpcType_ResponseStream = "responsestream" // stream of responses (streaming rpc) + RpcType_StreamingRequest = "streamingrequest" // streaming request + RpcType_Complex = "complex" // streaming request/response +) + +const ( + Command_Authenticate = "authenticate" + Command_Message = "message" + Command_GetMeta = "getmeta" + Command_SetMeta = "setmeta" + Command_SetView = "setview" + Command_ControllerInput = "controllerinput" + Command_ControllerRestart = "controllerrestart" + Command_FileAppend = "fileappend" + Command_FileAppendIJson = "fileappendijson" + Command_ResolveIds = "resolveids" + Command_CreateBlock = "createblock" + Command_DeleteBlock = "deleteblock" + Command_FileWrite = "filewrite" + Command_FileRead = "fileread" + Command_EventPublish = "eventpublish" + Command_EventRecv = "eventrecv" + Command_EventSub = "eventsub" + Command_EventUnsub = "eventunsub" + Command_EventUnsubAll = "eventunsuball" + Command_StreamTest = "streamtest" + Command_StreamWaveAi = "streamwaveai" ) type MetaDataType = map[string]any -var DataTypeMap = map[string]reflect.Type{ - "meta": reflect.TypeOf(MetaDataType{}), - "resolveidsrtn": reflect.TypeOf(CommandResolveIdsRtnData{}), - "oref": reflect.TypeOf(waveobj.ORef{}), -} - type RespOrErrorUnion[T any] struct { Response T Error error } +type WshRpcInterface interface { + AuthenticateCommand(ctx context.Context, data string) error + MessageCommand(ctx context.Context, data CommandMessageData) error + GetMetaCommand(ctx context.Context, data CommandGetMetaData) (MetaDataType, error) + SetMetaCommand(ctx context.Context, data CommandSetMetaData) error + SetViewCommand(ctx context.Context, data CommandBlockSetViewData) error + ControllerInputCommand(ctx context.Context, data CommandBlockInputData) error + ControllerRestartCommand(ctx context.Context, data CommandBlockRestartData) error + FileAppendCommand(ctx context.Context, data CommandFileData) error + FileAppendIJsonCommand(ctx context.Context, data CommandAppendIJsonData) error + ResolveIdsCommand(ctx context.Context, data CommandResolveIdsData) (CommandResolveIdsRtnData, error) + CreateBlockCommand(ctx context.Context, data CommandCreateBlockData) (waveobj.ORef, error) + DeleteBlockCommand(ctx context.Context, data CommandDeleteBlockData) error + FileWriteCommand(ctx context.Context, data CommandFileData) error + FileReadCommand(ctx context.Context, data CommandFileData) (string, error) + EventPublishCommand(ctx context.Context, data WaveEvent) error + EventRecvCommand(ctx context.Context, data WaveEvent) error + EventSubCommand(ctx context.Context, data SubscriptionRequest) error + EventUnsubCommand(ctx context.Context, data SubscriptionRequest) error + EventUnsubAllCommand(ctx context.Context) error + StreamTestCommand(ctx context.Context) chan RespOrErrorUnion[int] + StreamWaveAiCommand(ctx context.Context, request OpenAiStreamRequest) chan RespOrErrorUnion[OpenAIPacketType] +} + // for frontend type WshServerCommandMeta struct { CommandType string `json:"commandtype"` @@ -54,7 +86,13 @@ type WshRpcCommandOpts struct { NoResponse bool `json:"noresponse"` } -func HackRpcContextIntoData(dataPtr any, rpcContext wshutil.RpcContext) { +type RpcContext struct { + BlockId string `json:"blockid,omitempty"` + TabId string `json:"tabid,omitempty"` + WindowId string `json:"windowid,omitempty"` +} + +func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { dataVal := reflect.ValueOf(dataPtr).Elem() dataType := dataVal.Type() for i := 0; i < dataVal.NumField(); i++ { @@ -141,3 +179,54 @@ type CommandAppendIJsonData struct { type CommandDeleteBlockData struct { BlockId string `json:"blockid" wshcontext:"BlockId"` } + +type WaveEvent struct { + Event string `json:"event"` + Scopes []string `json:"scopes,omitempty"` + Sender string `json:"sender,omitempty"` + Data any `json:"data,omitempty"` +} + +type SubscriptionRequest struct { + Event string `json:"event"` + Scopes []string `json:"scopes,omitempty"` + AllScopes bool `json:"allscopes,omitempty"` +} + +type OpenAiStreamRequest struct { + ClientId string `json:"clientid,omitempty"` + Opts *OpenAIOptsType `json:"opts"` + Prompt []OpenAIPromptMessageType `json:"prompt"` +} + +type OpenAIPromptMessageType struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` +} + +type OpenAIOptsType struct { + Model string `json:"model"` + APIToken string `json:"apitoken"` + BaseURL string `json:"baseurl,omitempty"` + MaxTokens int `json:"maxtokens,omitempty"` + MaxChoices int `json:"maxchoices,omitempty"` + Timeout int `json:"timeout,omitempty"` +} + +type OpenAIPacketType struct { + Type string `json:"type"` + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Usage *OpenAIUsageType `json:"usage,omitempty"` + Index int `json:"index,omitempty"` + Text string `json:"text,omitempty"` + Error string `json:"error,omitempty"` +} + +type OpenAIUsageType struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 07fbf0e66..0b9f22124 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -11,7 +11,6 @@ import ( "fmt" "io/fs" "log" - "reflect" "strings" "time" @@ -20,45 +19,26 @@ import ( "github.com/wavetermdev/thenextwave/pkg/filestore" "github.com/wavetermdev/thenextwave/pkg/waveai" "github.com/wavetermdev/thenextwave/pkg/waveobj" + "github.com/wavetermdev/thenextwave/pkg/wps" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" ) -var RespStreamTest_MethodDecl = &WshServerMethodDecl{ - Command: "streamtest", - CommandType: wshutil.RpcType_ResponseStream, - MethodName: "RespStreamTest", - Method: reflect.ValueOf(WshServerImpl.RespStreamTest), - CommandDataType: nil, - DefaultResponseDataType: reflect.TypeOf((int)(0)), -} - -var RespStreamWaveAi_MethodDecl = &WshServerMethodDecl{ - Command: wshrpc.Command_StreamWaveAi, - CommandType: wshutil.RpcType_ResponseStream, - MethodName: "RespStreamWaveAi", - Method: reflect.ValueOf(WshServerImpl.RespStreamWaveAi), - CommandDataType: reflect.TypeOf(waveai.OpenAiStreamRequest{}), - DefaultResponseDataType: reflect.TypeOf(waveai.OpenAIPacketType{}), -} - -var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{ - wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand), - wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand), - wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand), - wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand), - wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand), - wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand), - wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand), - wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand), - wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand), - wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand), - wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand), - wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile), - wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile), - wshrpc.Command_StreamWaveAi: RespStreamWaveAi_MethodDecl, - "streamtest": RespStreamTest_MethodDecl, +func (ws *WshServer) AuthenticateCommand(ctx context.Context, data string) error { + w := wshutil.GetWshRpcFromContext(ctx) + if w == nil { + return fmt.Errorf("no wshrpc in context") + } + newCtx, err := wshutil.ValidateAndExtractRpcContextFromToken(data) + if err != nil { + return fmt.Errorf("error validating token: %w", err) + } + if newCtx == nil { + return fmt.Errorf("no context found in jwt token") + } + w.SetRpcContext(*newCtx) + return nil } // for testing @@ -68,7 +48,7 @@ func (ws *WshServer) MessageCommand(ctx context.Context, data wshrpc.CommandMess } // for testing -func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrErrorUnion[int] { +func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrErrorUnion[int] { rtn := make(chan wshrpc.RespOrErrorUnion[int]) go func() { for i := 1; i <= 5; i++ { @@ -80,7 +60,7 @@ func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrError return rtn } -func (ws *WshServer) RespStreamWaveAi(ctx context.Context, request waveai.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] { +func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { return waveai.RunCloudCompletionStream(ctx, request) } @@ -224,7 +204,7 @@ func (ws *WshServer) CreateBlockCommand(ctx context.Context, data wshrpc.Command return &waveobj.ORef{OType: wstore.OType_Block, OID: blockData.OID}, nil } -func (ws *WshServer) BlockSetViewCommand(ctx context.Context, data wshrpc.CommandBlockSetViewData) error { +func (ws *WshServer) SetViewCommand(ctx context.Context, data wshrpc.CommandBlockSetViewData) error { log.Printf("SETVIEW: %s | %q\n", data.BlockId, data.View) ctx = wstore.ContextWithUpdates(ctx) block, err := wstore.DBGet[*wstore.Block](ctx, data.BlockId) @@ -241,7 +221,7 @@ func (ws *WshServer) BlockSetViewCommand(ctx context.Context, data wshrpc.Comman return nil } -func (ws *WshServer) BlockRestartCommand(ctx context.Context, data wshrpc.CommandBlockRestartData) error { +func (ws *WshServer) ControllerRestartCommand(ctx context.Context, data wshrpc.CommandBlockRestartData) error { bc := blockcontroller.GetBlockController(data.BlockId) if bc == nil { return fmt.Errorf("block controller not found for block %q", data.BlockId) @@ -249,7 +229,7 @@ func (ws *WshServer) BlockRestartCommand(ctx context.Context, data wshrpc.Comman return bc.RestartController() } -func (ws *WshServer) BlockInputCommand(ctx context.Context, data wshrpc.CommandBlockInputData) error { +func (ws *WshServer) ControllerInputCommand(ctx context.Context, data wshrpc.CommandBlockInputData) error { bc := blockcontroller.GetBlockController(data.BlockId) if bc == nil { return fmt.Errorf("block controller not found for block %q", data.BlockId) @@ -269,7 +249,7 @@ func (ws *WshServer) BlockInputCommand(ctx context.Context, data wshrpc.CommandB return bc.SendInput(inputUnion) } -func (ws *WshServer) WriteFile(ctx context.Context, data wshrpc.CommandFileData) error { +func (ws *WshServer) FileWriteCommand(ctx context.Context, data wshrpc.CommandFileData) error { dataBuf, err := base64.StdEncoding.DecodeString(data.Data64) if err != nil { return fmt.Errorf("error decoding data64: %w", err) @@ -290,7 +270,7 @@ func (ws *WshServer) WriteFile(ctx context.Context, data wshrpc.CommandFileData) return nil } -func (ws *WshServer) ReadFile(ctx context.Context, data wshrpc.CommandFileData) (string, error) { +func (ws *WshServer) FileReadCommand(ctx context.Context, data wshrpc.CommandFileData) (string, error) { _, dataBuf, err := filestore.WFS.ReadFile(ctx, data.ZoneId, data.FileName) if err != nil { return "", fmt.Errorf("error reading blockfile: %w", err) @@ -298,7 +278,7 @@ func (ws *WshServer) ReadFile(ctx context.Context, data wshrpc.CommandFileData) return base64.StdEncoding.EncodeToString(dataBuf), nil } -func (ws *WshServer) AppendFileCommand(ctx context.Context, data wshrpc.CommandFileData) error { +func (ws *WshServer) FileAppendCommand(ctx context.Context, data wshrpc.CommandFileData) error { dataBuf, err := base64.StdEncoding.DecodeString(data.Data64) if err != nil { return fmt.Errorf("error decoding data64: %w", err) @@ -320,7 +300,7 @@ func (ws *WshServer) AppendFileCommand(ctx context.Context, data wshrpc.CommandF return nil } -func (ws *WshServer) AppendIJsonCommand(ctx context.Context, data wshrpc.CommandAppendIJsonData) error { +func (ws *WshServer) FileAppendIJsonCommand(ctx context.Context, data wshrpc.CommandAppendIJsonData) error { tryCreate := true if data.FileName == blockcontroller.BlockFile_Html && tryCreate { err := filestore.WFS.MakeFile(ctx, data.ZoneId, data.FileName, nil, filestore.FileOptsType{MaxSize: blockcontroller.DefaultHtmlMaxFileSize, IJson: true}) @@ -378,3 +358,46 @@ func (ws *WshServer) DeleteBlockCommand(ctx context.Context, data wshrpc.Command sendWStoreUpdatesToEventBus(updates) return nil } + +func (ws *WshServer) EventRecvCommand(ctx context.Context, data wshrpc.WaveEvent) error { + return nil +} + +func (ws *WshServer) EventPublishCommand(ctx context.Context, data wshrpc.WaveEvent) error { + wrpc := wshutil.GetWshRpcFromContext(ctx) + if wrpc == nil { + return fmt.Errorf("no wshrpc in context") + } + if data.Sender == "" { + data.Sender = wrpc.ClientId() + } + wps.Broker.Publish(data) + return nil +} + +func (ws *WshServer) EventSubCommand(ctx context.Context, data wshrpc.SubscriptionRequest) error { + wrpc := wshutil.GetWshRpcFromContext(ctx) + if wrpc == nil { + return fmt.Errorf("no wshrpc in context") + } + wps.Broker.Subscribe(wrpc, data) + return nil +} + +func (ws *WshServer) EventUnsubCommand(ctx context.Context, data wshrpc.SubscriptionRequest) error { + wrpc := wshutil.GetWshRpcFromContext(ctx) + if wrpc == nil { + return fmt.Errorf("no wshrpc in context") + } + wps.Broker.Unsubscribe(wrpc, data) + return nil +} + +func (ws *WshServer) EventUnsubAllCommand(ctx context.Context) error { + wrpc := wshutil.GetWshRpcFromContext(ctx) + if wrpc == nil { + return fmt.Errorf("no wshrpc in context") + } + wps.Broker.UnsubscribeAll(wrpc) + return nil +} diff --git a/pkg/wshrpc/wshserver/wshserverutil.go b/pkg/wshrpc/wshserver/wshserverutil.go index 461c1a33b..dc876f73d 100644 --- a/pkg/wshrpc/wshserver/wshserverutil.go +++ b/pkg/wshrpc/wshserver/wshserverutil.go @@ -10,9 +10,7 @@ import ( "net" "os" "reflect" - "time" - "github.com/golang-jwt/jwt/v5" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wshrpc" @@ -41,6 +39,7 @@ type WshServerMethodDecl struct { var WshServerImpl = WshServer{} var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem() +var wshCommandDeclMap = wshrpc.GenerateWshCommandDeclMap() func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl { methodVal := reflect.ValueOf(methodFunc) @@ -55,12 +54,16 @@ func GetWshServerMethod(command string, commandType string, methodName string, m if methodType.NumOut() > 1 { defResponseType = methodType.Out(0) } + var cdataType reflect.Type + if methodType.NumIn() > 1 { + cdataType = methodType.In(1) + } rtn := &WshServerMethodDecl{ Command: command, CommandType: commandType, MethodName: methodName, Method: methodVal, - CommandDataType: methodType.In(1), + CommandDataType: cdataType, DefaultResponseDataType: defResponseType, } return rtn @@ -89,7 +92,7 @@ func decodeRtnVals(rtnVals []reflect.Value) (any, error) { func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool { command := handler.GetCommand() - methodDecl := WshServerCommandToDeclMap[command] + methodDecl := wshCommandDeclMap[command] if methodDecl == nil { handler.SendResponseError(fmt.Errorf("command %q not found", command)) return true @@ -106,8 +109,18 @@ func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool { wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext()) callParams = append(callParams, reflect.ValueOf(commandData).Elem()) } - if methodDecl.CommandType == wshutil.RpcType_Call { - rtnVals := methodDecl.Method.Call(callParams) + implVal := reflect.ValueOf(&WshServerImpl) + implMethod := implVal.MethodByName(methodDecl.MethodName) + if !implMethod.IsValid() { + if !handler.NeedsResponse() { + // we also send an out of band message here since this is likely unexpected and will require debugging + handler.SendMessage(fmt.Sprintf("command %q method %q not found", handler.GetCommand(), methodDecl.MethodName)) + } + handler.SendResponseError(fmt.Errorf("method %q not found", methodDecl.MethodName)) + return true + } + if methodDecl.CommandType == wshrpc.RpcType_Call { + rtnVals := implMethod.Call(callParams) rtnData, rtnErr := decodeRtnVals(rtnVals) if rtnErr != nil { handler.SendResponseError(rtnErr) @@ -115,8 +128,8 @@ func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool { } handler.SendResponse(rtnData, true) return true - } else if methodDecl.CommandType == wshutil.RpcType_ResponseStream { - rtnVals := methodDecl.Method.Call(callParams) + } else if methodDecl.CommandType == wshrpc.RpcType_ResponseStream { + rtnVals := implMethod.Call(callParams) rtnChVal := rtnVals[0] if rtnChVal.IsNil() { handler.SendResponse(nil, true) @@ -163,7 +176,7 @@ func runWshRpcWithStream(conn net.Conn) { outputCh := make(chan []byte, DefaultOutputChSize) go wshutil.AdaptMsgChToStream(outputCh, conn) go wshutil.AdaptStreamToMsgCh(conn, inputCh) - wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler) + wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, mainWshServerHandler) } func RunWshRpcOverListener(listener net.Listener) { @@ -179,82 +192,6 @@ func RunWshRpcOverListener(listener net.Listener) { }() } -func MakeClientJWTToken(rpcCtx wshutil.RpcContext, sockName string) (string, error) { - claims := jwt.MapClaims{} - claims["iat"] = time.Now().Unix() - claims["iss"] = "waveterm" - claims["sock"] = sockName - claims["exp"] = time.Now().Add(time.Hour * 24 * 365).Unix() - if rpcCtx.BlockId != "" { - claims["blockid"] = rpcCtx.BlockId - } - if rpcCtx.TabId != "" { - claims["tabid"] = rpcCtx.TabId - } - if rpcCtx.WindowId != "" { - claims["windowid"] = rpcCtx.WindowId - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) - if err != nil { - return "", fmt.Errorf("error signing token: %w", err) - } - return tokenStr, nil -} - -func ValidateAndExtractRpcContextFromToken(tokenStr string) (wshutil.RpcContext, error) { - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, err := parser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { - return []byte(wavebase.JwtSecret), nil - }) - if err != nil { - return wshutil.RpcContext{}, fmt.Errorf("error parsing token: %w", err) - } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return wshutil.RpcContext{}, fmt.Errorf("error getting claims from token") - } - // validate "exp" claim - if exp, ok := claims["exp"].(float64); ok { - if int64(exp) < time.Now().Unix() { - return wshutil.RpcContext{}, fmt.Errorf("token has expired") - } - } else { - return wshutil.RpcContext{}, fmt.Errorf("exp claim is missing or invalid") - } - // validate "iss" claim - if iss, ok := claims["iss"].(string); ok { - if iss != "waveterm" { - return wshutil.RpcContext{}, fmt.Errorf("unexpected issuer: %s", iss) - } - } else { - return wshutil.RpcContext{}, fmt.Errorf("iss claim is missing or invalid") - } - rpcCtx := wshutil.RpcContext{} - rpcCtx.BlockId = claims["blockid"].(string) - rpcCtx.TabId = claims["tabid"].(string) - rpcCtx.WindowId = claims["windowid"].(string) - return rpcCtx, nil -} - -func ExtractUnverifiedSocketName(tokenStr string) (string, error) { - // this happens on the client who does not have access to the secret key - // we want to read the claims without validating the signature - token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) - if err != nil { - return "", fmt.Errorf("error parsing token: %w", err) - } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return "", fmt.Errorf("error getting claims from token") - } - sockName, ok := claims["sock"].(string) - if !ok { - return "", fmt.Errorf("sock claim is missing or invalid") - } - return sockName, nil -} - func RunDomainSocketWshServer() error { sockName := wavebase.GetDomainSocketName() listener, err := MakeUnixListener(sockName) @@ -266,6 +203,6 @@ func RunDomainSocketWshServer() error { return nil } -func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) { +func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) { wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler) } diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index 0ad9b0f4c..da677325a 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -15,19 +15,13 @@ import ( "time" "github.com/google/uuid" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" ) const DefaultTimeoutMs = 5000 const RespChSize = 32 const DefaultMessageChSize = 32 -const ( - RpcType_Call = "call" // single response (regular rpc) - RpcType_ResponseStream = "responsestream" // stream of responses (streaming rpc) - RpcType_StreamingRequest = "streamingrequest" // streaming request - RpcType_Complex = "complex" // streaming request/response -) - type ResponseFnType = func(any) error // returns true if handler is complete, false for an async handler @@ -115,17 +109,12 @@ func (r *RpcMessage) Validate() error { return fmt.Errorf("invalid packet: must have command, reqid, or resid set") } -type RpcContext struct { - BlockId string `json:"blockid,omitempty"` - TabId string `json:"tabid,omitempty"` - WindowId string `json:"windowid,omitempty"` -} - type WshRpc struct { Lock *sync.Mutex + clientId string InputCh chan []byte OutputCh chan []byte - RpcContext *atomic.Pointer[RpcContext] + RpcContext *atomic.Pointer[wshrpc.RpcContext] RpcMap map[string]*rpcData HandlerFn CommandHandlerFnType @@ -139,13 +128,14 @@ type rpcData struct { // oscEsc is the OSC escape sequence to use for *sending* messages // closes outputCh when inputCh is closed/done -func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx RpcContext, commandHandlerFn CommandHandlerFnType) *WshRpc { +func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcContext, commandHandlerFn CommandHandlerFnType) *WshRpc { rtn := &WshRpc{ Lock: &sync.Mutex{}, + clientId: uuid.New().String(), InputCh: inputCh, OutputCh: outputCh, RpcMap: make(map[string]*rpcData), - RpcContext: &atomic.Pointer[RpcContext]{}, + RpcContext: &atomic.Pointer[wshrpc.RpcContext]{}, HandlerFn: commandHandlerFn, ResponseHandlerMap: make(map[string]*RpcResponseHandler), } @@ -154,12 +144,30 @@ func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx RpcContext, co return rtn } -func (w *WshRpc) GetRpcContext() RpcContext { +func (w *WshRpc) ClientId() string { + return w.clientId +} + +func (w *WshRpc) SendEvent(event wshrpc.WaveEvent) { + // for wps compatibility + msg := &RpcMessage{ + Command: wshrpc.Command_EventPublish, + Data: event, + } + barr, err := json.Marshal(msg) + if err != nil { + log.Printf("error marshalling event: %v\n", err) + return + } + w.OutputCh <- barr +} + +func (w *WshRpc) GetRpcContext() wshrpc.RpcContext { rtnPtr := w.RpcContext.Load() return *rtnPtr } -func (w *WshRpc) SetRpcContext(ctx RpcContext) { +func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) { w.RpcContext.Store(&ctx) } @@ -399,7 +407,7 @@ type RpcResponseHandler struct { reqId string command string commandData any - rpcCtx RpcContext + rpcCtx wshrpc.RpcContext canceled *atomic.Bool // canceled by requestor done *atomic.Bool } @@ -416,10 +424,25 @@ func (handler *RpcResponseHandler) GetCommandRawData() any { return handler.commandData } -func (handler *RpcResponseHandler) GetRpcContext() RpcContext { +func (handler *RpcResponseHandler) GetRpcContext() wshrpc.RpcContext { return handler.rpcCtx } +func (handler *RpcResponseHandler) NeedsResponse() bool { + return handler.reqId != "" +} + +func (handler *RpcResponseHandler) SendMessage(msg string) { + rpcMsg := &RpcMessage{ + Command: wshrpc.Command_Message, + Data: wshrpc.CommandMessageData{ + Message: msg, + }, + } + msgBytes, _ := json.Marshal(rpcMsg) // will never fail + handler.w.OutputCh <- msgBytes +} + func (handler *RpcResponseHandler) SendResponse(data any, done bool) error { if handler.reqId == "" { return nil // no response expected diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index adf8b5cb8..e1919e5d8 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -14,7 +14,11 @@ import ( "sync" "sync/atomic" "syscall" + "time" + "github.com/golang-jwt/jwt/v5" + "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" "golang.org/x/term" ) @@ -180,7 +184,7 @@ func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc, messageCh := make(chan []byte, 32) outputCh := make(chan []byte, 32) ptyBuf := MakePtyBuffer(WaveServerOSCPrefix, os.Stdin, messageCh) - rpcClient := MakeWshRpc(messageCh, outputCh, RpcContext{}, handlerFn) + rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, handlerFn) go func() { for msg := range outputCh { barr := EncodeWaveOSCBytes(WaveOSC, msg) @@ -189,3 +193,79 @@ func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc, }() return rpcClient, ptyBuf } + +func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, error) { + claims := jwt.MapClaims{} + claims["iat"] = time.Now().Unix() + claims["iss"] = "waveterm" + claims["sock"] = sockName + claims["exp"] = time.Now().Add(time.Hour * 24 * 365).Unix() + if rpcCtx.BlockId != "" { + claims["blockid"] = rpcCtx.BlockId + } + if rpcCtx.TabId != "" { + claims["tabid"] = rpcCtx.TabId + } + if rpcCtx.WindowId != "" { + claims["windowid"] = rpcCtx.WindowId + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) + if err != nil { + return "", fmt.Errorf("error signing token: %w", err) + } + return tokenStr, nil +} + +func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext, error) { + parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + token, err := parser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + return []byte(wavebase.JwtSecret), nil + }) + if err != nil { + return nil, fmt.Errorf("error parsing token: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("error getting claims from token") + } + // validate "exp" claim + if exp, ok := claims["exp"].(float64); ok { + if int64(exp) < time.Now().Unix() { + return nil, fmt.Errorf("token has expired") + } + } else { + return nil, fmt.Errorf("exp claim is missing or invalid") + } + // validate "iss" claim + if iss, ok := claims["iss"].(string); ok { + if iss != "waveterm" { + return nil, fmt.Errorf("unexpected issuer: %s", iss) + } + } else { + return nil, fmt.Errorf("iss claim is missing or invalid") + } + rpcCtx := &wshrpc.RpcContext{} + rpcCtx.BlockId = claims["blockid"].(string) + rpcCtx.TabId = claims["tabid"].(string) + rpcCtx.WindowId = claims["windowid"].(string) + return rpcCtx, nil +} + +func ExtractUnverifiedSocketName(tokenStr string) (string, error) { + // this happens on the client who does not have access to the secret key + // we want to read the claims without validating the signature + token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", fmt.Errorf("error parsing token: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return "", fmt.Errorf("error getting claims from token") + } + sockName, ok := claims["sock"].(string) + if !ok { + return "", fmt.Errorf("sock claim is missing or invalid") + } + return sockName, nil +} diff --git a/yarn.lock b/yarn.lock index 2b7daee74..c426d29c9 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3546,6 +3546,31 @@ __metadata: languageName: node linkType: hard +"@types/color-convert@npm:*": + version: 2.0.3 + resolution: "@types/color-convert@npm:2.0.3" + dependencies: + "@types/color-name": "npm:*" + checksum: 10c0/a5870547660f426cddd76b54e942703e29c3b43fc26b1ba567e10b9707d144b7d8863e0af7affd9c3391815c06582571f43835c71ede270a6c58949155d18b77 + languageName: node + linkType: hard + +"@types/color-name@npm:*": + version: 1.1.4 + resolution: "@types/color-name@npm:1.1.4" + checksum: 10c0/11a5b67408a53a972fa98e4bbe2b0ff4cb74a3b3abb5f250cb5ec7b055a45aa8e00ddaf39b8327ef683ede9b2ff9b3ee9d25cd708d12b1b6a9aee5e8e6002920 + languageName: node + linkType: hard + +"@types/color@npm:^3.0.6": + version: 3.0.6 + resolution: "@types/color@npm:3.0.6" + dependencies: + "@types/color-convert": "npm:*" + checksum: 10c0/79267eeb67f9d11761aecee36bb1503fb8daa699b9ae7e036fc23a74380e5b130c5c0f6d7adafabba89256e46f36ee4d3e28e0ac7e107e8258550eae7d091acf + languageName: node + linkType: hard + "@types/connect@npm:*": version: 3.4.38 resolution: "@types/connect@npm:3.4.38" @@ -5258,7 +5283,7 @@ __metadata: languageName: node linkType: hard -"color-string@npm:^1.6.0": +"color-string@npm:^1.6.0, color-string@npm:^1.9.0": version: 1.9.1 resolution: "color-string@npm:1.9.1" dependencies: @@ -5278,6 +5303,16 @@ __metadata: languageName: node linkType: hard +"color@npm:^4.2.3": + version: 4.2.3 + resolution: "color@npm:4.2.3" + dependencies: + color-convert: "npm:^2.0.1" + color-string: "npm:^1.9.0" + checksum: 10c0/7fbe7cfb811054c808349de19fb380252e5e34e61d7d168ec3353e9e9aacb1802674bddc657682e4e9730c2786592a4de6f8283e7e0d3870b829bb0b7b2f6118 + languageName: node + linkType: hard + "colorspace@npm:1.1.x": version: 1.1.4 resolution: "colorspace@npm:1.1.4" @@ -11768,6 +11803,7 @@ __metadata: "@table-nav/core": "npm:^0.0.7" "@table-nav/react": "npm:^0.0.7" "@tanstack/react-table": "npm:^8.19.3" + "@types/color": "npm:^3.0.6" "@types/electron": "npm:^1.6.10" "@types/node": "npm:^20.14.12" "@types/papaparse": "npm:^5" @@ -11783,6 +11819,7 @@ __metadata: "@xterm/xterm": "npm:^5.5.0" base64-js: "npm:^1.5.1" clsx: "npm:^2.1.1" + color: "npm:^4.2.3" dayjs: "npm:^1.11.12" electron: "npm:^31.3.0" electron-builder: "npm:^24.13.3"