From 844451ea0de4686dcb00327700529b8ef58dd576 Mon Sep 17 00:00:00 2001 From: Mike Sawka Date: Tue, 13 Aug 2024 16:52:35 -0700 Subject: [PATCH] wsh routing + proxy (#224) lots of changes, including: * source/route to rpcmessage * rpcproxy * wshrouter * bug fixing * wps uses routeids not clients --- .../main-generatewshclient.go | 4 +- cmd/server/main-server.go | 11 +- cmd/wsh/cmd/wshcmd-deleteblock.go | 2 +- cmd/wsh/cmd/wshcmd-getmeta.go | 2 +- cmd/wsh/cmd/wshcmd-readfile.go | 2 +- cmd/wsh/cmd/wshcmd-root.go | 8 +- cmd/wsh/cmd/wshcmd-setmeta.go | 2 +- cmd/wsh/cmd/wshcmd-view.go | 2 +- frontend/app/store/wos.ts | 10 +- frontend/app/store/wshserver.ts | 50 ++--- frontend/types/gotypes.d.ts | 16 +- pkg/blockcontroller/blockcontroller.go | 17 +- pkg/tsgen/tsgen.go | 10 +- pkg/web/ws.go | 46 ++-- pkg/wps/wps.go | 105 +++++---- pkg/wshrpc/wshclient/wshclient.go | 50 ++--- pkg/wshrpc/wshclient/wshclientutil.go | 14 +- pkg/wshrpc/wshrpctypes.go | 21 +- pkg/wshrpc/wshserver/wshserver.go | 82 +++---- pkg/wshrpc/wshserver/wshserverutil.go | 107 +++------ pkg/wshutil/wshadapter.go | 34 ++- pkg/wshutil/wshproxy.go | 151 +++++++++++++ pkg/wshutil/wshrouter.go | 212 ++++++++++++++++++ pkg/wshutil/wshrpc.go | 73 +++++- pkg/wshutil/wshutil.go | 8 - 25 files changed, 716 insertions(+), 323 deletions(-) create mode 100644 pkg/wshutil/wshproxy.go create mode 100644 pkg/wshutil/wshrouter.go diff --git a/cmd/generatewshclient/main-generatewshclient.go b/cmd/generatewshclient/main-generatewshclient.go index 5c6ee5d9b..358ef90c2 100644 --- a/cmd/generatewshclient/main-generatewshclient.go +++ b/cmd/generatewshclient/main-generatewshclient.go @@ -23,7 +23,7 @@ func genMethod_ResponseStream(fd *os.File, methodDecl *wshrpc.WshRpcMethodDecl) if methodDecl.DefaultResponseDataType != nil { respType = methodDecl.DefaultResponseDataType.String() } - fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[%s] {\n", methodDecl.MethodName, dataType, respType) + fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[%s] {\n", methodDecl.MethodName, dataType, respType) fmt.Fprintf(fd, " return sendRpcRequestResponseStreamHelper[%s](w, %q, %s, opts)\n", respType, methodDecl.Command, dataVarName) fmt.Fprintf(fd, "}\n\n") } @@ -44,7 +44,7 @@ func genMethod_Call(fd *os.File, methodDecl *wshrpc.WshRpcMethodDecl) { respName = "resp" tParamVal = methodDecl.DefaultResponseDataType.String() } - fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) %s {\n", methodDecl.MethodName, dataType, returnType) + fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.RpcOpts) %s {\n", methodDecl.MethodName, dataType, returnType) fmt.Fprintf(fd, " %s, err := sendRpcRequestCallHelper[%s](w, %q, %s, opts)\n", respName, tParamVal, methodDecl.Command, dataVarName) if methodDecl.DefaultResponseDataType != nil { fmt.Fprintf(fd, " return resp, err\n") diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index 92c2d77ff..c46461357 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/wavetermdev/thenextwave/pkg/blockcontroller" "github.com/wavetermdev/thenextwave/pkg/filestore" "github.com/wavetermdev/thenextwave/pkg/service" "github.com/wavetermdev/thenextwave/pkg/telemetry" @@ -27,6 +26,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wconfig" "github.com/wavetermdev/thenextwave/pkg/web" "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshserver" + "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" ) @@ -147,11 +147,15 @@ func shutdownActivityUpdate() { } } +func createMainWshClient() { + rpc := wshserver.GetMainRpcClient() + wshutil.DefaultRouter.RegisterRoute("wavesrv", rpc) + wshutil.DefaultRouter.SetDefaultRoute("wavesrv") +} + func main() { log.SetFlags(log.LstdFlags | log.Lmicroseconds) log.SetPrefix("[wavesrv] ") - blockcontroller.WshServerFactoryFn = wshserver.MakeWshServer - web.WshServerFactoryFn = wshserver.MakeWshServer wavebase.WaveVersion = WaveVersion wavebase.BuildTime = BuildTime @@ -200,6 +204,7 @@ func main() { log.Printf("error ensuring initial data: %v\n", err) return } + createMainWshClient() installShutdownSignalHandlers() startupActivityUpdate() go stdinReadWatch() diff --git a/cmd/wsh/cmd/wshcmd-deleteblock.go b/cmd/wsh/cmd/wshcmd-deleteblock.go index cc304920b..c4e72d069 100644 --- a/cmd/wsh/cmd/wshcmd-deleteblock.go +++ b/cmd/wsh/cmd/wshcmd-deleteblock.go @@ -42,7 +42,7 @@ func deleteBlockRun(cmd *cobra.Command, args []string) { deleteBlockData := &wshrpc.CommandDeleteBlockData{ BlockId: fullORef.OID, } - _, err = RpcClient.SendRpcRequest(wshrpc.Command_DeleteBlock, deleteBlockData, 2000) + _, err = RpcClient.SendRpcRequest(wshrpc.Command_DeleteBlock, deleteBlockData, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { WriteStderr("[error] deleting block: %v\n", err) return diff --git a/cmd/wsh/cmd/wshcmd-getmeta.go b/cmd/wsh/cmd/wshcmd-getmeta.go index 1f5efabaf..522f49ca0 100644 --- a/cmd/wsh/cmd/wshcmd-getmeta.go +++ b/cmd/wsh/cmd/wshcmd-getmeta.go @@ -38,7 +38,7 @@ func getMetaRun(cmd *cobra.Command, args []string) { WriteStderr("[error] resolving oref: %v\n", err) return } - resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: *fullORef}, &wshrpc.WshRpcCommandOpts{Timeout: 2000}) + resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: *fullORef}, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { WriteStderr("[error] getting metadata: %v\n", err) return diff --git a/cmd/wsh/cmd/wshcmd-readfile.go b/cmd/wsh/cmd/wshcmd-readfile.go index 2d240301d..aaa5cee49 100644 --- a/cmd/wsh/cmd/wshcmd-readfile.go +++ b/cmd/wsh/cmd/wshcmd-readfile.go @@ -38,7 +38,7 @@ func runReadFile(cmd *cobra.Command, args []string) { WriteStderr("error resolving oref: %v\n", err) return } - resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.WshRpcCommandOpts{Timeout: 5000}) + resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.RpcOpts{Timeout: 5000}) if err != nil { WriteStderr("[error] reading file: %v\n", err) return diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index 93c14b8c4..f9880ac5a 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -39,7 +39,7 @@ func extraShutdownFn() { cmd := &wshrpc.CommandSetMetaData{ Meta: map[string]any{"term:mode": nil}, } - RpcClient.SendCommand(wshrpc.Command_SetMeta, cmd) + RpcClient.SendCommand(wshrpc.Command_SetMeta, cmd, nil) time.Sleep(10 * time.Millisecond) } } @@ -77,7 +77,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl) error { if err != nil { return fmt.Errorf("error setting up domain socket rpc client: %v", err) } - wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.WshRpcCommandOpts{NoResponse: true}) + wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.RpcOpts{NoResponse: true}) // note we don't modify WrappedStdin here (just use os.Stdin) return nil } @@ -87,7 +87,7 @@ func setTermHtmlMode() { cmd := &wshrpc.CommandSetMetaData{ Meta: map[string]any{"term:mode": "html"}, } - err := RpcClient.SendCommand(wshrpc.Command_SetMeta, cmd) + err := RpcClient.SendCommand(wshrpc.Command_SetMeta, cmd, nil) if err != nil { fmt.Fprintf(os.Stderr, "Error setting html mode: %v\r\n", err) } @@ -136,7 +136,7 @@ func resolveSimpleId(id string) (*waveobj.ORef, error) { } return &orefObj, nil } - rtnData, err := wshclient.ResolveIdsCommand(RpcClient, wshrpc.CommandResolveIdsData{Ids: []string{id}}, &wshrpc.WshRpcCommandOpts{Timeout: 2000}) + rtnData, err := wshclient.ResolveIdsCommand(RpcClient, wshrpc.CommandResolveIdsData{Ids: []string{id}}, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return nil, fmt.Errorf("error resolving ids: %v", err) } diff --git a/cmd/wsh/cmd/wshcmd-setmeta.go b/cmd/wsh/cmd/wshcmd-setmeta.go index 9dfe02191..aec41663c 100644 --- a/cmd/wsh/cmd/wshcmd-setmeta.go +++ b/cmd/wsh/cmd/wshcmd-setmeta.go @@ -83,7 +83,7 @@ func setMetaRun(cmd *cobra.Command, args []string) { ORef: *fullORef, Meta: meta, } - _, err = RpcClient.SendRpcRequest(wshrpc.Command_SetMeta, setMetaWshCmd, 2000) + _, err = RpcClient.SendRpcRequest(wshrpc.Command_SetMeta, setMetaWshCmd, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { WriteStderr("[error] setting metadata: %v\n", err) return diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index bad220168..b3dcc368e 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -64,7 +64,7 @@ func viewRun(cmd *cobra.Command, args []string) { }, } } - _, err := RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, wshCmd, 2000) + _, err := RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { WriteStderr("[error] running view command: %v\r\n", err) return diff --git a/frontend/app/store/wos.ts b/frontend/app/store/wos.ts index 19febf6a3..8f1d560a9 100644 --- a/frontend/app/store/wos.ts +++ b/frontend/app/store/wos.ts @@ -111,7 +111,7 @@ function callBackendService(service: string, method: string, args: any[], noUICo function wshServerRpcHelper_responsestream( command: string, data: any, - opts: WshRpcCommandOpts + opts: RpcOpts ): AsyncGenerator { if (opts?.noresponse) { throw new Error("noresponse not supported for responsestream calls"); @@ -124,11 +124,14 @@ function wshServerRpcHelper_responsestream( if (opts?.timeout) { msg.timeout = opts.timeout; } + if (opts?.route) { + msg.route = opts.route; + } const rpcGen = sendRpcCommand(msg); return rpcGen; } -function wshServerRpcHelper_call(command: string, data: any, opts: WshRpcCommandOpts): Promise { +function wshServerRpcHelper_call(command: string, data: any, opts: RpcOpts): Promise { const msg: RpcMessage = { command: command, data: data, @@ -139,6 +142,9 @@ function wshServerRpcHelper_call(command: string, data: any, opts: WshRpcCommand if (opts?.timeout) { msg.timeout = opts.timeout; } + if (opts?.route) { + msg.route = opts.route; + } const rpcGen = sendRpcCommand(msg); if (rpcGen == null) { return null; diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts index 6e51946e8..9cce214f6 100644 --- a/frontend/app/store/wshserver.ts +++ b/frontend/app/store/wshserver.ts @@ -8,127 +8,127 @@ import * as WOS from "./wos"; // WshServerCommandToDeclMap class WshServerType { // command "authenticate" [call] - AuthenticateCommand(data: string, opts?: WshRpcCommandOpts): Promise { + AuthenticateCommand(data: string, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("authenticate", data, opts); } // command "controllerinput" [call] - ControllerInputCommand(data: CommandBlockInputData, opts?: WshRpcCommandOpts): Promise { + ControllerInputCommand(data: CommandBlockInputData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("controllerinput", data, opts); } // command "controllerrestart" [call] - ControllerRestartCommand(data: CommandBlockRestartData, opts?: WshRpcCommandOpts): Promise { + ControllerRestartCommand(data: CommandBlockRestartData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("controllerrestart", data, opts); } // command "createblock" [call] - CreateBlockCommand(data: CommandCreateBlockData, opts?: WshRpcCommandOpts): Promise { + CreateBlockCommand(data: CommandCreateBlockData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("createblock", data, opts); } // command "deleteblock" [call] - DeleteBlockCommand(data: CommandDeleteBlockData, opts?: WshRpcCommandOpts): Promise { + DeleteBlockCommand(data: CommandDeleteBlockData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("deleteblock", data, opts); } // command "eventpublish" [call] - EventPublishCommand(data: WaveEvent, opts?: WshRpcCommandOpts): Promise { + EventPublishCommand(data: WaveEvent, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("eventpublish", data, opts); } // command "eventrecv" [call] - EventRecvCommand(data: WaveEvent, opts?: WshRpcCommandOpts): Promise { + EventRecvCommand(data: WaveEvent, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("eventrecv", data, opts); } // command "eventsub" [call] - EventSubCommand(data: SubscriptionRequest, opts?: WshRpcCommandOpts): Promise { + EventSubCommand(data: SubscriptionRequest, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("eventsub", data, opts); } // command "eventunsub" [call] - EventUnsubCommand(data: SubscriptionRequest, opts?: WshRpcCommandOpts): Promise { + EventUnsubCommand(data: SubscriptionRequest, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("eventunsub", data, opts); } // command "eventunsuball" [call] - EventUnsubAllCommand(opts?: WshRpcCommandOpts): Promise { + EventUnsubAllCommand(opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("eventunsuball", null, opts); } // command "fileappend" [call] - FileAppendCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { + FileAppendCommand(data: CommandFileData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("fileappend", data, opts); } // command "fileappendijson" [call] - FileAppendIJsonCommand(data: CommandAppendIJsonData, opts?: WshRpcCommandOpts): Promise { + FileAppendIJsonCommand(data: CommandAppendIJsonData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("fileappendijson", data, opts); } // command "fileread" [call] - FileReadCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { + FileReadCommand(data: CommandFileData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("fileread", data, opts); } // command "filewrite" [call] - FileWriteCommand(data: CommandFileData, opts?: WshRpcCommandOpts): Promise { + FileWriteCommand(data: CommandFileData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("filewrite", data, opts); } // command "getmeta" [call] - GetMetaCommand(data: CommandGetMetaData, opts?: WshRpcCommandOpts): Promise { + GetMetaCommand(data: CommandGetMetaData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("getmeta", data, opts); } // command "message" [call] - MessageCommand(data: CommandMessageData, opts?: WshRpcCommandOpts): Promise { + MessageCommand(data: CommandMessageData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("message", data, opts); } // command "remotefileinfo" [call] - RemoteFileInfoCommand(data: string, opts?: WshRpcCommandOpts): Promise { + RemoteFileInfoCommand(data: string, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("remotefileinfo", data, opts); } // command "remotestreamfile" [responsestream] - RemoteStreamFileCommand(data: CommandRemoteStreamFileData, opts?: WshRpcCommandOpts): AsyncGenerator { + RemoteStreamFileCommand(data: CommandRemoteStreamFileData, opts?: RpcOpts): AsyncGenerator { return WOS.wshServerRpcHelper_responsestream("remotestreamfile", data, opts); } // command "resolveids" [call] - ResolveIdsCommand(data: CommandResolveIdsData, opts?: WshRpcCommandOpts): Promise { + ResolveIdsCommand(data: CommandResolveIdsData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("resolveids", data, opts); } // command "setmeta" [call] - SetMetaCommand(data: CommandSetMetaData, opts?: WshRpcCommandOpts): Promise { + SetMetaCommand(data: CommandSetMetaData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("setmeta", data, opts); } // command "setview" [call] - SetViewCommand(data: CommandBlockSetViewData, opts?: WshRpcCommandOpts): Promise { + SetViewCommand(data: CommandBlockSetViewData, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("setview", data, opts); } // command "streamcpudata" [responsestream] - StreamCpuDataCommand(data: CpuDataRequest, opts?: WshRpcCommandOpts): AsyncGenerator { + StreamCpuDataCommand(data: CpuDataRequest, opts?: RpcOpts): AsyncGenerator { return WOS.wshServerRpcHelper_responsestream("streamcpudata", data, opts); } // command "streamtest" [responsestream] - StreamTestCommand(opts?: WshRpcCommandOpts): AsyncGenerator { + StreamTestCommand(opts?: RpcOpts): AsyncGenerator { return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts); } // command "streamwaveai" [responsestream] - StreamWaveAiCommand(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator { + StreamWaveAiCommand(data: OpenAiStreamRequest, opts?: RpcOpts): AsyncGenerator { return WOS.wshServerRpcHelper_responsestream("streamwaveai", data, opts); } // command "test" [call] - TestCommand(data: string, opts?: WshRpcCommandOpts): Promise { + TestCommand(data: string, opts?: RpcOpts): Promise { return WOS.wshServerRpcHelper_call("test", data, opts); } diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 1992843a6..e76c1bae7 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -130,6 +130,7 @@ declare global { // wshrpc.CommandResolveIdsData type CommandResolveIdsData = { + blockid: string; ids: string[]; }; @@ -305,6 +306,8 @@ declare global { reqid?: string; resid?: string; timeout?: number; + route?: string; + source?: string; cont?: boolean; cancel?: boolean; error?: string; @@ -312,6 +315,13 @@ declare global { data?: any; }; + // wshrpc.RpcOpts + type RpcOpts = { + timeout?: number; + noresponse?: boolean; + route?: string; + }; + // wstore.RuntimeOpts type RuntimeOpts = { termsize?: TermSize; @@ -607,12 +617,6 @@ declare global { tabids: string[]; }; - // wshrpc.WshRpcCommandOpts - type WshRpcCommandOpts = { - timeout: number; - noresponse: boolean; - }; - // wshrpc.WshServerCommandMeta type WshServerCommandMeta = { commandtype: string; diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 36693bdfe..b1657345f 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -27,9 +27,6 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wstore" ) -// set by main-server.go (for dependency inversion) -var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) = nil - const ( BlockController_Shell = "shell" BlockController_Cmd = "cmd" @@ -327,10 +324,13 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj }) shellInputCh := make(chan *BlockInputUnion, 32) bc.ShellInputCh = shellInputCh - messageCh := make(chan []byte, 32) - ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, messageCh) - outputCh := make(chan []byte, 32) - WshServerFactoryFn(messageCh, outputCh, wshrpc.RpcContext{BlockId: bc.BlockId, TabId: bc.TabId}) + + // make esc sequence wshclient wshProxy + // we don't need to authenticate this wshProxy since it is coming direct + wshProxy := wshutil.MakeRpcProxy() + wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}) + wshutil.DefaultRouter.RegisterRoute("controller:"+bc.BlockId, wshProxy) + ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh) go func() { // handles regular output from the pty (goes to the blockfile and xterm) defer func() { @@ -380,7 +380,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj }() go func() { // handles outputCh -> shellInputCh - for msg := range outputCh { + for msg := range wshProxy.ToRemoteCh { encodedMsg := wshutil.EncodeWaveOSCBytes(wshutil.WaveServerOSC, msg) shellInputCh <- &BlockInputUnion{InputData: encodedMsg} } @@ -388,6 +388,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj go func() { // wait for the shell to finish defer func() { + wshutil.DefaultRouter.UnregisterRoute("controller:" + bc.BlockId) bc.UpdateControllerAndSendUpdate(func() bool { bc.ShellProcStatus = Status_Done return true diff --git a/pkg/tsgen/tsgen.go b/pkg/tsgen/tsgen.go index fc1ff0b60..9f4be1009 100644 --- a/pkg/tsgen/tsgen.go +++ b/pkg/tsgen/tsgen.go @@ -425,9 +425,9 @@ func GenerateWshServerMethod_ResponseStream(methodDecl *wshrpc.WshRpcMethodDecl, } genRespType := fmt.Sprintf("AsyncGenerator<%s, void, boolean>", respType) if methodDecl.CommandDataType != nil { - sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), genRespType)) + sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: RpcOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), genRespType)) } else { - sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, genRespType)) + sb.WriteString(fmt.Sprintf(" %s(opts?: RpcOpts): %s {\n", methodDecl.MethodName, genRespType)) } sb.WriteString(fmt.Sprintf(" return WOS.wshServerRpcHelper_responsestream(%q, %s, opts);\n", methodDecl.Command, dataName)) sb.WriteString(" }\n") @@ -447,9 +447,9 @@ func GenerateWshServerMethod_Call(methodDecl *wshrpc.WshRpcMethodDecl, tsTypesMa dataName = "data" } if methodDecl.CommandDataType != nil { - sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), rtnType)) + sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: RpcOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), rtnType)) } else { - sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, rtnType)) + sb.WriteString(fmt.Sprintf(" %s(opts?: RpcOpts): %s {\n", methodDecl.MethodName, rtnType)) } methodBody := fmt.Sprintf(" return WOS.wshServerRpcHelper_call(%q, %s, opts);\n", methodDecl.Command, dataName) sb.WriteString(methodBody) @@ -484,7 +484,7 @@ func GenerateServiceTypes(tsTypesMap map[reflect.Type]string) error { } func GenerateWshServerTypes(tsTypesMap map[reflect.Type]string) error { - GenerateTSType(reflect.TypeOf(wshrpc.WshRpcCommandOpts{}), tsTypesMap) + GenerateTSType(reflect.TypeOf(wshrpc.RpcOpts{}), tsTypesMap) rtype := wshRpcInterfaceRType for midx := 0; midx < rtype.NumMethod(); midx++ { method := rtype.Method(midx) diff --git a/pkg/web/ws.go b/pkg/web/ws.go index cddbdda9e..324aa3d53 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -22,9 +22,6 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wshutil" ) -// set by main-server.go (for dependency inversion) -var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) = nil - const wsReadWaitTimeout = 15 * time.Second const wsWriteWaitTimeout = 10 * time.Second const wsPingPeriodTickTime = 10 * time.Second @@ -148,31 +145,10 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan [] func processMessage(jmsg map[string]any, outputCh chan any, rpcInputCh chan []byte) { wsCommand := getStringFromMap(jmsg, "wscommand") - if wsCommand != "" { - processWSCommand(jmsg, outputCh, rpcInputCh) + if wsCommand == "" { return } - msgType := getMessageType(jmsg) - if msgType != "rpc" { - return - } - reqId := getStringFromMap(jmsg, "reqid") - var rtnErr error - defer func() { - r := recover() - if r != nil { - rtnErr = fmt.Errorf("panic: %v", r) - log.Printf("panic in processMessage: %v\n", r) - debug.PrintStack() - } - if rtnErr == nil { - return - } - rtn := map[string]any{"type": "rpcresp", "reqid": reqId, "error": rtnErr.Error()} - outputCh <- rtn - }() - method := getStringFromMap(jmsg, "method") - rtnErr = fmt.Errorf("unknown method %q", method) + processWSCommand(jmsg, outputCh, rpcInputCh) } func ReadLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, rpcInputCh chan []byte) { @@ -277,17 +253,23 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { log.Printf("New websocket connection: windowid:%s connid:%s\n", windowId, wsConnId) outputCh := make(chan any, 100) closeCh := make(chan any) - rpcInputCh := make(chan []byte, 32) - rpcOutputCh := make(chan []byte, 32) eventbus.RegisterWSChannel(wsConnId, windowId, outputCh) defer eventbus.UnregisterWSChannel(wsConnId) - WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{WindowId: windowId}) + // we create a wshproxy to handle rpc messages to/from the window + wproxy := wshutil.MakeRpcProxy() + rpcRouteId := "window:" + windowId + wshutil.DefaultRouter.RegisterRoute(rpcRouteId, wproxy) + defer func() { + wshutil.DefaultRouter.UnregisterRoute(rpcRouteId) + close(wproxy.ToRemoteCh) + }() + // WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{}) wg := &sync.WaitGroup{} wg.Add(2) go func() { // no waitgroup add here // move values from rpcOutputCh to outputCh - for msgBytes := range rpcOutputCh { + for msgBytes := range wproxy.ToRemoteCh { rpcWSMsg := map[string]any{ "eventtype": "rpc", // TODO don't hard code this (but def is in eventbus) "data": json.RawMessage(msgBytes), @@ -298,7 +280,7 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { go func() { // read loop defer wg.Done() - ReadLoop(conn, outputCh, closeCh, rpcInputCh) + ReadLoop(conn, outputCh, closeCh, wproxy.FromRemoteCh) }() go func() { // write loop @@ -306,6 +288,6 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { WriteLoop(conn, outputCh, closeCh) }() wg.Wait() - close(rpcInputCh) + close(wproxy.FromRemoteCh) return nil } diff --git a/pkg/wps/wps.go b/pkg/wps/wps.go index 437a640d5..058235eb7 100644 --- a/pkg/wps/wps.go +++ b/pkg/wps/wps.go @@ -16,26 +16,24 @@ import ( // strong typing and event types can be defined elsewhere type Client interface { - ClientId() string - SendEvent(event wshrpc.WaveEvent) + SendEvent(routeId string, event wshrpc.WaveEvent) } type BrokerSubscription struct { - AllSubs []string // clientids of client subscribed to "all" events - ScopeSubs map[string][]string // clientids of client subscribed to specific scopes - StarSubs map[string][]string // clientids of client subscribed to star scope (scopes with "*" or "**" in them) + AllSubs []string // routeids subscribed to "all" events + ScopeSubs map[string][]string // routeids subscribed to specific scopes + StarSubs map[string][]string // routeids subscribed to star scope (scopes with "*" or "**" in them) } type BrokerType struct { - Lock *sync.Mutex - ClientMap map[string]Client - SubMap map[string]*BrokerSubscription + Lock *sync.Mutex + Client Client + SubMap map[string]*BrokerSubscription } var Broker = &BrokerType{ - Lock: &sync.Mutex{}, - ClientMap: make(map[string]Client), - SubMap: make(map[string]*BrokerSubscription), + Lock: &sync.Mutex{}, + SubMap: make(map[string]*BrokerSubscription), } func scopeHasStarMatch(scope string) bool { @@ -48,10 +46,21 @@ func scopeHasStarMatch(scope string) bool { return false } -func (b *BrokerType) Subscribe(subscriber Client, sub wshrpc.SubscriptionRequest) { +func (b *BrokerType) SetClient(client Client) { + b.Lock.Lock() + defer b.Lock.Unlock() + b.Client = client +} + +func (b *BrokerType) GetClient() Client { + b.Lock.Lock() + defer b.Lock.Unlock() + return b.Client +} + +func (b *BrokerType) Subscribe(subRouteId string, sub wshrpc.SubscriptionRequest) { b.Lock.Lock() defer b.Lock.Unlock() - clientId := subscriber.ClientId() bs := b.SubMap[sub.Event] if bs == nil { bs = &BrokerSubscription{ @@ -62,14 +71,14 @@ func (b *BrokerType) Subscribe(subscriber Client, sub wshrpc.SubscriptionRequest b.SubMap[sub.Event] = bs } if sub.AllScopes { - bs.AllSubs = utilfn.AddElemToSliceUniq(bs.AllSubs, clientId) + bs.AllSubs = utilfn.AddElemToSliceUniq(bs.AllSubs, subRouteId) } for _, scope := range sub.Scopes { starMatch := scopeHasStarMatch(scope) if starMatch { - addStrToScopeMap(bs.StarSubs, scope, clientId) + addStrToScopeMap(bs.StarSubs, scope, subRouteId) } else { - addStrToScopeMap(bs.ScopeSubs, scope, clientId) + addStrToScopeMap(bs.ScopeSubs, scope, subRouteId) } } } @@ -78,9 +87,9 @@ func (bs *BrokerSubscription) IsEmpty() bool { return len(bs.AllSubs) == 0 && len(bs.ScopeSubs) == 0 && len(bs.StarSubs) == 0 } -func removeStrFromScopeMap(scopeMap map[string][]string, scope string, clientId string) { +func removeStrFromScopeMap(scopeMap map[string][]string, scope string, routeId string) { scopeSubs := scopeMap[scope] - scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, clientId) + scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, routeId) if len(scopeSubs) == 0 { delete(scopeMap, scope) } else { @@ -88,9 +97,9 @@ func removeStrFromScopeMap(scopeMap map[string][]string, scope string, clientId } } -func removeStrFromScopeMapAll(scopeMap map[string][]string, clientId string) { +func removeStrFromScopeMapAll(scopeMap map[string][]string, routeId string) { for scope, scopeSubs := range scopeMap { - scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, clientId) + scopeSubs = utilfn.RemoveElemFromSlice(scopeSubs, routeId) if len(scopeSubs) == 0 { delete(scopeMap, scope) } else { @@ -99,29 +108,28 @@ func removeStrFromScopeMapAll(scopeMap map[string][]string, clientId string) { } } -func addStrToScopeMap(scopeMap map[string][]string, scope string, clientId string) { +func addStrToScopeMap(scopeMap map[string][]string, scope string, routeId string) { scopeSubs := scopeMap[scope] - scopeSubs = utilfn.AddElemToSliceUniq(scopeSubs, clientId) + scopeSubs = utilfn.AddElemToSliceUniq(scopeSubs, routeId) scopeMap[scope] = scopeSubs } -func (b *BrokerType) Unsubscribe(subscriber Client, sub wshrpc.SubscriptionRequest) { +func (b *BrokerType) Unsubscribe(subRouteId string, sub wshrpc.SubscriptionRequest) { b.Lock.Lock() defer b.Lock.Unlock() - clientId := subscriber.ClientId() bs := b.SubMap[sub.Event] if bs == nil { return } if sub.AllScopes { - bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, clientId) + bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, subRouteId) } for _, scope := range sub.Scopes { starMatch := scopeHasStarMatch(scope) if starMatch { - removeStrFromScopeMap(bs.StarSubs, scope, clientId) + removeStrFromScopeMap(bs.StarSubs, scope, subRouteId) } else { - removeStrFromScopeMap(bs.ScopeSubs, scope, clientId) + removeStrFromScopeMap(bs.ScopeSubs, scope, subRouteId) } } if bs.IsEmpty() { @@ -129,15 +137,13 @@ func (b *BrokerType) Unsubscribe(subscriber Client, sub wshrpc.SubscriptionReque } } -func (b *BrokerType) UnsubscribeAll(subscriber Client) { +func (b *BrokerType) UnsubscribeAll(subRouteId string) { b.Lock.Lock() defer b.Lock.Unlock() - clientId := subscriber.ClientId() - delete(b.ClientMap, clientId) for eventType, bs := range b.SubMap { - bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, clientId) - removeStrFromScopeMapAll(bs.StarSubs, clientId) - removeStrFromScopeMapAll(bs.ScopeSubs, clientId) + bs.AllSubs = utilfn.RemoveElemFromSlice(bs.AllSubs, subRouteId) + removeStrFromScopeMapAll(bs.StarSubs, subRouteId) + removeStrFromScopeMapAll(bs.ScopeSubs, subRouteId) if bs.IsEmpty() { delete(b.SubMap, eventType) } @@ -145,41 +151,42 @@ func (b *BrokerType) UnsubscribeAll(subscriber Client) { } func (b *BrokerType) Publish(event wshrpc.WaveEvent) { - clientIds := b.getMatchingClientIds(event) - for _, clientId := range clientIds { - client := b.ClientMap[clientId] - if client != nil { - client.SendEvent(event) - } + client := b.GetClient() + if client == nil { + return + } + routeIds := b.getMatchingRouteIds(event) + for _, routeId := range routeIds { + client.SendEvent(routeId, event) } } -func (b *BrokerType) getMatchingClientIds(event wshrpc.WaveEvent) []string { +func (b *BrokerType) getMatchingRouteIds(event wshrpc.WaveEvent) []string { b.Lock.Lock() defer b.Lock.Unlock() bs := b.SubMap[event.Event] if bs == nil { return nil } - clientIds := make(map[string]bool) - for _, clientId := range bs.AllSubs { - clientIds[clientId] = true + routeIds := make(map[string]bool) + for _, routeId := range bs.AllSubs { + routeIds[routeId] = true } for _, scope := range event.Scopes { - for _, clientId := range bs.ScopeSubs[scope] { - clientIds[clientId] = true + for _, routeId := range bs.ScopeSubs[scope] { + routeIds[routeId] = true } for starScope := range bs.StarSubs { if utilfn.StarMatchString(starScope, scope, ":") { - for _, clientId := range bs.StarSubs[starScope] { - clientIds[clientId] = true + for _, routeId := range bs.StarSubs[starScope] { + routeIds[routeId] = true } } } } var rtn []string - for clientId := range clientIds { - rtn = append(rtn, clientId) + for routeId := range routeIds { + rtn = append(rtn, routeId) } return rtn } diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 8481e0482..5f6fb2535 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -12,147 +12,147 @@ import ( ) // command "authenticate", wshserver.AuthenticateCommand -func AuthenticateCommand(w *wshutil.WshRpc, data string, opts *wshrpc.WshRpcCommandOpts) error { +func AuthenticateCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "authenticate", data, opts) return err } // command "controllerinput", wshserver.ControllerInputCommand -func ControllerInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.WshRpcCommandOpts) error { +func ControllerInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "controllerinput", data, opts) return err } // command "controllerrestart", wshserver.ControllerRestartCommand -func ControllerRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.WshRpcCommandOpts) error { +func ControllerRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "controllerrestart", data, opts) return err } // command "createblock", wshserver.CreateBlockCommand -func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.WshRpcCommandOpts) (waveobj.ORef, error) { +func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.RpcOpts) (waveobj.ORef, error) { resp, err := sendRpcRequestCallHelper[waveobj.ORef](w, "createblock", data, opts) return resp, err } // command "deleteblock", wshserver.DeleteBlockCommand -func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, opts *wshrpc.WshRpcCommandOpts) error { +func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "deleteblock", data, opts) return err } // command "eventpublish", wshserver.EventPublishCommand -func EventPublishCommand(w *wshutil.WshRpc, data wshrpc.WaveEvent, opts *wshrpc.WshRpcCommandOpts) error { +func EventPublishCommand(w *wshutil.WshRpc, data wshrpc.WaveEvent, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts) return err } // command "eventrecv", wshserver.EventRecvCommand -func EventRecvCommand(w *wshutil.WshRpc, data wshrpc.WaveEvent, opts *wshrpc.WshRpcCommandOpts) error { +func EventRecvCommand(w *wshutil.WshRpc, data wshrpc.WaveEvent, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventrecv", data, opts) return err } // command "eventsub", wshserver.EventSubCommand -func EventSubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.WshRpcCommandOpts) error { +func EventSubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventsub", data, opts) return err } // command "eventunsub", wshserver.EventUnsubCommand -func EventUnsubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.WshRpcCommandOpts) error { +func EventUnsubCommand(w *wshutil.WshRpc, data wshrpc.SubscriptionRequest, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventunsub", data, opts) return err } // command "eventunsuball", wshserver.EventUnsubAllCommand -func EventUnsubAllCommand(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) error { +func EventUnsubAllCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventunsuball", nil, opts) return err } // command "fileappend", wshserver.FileAppendCommand -func FileAppendCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) error { +func FileAppendCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "fileappend", data, opts) return err } // command "fileappendijson", wshserver.FileAppendIJsonCommand -func FileAppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.WshRpcCommandOpts) error { +func FileAppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "fileappendijson", data, opts) return err } // command "fileread", wshserver.FileReadCommand -func FileReadCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) (string, error) { +func FileReadCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.RpcOpts) (string, error) { resp, err := sendRpcRequestCallHelper[string](w, "fileread", data, opts) return resp, err } // command "filewrite", wshserver.FileWriteCommand -func FileWriteCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.WshRpcCommandOpts) error { +func FileWriteCommand(w *wshutil.WshRpc, data wshrpc.CommandFileData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "filewrite", data, opts) return err } // command "getmeta", wshserver.GetMetaCommand -func GetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandGetMetaData, opts *wshrpc.WshRpcCommandOpts) (waveobj.MetaMapType, error) { +func GetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandGetMetaData, opts *wshrpc.RpcOpts) (waveobj.MetaMapType, error) { resp, err := sendRpcRequestCallHelper[waveobj.MetaMapType](w, "getmeta", data, opts) return resp, err } // command "message", wshserver.MessageCommand -func MessageCommand(w *wshutil.WshRpc, data wshrpc.CommandMessageData, opts *wshrpc.WshRpcCommandOpts) error { +func MessageCommand(w *wshutil.WshRpc, data wshrpc.CommandMessageData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "message", data, opts) return err } // command "remotefileinfo", wshserver.RemoteFileInfoCommand -func RemoteFileInfoCommand(w *wshutil.WshRpc, data string, opts *wshrpc.WshRpcCommandOpts) (*wshrpc.FileInfo, error) { +func RemoteFileInfoCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) (*wshrpc.FileInfo, error) { resp, err := sendRpcRequestCallHelper[*wshrpc.FileInfo](w, "remotefileinfo", data, opts) return resp, err } // command "remotestreamfile", wshserver.RemoteStreamFileCommand -func RemoteStreamFileCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamFileData, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteStreamFileRtnData] { +func RemoteStreamFileCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamFileData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteStreamFileRtnData] { return sendRpcRequestResponseStreamHelper[wshrpc.CommandRemoteStreamFileRtnData](w, "remotestreamfile", data, opts) } // command "resolveids", wshserver.ResolveIdsCommand -func ResolveIdsCommand(w *wshutil.WshRpc, data wshrpc.CommandResolveIdsData, opts *wshrpc.WshRpcCommandOpts) (wshrpc.CommandResolveIdsRtnData, error) { +func ResolveIdsCommand(w *wshutil.WshRpc, data wshrpc.CommandResolveIdsData, opts *wshrpc.RpcOpts) (wshrpc.CommandResolveIdsRtnData, error) { resp, err := sendRpcRequestCallHelper[wshrpc.CommandResolveIdsRtnData](w, "resolveids", data, opts) return resp, err } // command "setmeta", wshserver.SetMetaCommand -func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wshrpc.WshRpcCommandOpts) error { +func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "setmeta", data, opts) return err } // command "setview", wshserver.SetViewCommand -func SetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.WshRpcCommandOpts) error { +func SetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "setview", data, opts) return err } // command "streamcpudata", wshserver.StreamCpuDataCommand -func StreamCpuDataCommand(w *wshutil.WshRpc, data wshrpc.CpuDataRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[wshrpc.CpuDataType] { +func StreamCpuDataCommand(w *wshutil.WshRpc, data wshrpc.CpuDataRequest, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.CpuDataType] { return sendRpcRequestResponseStreamHelper[wshrpc.CpuDataType](w, "streamcpudata", data, opts) } // command "streamtest", wshserver.StreamTestCommand -func StreamTestCommand(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] { +func StreamTestCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[int] { return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts) } // command "streamwaveai", wshserver.StreamWaveAiCommand -func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { +func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.OpenAiStreamRequest, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { return sendRpcRequestResponseStreamHelper[wshrpc.OpenAIPacketType](w, "streamwaveai", data, opts) } // command "test", wshserver.TestCommand -func TestCommand(w *wshutil.WshRpc, data string, opts *wshrpc.WshRpcCommandOpts) error { +func TestCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "test", data, opts) return err } diff --git a/pkg/wshrpc/wshclient/wshclientutil.go b/pkg/wshrpc/wshclient/wshclientutil.go index eba36e7fd..6ac547a8d 100644 --- a/pkg/wshrpc/wshclient/wshclientutil.go +++ b/pkg/wshrpc/wshclient/wshclientutil.go @@ -9,19 +9,19 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wshutil" ) -func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) (T, error) { +func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) (T, error) { if opts == nil { - opts = &wshrpc.WshRpcCommandOpts{} + opts = &wshrpc.RpcOpts{} } var respData T if opts.NoResponse { - err := w.SendCommand(command, data) + err := w.SendCommand(command, data, opts) if err != nil { return respData, err } return respData, nil } - resp, err := w.SendRpcRequest(command, data, opts.Timeout) + resp, err := w.SendRpcRequest(command, data, opts) if err != nil { return respData, err } @@ -32,12 +32,12 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int return respData, nil } -func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[T] { +func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] { if opts == nil { - opts = &wshrpc.WshRpcCommandOpts{} + opts = &wshrpc.RpcOpts{} } respChan := make(chan wshrpc.RespOrErrorUnion[T]) - reqHandler, err := w.SendComplexRequest(command, data, true, opts.Timeout) + reqHandler, err := w.SendComplexRequest(command, data, opts) if err != nil { go func() { respChan <- wshrpc.RespOrErrorUnion[T]{Error: err} diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index ade1325b2..4720f3f57 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -6,6 +6,7 @@ package wshrpc import ( "context" + "log" "os" "reflect" @@ -57,6 +58,7 @@ type RespOrErrorUnion[T any] struct { type WshRpcInterface interface { AuthenticateCommand(ctx context.Context, data string) error + MessageCommand(ctx context.Context, data CommandMessageData) error GetMetaCommand(ctx context.Context, data CommandGetMetaData) (wstore.MetaMapType, error) SetMetaCommand(ctx context.Context, data CommandSetMetaData) error @@ -90,15 +92,15 @@ type WshServerCommandMeta struct { CommandType string `json:"commandtype"` } -type WshRpcCommandOpts struct { - Timeout int `json:"timeout"` - NoResponse bool `json:"noresponse"` +type RpcOpts struct { + Timeout int `json:"timeout,omitempty"` + NoResponse bool `json:"noresponse,omitempty"` + Route string `json:"route,omitempty"` } type RpcContext struct { - BlockId string `json:"blockid,omitempty"` - TabId string `json:"tabid,omitempty"` - WindowId string `json:"windowid,omitempty"` + BlockId string `json:"blockid,omitempty"` + TabId string `json:"tabid,omitempty"` } func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { @@ -122,12 +124,12 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { field.SetString(rpcContext.BlockId) case "TabId": field.SetString(rpcContext.TabId) - case "WindowId": - field.SetString(rpcContext.WindowId) case "BlockORef": if rpcContext.BlockId != "" { field.Set(reflect.ValueOf(waveobj.MakeORef(wstore.OType_Block, rpcContext.BlockId))) } + default: + log.Printf("invalid wshcontext tag: %q in type(%T)", tag, dataPtr) } } } @@ -147,7 +149,8 @@ type CommandSetMetaData struct { } type CommandResolveIdsData struct { - Ids []string `json:"ids"` + BlockId string `json:"blockid" wshcontext:"BlockId"` + Ids []string `json:"ids"` } type CommandResolveIdsRtnData struct { diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 9dfd8cdcc..e1176a509 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -30,26 +30,33 @@ import ( const SimpleId_This = "this" +type WshServer struct{} + +func (*WshServer) WshServerImpl() {} + +var WshServerImpl = WshServer{} + func (ws *WshServer) TestCommand(ctx context.Context, data string) error { defer func() { if r := recover(); r != nil { log.Printf("panic in TestCommand: %v", r) } }() - rpc := wshutil.GetWshRpcFromContext(ctx) - if rpc == nil { + rpcSource := wshutil.GetRpcSourceFromContext(ctx) + log.Printf("TEST src:%s | %s\n", rpcSource, data) + if rpcSource == "" { return nil } go func() { - wshclient.MessageCommand(rpc, wshrpc.CommandMessageData{Message: "test message"}, &wshrpc.WshRpcCommandOpts{NoResponse: true}) - resp, err := wshclient.RemoteFileInfoCommand(rpc, "~/work/wails/thenextwave/README.md", nil) + mainClient := GetMainRpcClient() + wshclient.MessageCommand(mainClient, wshrpc.CommandMessageData{Message: "test message"}, &wshrpc.RpcOpts{NoResponse: true, Route: rpcSource}) + resp, err := wshclient.RemoteFileInfoCommand(mainClient, "~/work/wails/thenextwave/README.md", &wshrpc.RpcOpts{Route: rpcSource}) if err != nil { log.Printf("error getting remote file info: %v", err) return } log.Printf("remote file info: %#v\n", resp) - - rch := wshclient.RemoteStreamFileCommand(rpc, wshrpc.CommandRemoteStreamFileData{Path: "~/work/wails/thenextwave/README.md"}, nil) + rch := wshclient.RemoteStreamFileCommand(mainClient, wshrpc.CommandRemoteStreamFileData{Path: "~/work/wails/thenextwave/README.md"}, &wshrpc.RpcOpts{Route: rpcSource}) for msg := range rch { if msg.Error != nil { log.Printf("error in stream: %v", msg.Error) @@ -66,22 +73,6 @@ func (ws *WshServer) TestCommand(ctx context.Context, data string) error { return nil } -func (ws *WshServer) AuthenticateCommand(ctx context.Context, data string) error { - w := wshutil.GetWshRpcFromContext(ctx) - if w == nil { - return fmt.Errorf("no wshrpc in context") - } - newCtx, err := wshutil.ValidateAndExtractRpcContextFromToken(data) - if err != nil { - return fmt.Errorf("error validating token: %w", err) - } - if newCtx == nil { - return fmt.Errorf("no context found in jwt token") - } - w.SetRpcContext(*newCtx) - return nil -} - // for testing func (ws *WshServer) MessageCommand(ctx context.Context, data wshrpc.CommandMessageData) error { log.Printf("MESSAGE: %s | %q\n", data.ORef, data.Message) @@ -228,17 +219,12 @@ func sendWaveObjUpdate(oref waveobj.ORef) { }) } -func resolveSimpleId(ctx context.Context, simpleId string) (*waveobj.ORef, error) { +func resolveSimpleId(ctx context.Context, data wshrpc.CommandResolveIdsData, simpleId string) (*waveobj.ORef, error) { if simpleId == SimpleId_This { - wshRpc := wshutil.GetWshRpcFromContext(ctx) - if wshRpc == nil { - return nil, fmt.Errorf("no wshrpc in context") + if data.BlockId == "" { + return nil, fmt.Errorf("no blockid in request") } - rpcCtx := wshRpc.GetRpcContext() - if rpcCtx.BlockId == "" { - return nil, fmt.Errorf("no blockid in rpc context") - } - return &waveobj.ORef{OType: wstore.OType_Block, OID: rpcCtx.BlockId}, nil + return &waveobj.ORef{OType: wstore.OType_Block, OID: data.BlockId}, nil } if strings.Contains(simpleId, ":") { rtn, err := waveobj.ParseORef(simpleId) @@ -254,7 +240,7 @@ func (ws *WshServer) ResolveIdsCommand(ctx context.Context, data wshrpc.CommandR rtn := wshrpc.CommandResolveIdsRtnData{} rtn.ResolvedIds = make(map[string]waveobj.ORef) for _, simpleId := range data.Ids { - oref, err := resolveSimpleId(ctx, simpleId) + oref, err := resolveSimpleId(ctx, data, simpleId) if err != nil || oref == nil { continue } @@ -471,40 +457,40 @@ func (ws *WshServer) EventRecvCommand(ctx context.Context, data wshrpc.WaveEvent } func (ws *WshServer) EventPublishCommand(ctx context.Context, data wshrpc.WaveEvent) error { - wrpc := wshutil.GetWshRpcFromContext(ctx) - if wrpc == nil { - return fmt.Errorf("no wshrpc in context") + rpcSource := wshutil.GetRpcSourceFromContext(ctx) + if rpcSource == "" { + return fmt.Errorf("no rpc source set") } if data.Sender == "" { - data.Sender = wrpc.ClientId() + data.Sender = rpcSource } wps.Broker.Publish(data) return nil } func (ws *WshServer) EventSubCommand(ctx context.Context, data wshrpc.SubscriptionRequest) error { - wrpc := wshutil.GetWshRpcFromContext(ctx) - if wrpc == nil { - return fmt.Errorf("no wshrpc in context") + rpcSource := wshutil.GetRpcSourceFromContext(ctx) + if rpcSource == "" { + return fmt.Errorf("no rpc source set") } - wps.Broker.Subscribe(wrpc, data) + wps.Broker.Subscribe(rpcSource, data) return nil } func (ws *WshServer) EventUnsubCommand(ctx context.Context, data wshrpc.SubscriptionRequest) error { - wrpc := wshutil.GetWshRpcFromContext(ctx) - if wrpc == nil { - return fmt.Errorf("no wshrpc in context") + rpcSource := wshutil.GetRpcSourceFromContext(ctx) + if rpcSource == "" { + return fmt.Errorf("no rpc source set") } - wps.Broker.Unsubscribe(wrpc, data) + wps.Broker.Unsubscribe(rpcSource, data) return nil } func (ws *WshServer) EventUnsubAllCommand(ctx context.Context) error { - wrpc := wshutil.GetWshRpcFromContext(ctx) - if wrpc == nil { - return fmt.Errorf("no wshrpc in context") + rpcSource := wshutil.GetRpcSourceFromContext(ctx) + if rpcSource == "" { + return fmt.Errorf("no rpc source set") } - wps.Broker.UnsubscribeAll(wrpc) + wps.Broker.UnsubscribeAll(rpcSource) return nil } diff --git a/pkg/wshrpc/wshserver/wshserverutil.go b/pkg/wshrpc/wshserver/wshserverutil.go index 565ab6391..64a1da9f5 100644 --- a/pkg/wshrpc/wshserver/wshserverutil.go +++ b/pkg/wshrpc/wshserver/wshserverutil.go @@ -4,89 +4,42 @@ package wshserver import ( - "context" - "fmt" "log" "net" - "reflect" + "sync" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" ) -// this file contains the generic types and functions that create and power the WSH server - const ( DefaultOutputChSize = 32 DefaultInputChSize = 32 ) -type WshServer struct{} - -func (*WshServer) WshServerImpl() {} - -type WshServerMethodDecl struct { - Command string - CommandType string - MethodName string - Method reflect.Value - CommandDataType reflect.Type - DefaultResponseDataType reflect.Type - RequestDataTypes []reflect.Type // for streaming requests - ResponseDataTypes []reflect.Type // for streaming responses -} - -var WshServerImpl = WshServer{} -var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem() -var wshCommandDeclMap = wshrpc.GenerateWshCommandDeclMap() - -func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl { - methodVal := reflect.ValueOf(methodFunc) - methodType := methodVal.Type() - if methodType.Kind() != reflect.Func { - panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType)) - } - if methodType.In(0) != contextRType { - panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType)) - } - var defResponseType reflect.Type - if methodType.NumOut() > 1 { - defResponseType = methodType.Out(0) - } - var cdataType reflect.Type - if methodType.NumIn() > 1 { - cdataType = methodType.In(1) - } - rtn := &WshServerMethodDecl{ - Command: command, - CommandType: commandType, - MethodName: methodName, - Method: methodVal, - CommandDataType: cdataType, - DefaultResponseDataType: defResponseType, - } - return rtn -} - -func decodeRtnVals(rtnVals []reflect.Value) (any, error) { - switch len(rtnVals) { - case 0: - return nil, nil - case 1: - errIf := rtnVals[0].Interface() - if errIf == nil { - return nil, nil +func handleDomainSocketClient(conn net.Conn) { + proxy := wshutil.MakeRpcProxy() + go func() { + writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn) + if writeErr != nil { + log.Printf("error writing to domain socket: %v\n", writeErr) } - return nil, errIf.(error) - case 2: - errIf := rtnVals[1].Interface() - if errIf == nil { - return rtnVals[0].Interface(), nil - } - return rtnVals[0].Interface(), errIf.(error) - default: - return nil, fmt.Errorf("too many return values: %d", len(rtnVals)) + }() + go func() { + // when input is closed, close the connection + defer conn.Close() + wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh) + }() + rpcCtx, err := proxy.HandleAuthentication() + if err != nil { + conn.Close() + log.Printf("error handling authentication: %v\n", err) + return } + // now that we're authenticated, set the ctx and attach to the router + log.Printf("domain socket connection authenticated: %#v\n", rpcCtx) + proxy.SetRpcContext(rpcCtx) + wshutil.DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy) } func RunWshRpcOverListener(listener net.Listener) { @@ -98,11 +51,19 @@ func RunWshRpcOverListener(listener net.Listener) { continue } log.Print("got domain socket connection\n") - // TODO deal with closing connection - go wshutil.SetupConnRpcClient(conn, &WshServerImpl) + go handleDomainSocketClient(conn) } } -func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) { - wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, &WshServerImpl) +var waveSrvClient_Singleton *wshutil.WshRpc +var waveSrvClient_Once = &sync.Once{} + +// returns the wavesrv main rpc client singleton +func GetMainRpcClient() *wshutil.WshRpc { + waveSrvClient_Once.Do(func() { + inputCh := make(chan []byte, DefaultInputChSize) + outputCh := make(chan []byte, DefaultOutputChSize) + waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl) + }) + return waveSrvClient_Singleton } diff --git a/pkg/wshutil/wshadapter.go b/pkg/wshutil/wshadapter.go index b99ef2c11..b1eac1471 100644 --- a/pkg/wshutil/wshadapter.go +++ b/pkg/wshutil/wshadapter.go @@ -52,6 +52,31 @@ func noImplHandler(handler *RpcResponseHandler) bool { return true } +func recodeCommandData(command string, data any, rpcCtx *wshrpc.RpcContext) (any, error) { + // only applies to initial command packet + if command == "" { + return data, nil + } + methodDecl := WshCommandDeclMap[command] + if methodDecl == nil { + return data, fmt.Errorf("command %q not found", command) + } + if methodDecl.CommandDataType == nil { + return data, nil + } + commandDataPtr := reflect.New(methodDecl.CommandDataType).Interface() + if data != nil { + err := utilfn.ReUnmarshal(commandDataPtr, data) + if err != nil { + return data, fmt.Errorf("error re-marshalling command data: %w", err) + } + if rpcCtx != nil { + wshrpc.HackRpcContextIntoData(commandDataPtr, *rpcCtx) + } + } + return reflect.ValueOf(commandDataPtr).Elem().Interface(), nil +} + func serverImplAdapter(impl any) func(*RpcResponseHandler) bool { if impl == nil { return noImplHandler @@ -81,14 +106,13 @@ func serverImplAdapter(impl any) func(*RpcResponseHandler) bool { var callParams []reflect.Value callParams = append(callParams, reflect.ValueOf(handler.Context())) if methodDecl.CommandDataType != nil { - commandData := reflect.New(methodDecl.CommandDataType).Interface() - err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData()) + rpcCtx := handler.GetRpcContext() + cmdData, err := recodeCommandData(cmd, handler.GetCommandRawData(), &rpcCtx) if err != nil { - handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err)) + handler.SendResponseError(err) return true } - wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext()) - callParams = append(callParams, reflect.ValueOf(commandData).Elem()) + callParams = append(callParams, reflect.ValueOf(cmdData)) } if methodDecl.CommandType == wshrpc.RpcType_Call { rtnVals := implMethod.Call(callParams) diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go new file mode 100644 index 000000000..a990e0f28 --- /dev/null +++ b/pkg/wshutil/wshproxy.go @@ -0,0 +1,151 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshutil + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" +) + +type WshRpcProxy struct { + Lock *sync.Mutex + RpcContext *wshrpc.RpcContext + ToRemoteCh chan []byte + FromRemoteCh chan []byte +} + +func MakeRpcProxy() *WshRpcProxy { + return &WshRpcProxy{ + Lock: &sync.Mutex{}, + ToRemoteCh: make(chan []byte, DefaultInputChSize), + FromRemoteCh: make(chan []byte, DefaultOutputChSize), + } +} + +func (p *WshRpcProxy) SetRpcContext(rpcCtx *wshrpc.RpcContext) { + p.Lock.Lock() + defer p.Lock.Unlock() + p.RpcContext = rpcCtx +} + +func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext { + p.Lock.Lock() + defer p.Lock.Unlock() + return p.RpcContext +} + +func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) { + if msg.ReqId == "" { + // no response needed + return + } + resp := RpcMessage{ + ResId: msg.ReqId, + Route: msg.Source, + Error: sendErr.Error(), + } + respBytes, _ := json.Marshal(resp) + p.SendRpcMessage(respBytes) +} + +func (p *WshRpcProxy) sendResponse(msg RpcMessage) { + if msg.ReqId == "" { + // no response needed + return + } + resp := RpcMessage{ + ResId: msg.ReqId, + Route: msg.Source, + } + respBytes, _ := json.Marshal(resp) + p.SendRpcMessage(respBytes) +} + +func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, error) { + if msg.Data == nil { + return nil, fmt.Errorf("no data in authenticate message") + } + strData, ok := msg.Data.(string) + if !ok { + return nil, fmt.Errorf("data in authenticate message not a string") + } + newCtx, err := ValidateAndExtractRpcContextFromToken(strData) + if err != nil { + return nil, fmt.Errorf("error validating token: %w", err) + } + if newCtx == nil { + return nil, fmt.Errorf("no context found in jwt token") + } + if newCtx.BlockId == "" { + return nil, fmt.Errorf("no blockId found in jwt token") + } + if _, err := uuid.Parse(newCtx.BlockId); err != nil { + return nil, fmt.Errorf("invalid blockId in jwt token") + } + return newCtx, nil +} + +func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { + for { + msgBytes, ok := <-p.FromRemoteCh + if !ok { + return nil, fmt.Errorf("remote closed, not authenticated") + } + var msg RpcMessage + err := json.Unmarshal(msgBytes, &msg) + if err != nil { + // nothing to do, can't even send a response since we don't have Source or ReqId + continue + } + if msg.Command == "" { + // this message is not allowed (protocol error at this point), ignore + continue + } + // we only allow one command "authenticate", everything else returns an error + if msg.Command != wshrpc.Command_Authenticate { + respErr := fmt.Errorf("connection not authenticated") + p.sendResponseError(msg, respErr) + continue + } + newCtx, err := handleAuthenticationCommand(msg) + if err != nil { + p.sendResponseError(msg, err) + continue + } + p.sendResponse(msg) + return newCtx, nil + } +} + +func (p *WshRpcProxy) SendRpcMessage(msg []byte) { + p.ToRemoteCh <- msg +} + +func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) { + msgBytes, ok := <-p.FromRemoteCh + if !ok || p.RpcContext == nil { + return msgBytes, ok + } + var msg RpcMessage + err := json.Unmarshal(msgBytes, &msg) + if err != nil { + // nothing to do here -- will error out at another level + return msgBytes, true + } + msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext) + if err != nil { + // nothing to do here -- will error out at another level + return msgBytes, true + } + newBytes, err := json.Marshal(msg) + if err != nil { + // nothing to do here + return msgBytes, true + } + return newBytes, true +} diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go new file mode 100644 index 000000000..d04ff5bd5 --- /dev/null +++ b/pkg/wshutil/wshrouter.go @@ -0,0 +1,212 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshutil + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + + "github.com/wavetermdev/thenextwave/pkg/wshrpc" +) + +type routeInfo struct { + RpcId string + SourceRouteId string + DestRouteId string +} + +type WshRouter struct { + Lock *sync.Mutex + DefaultRoute string + RouteMap map[string]AbstractRpcClient + RpcMap map[string]*routeInfo + InputCh chan []byte +} + +var DefaultRouter = NewWshRouter() + +func NewWshRouter() *WshRouter { + rtn := &WshRouter{ + Lock: &sync.Mutex{}, + RouteMap: make(map[string]AbstractRpcClient), + RpcMap: make(map[string]*routeInfo), + InputCh: make(chan []byte, DefaultInputChSize), + } + go rtn.runServer() + return rtn +} + +func noRouteErr(routeId string) error { + if routeId == "" { + return errors.New("no default route") + } + return fmt.Errorf("no route for %q", routeId) +} + +func (router *WshRouter) handleNoRoute(msg RpcMessage) { + nrErr := noRouteErr(msg.Route) + if msg.ReqId == "" { + if msg.Command == wshrpc.Command_Message { + // to prevent infinite loops + return + } + // no response needed, but send message back to source + respMsg := RpcMessage{Command: wshrpc.Command_Message, Route: msg.Source, Data: wshrpc.CommandMessageData{Message: nrErr.Error()}} + respBytes, _ := json.Marshal(respMsg) + router.InputCh <- respBytes + return + } + // send error response + response := RpcMessage{ + Route: msg.Source, + ResId: msg.ReqId, + Error: nrErr.Error(), + } + respBytes, _ := json.Marshal(response) + router.InputCh <- respBytes +} + +func (router *WshRouter) registerRouteInfo(rpcId string, sourceRouteId string, destRouteId string) { + router.Lock.Lock() + defer router.Lock.Unlock() + router.RpcMap[rpcId] = &routeInfo{RpcId: rpcId, SourceRouteId: sourceRouteId, DestRouteId: destRouteId} +} + +func (router *WshRouter) unregisterRouteInfo(rpcId string) { + router.Lock.Lock() + defer router.Lock.Unlock() + delete(router.RpcMap, rpcId) +} + +func (router *WshRouter) getRouteInfo(rpcId string) *routeInfo { + router.Lock.Lock() + defer router.Lock.Unlock() + return router.RpcMap[rpcId] +} + +func (router *WshRouter) runServer() { + for msgBytes := range router.InputCh { + var msg RpcMessage + err := json.Unmarshal(msgBytes, &msg) + if err != nil { + fmt.Println("error unmarshalling message: ", err) + continue + } + var routeId string + msg.Route, routeId = popRoute(msg.Route) + if msg.Command != "" { + // new comand, setup new rpc + rpc := router.GetRpc(routeId) + if rpc == nil { + router.handleNoRoute(msg) + continue + } + if msg.ReqId != "" { + router.registerRouteInfo(msg.ReqId, msg.Source, routeId) + } + rpc.SendRpcMessage(msgBytes) + continue + } + // look at reqid or resid to route correctly + if msg.ReqId != "" { + routeInfo := router.getRouteInfo(msg.ReqId) + if routeInfo == nil { + // no route info, nothing to do + continue + } + rpc := router.GetRpc(routeInfo.DestRouteId) + if rpc != nil { + rpc.SendRpcMessage(msgBytes) + } + continue + } else if msg.ResId != "" { + routeInfo := router.getRouteInfo(msg.ResId) + if routeInfo == nil { + // no route info, nothing to do + continue + } + rpc := router.GetRpc(routeInfo.SourceRouteId) + if rpc != nil { + rpc.SendRpcMessage(msgBytes) + } + if !msg.Cont { + router.unregisterRouteInfo(msg.ResId) + } + continue + } else { + // this is a bad message (no command, reqid, or resid) + continue + } + } +} + +func addRoute(curRoute string, newRoute string) string { + if curRoute == "" { + return newRoute + } + return curRoute + "," + newRoute +} + +// returns (newRoute, poppedRoute) +func popRoute(curRoute string) (string, string) { + routes := strings.Split(curRoute, ",") + if len(routes) == 1 { + return "", curRoute + } + return strings.Join(routes[:len(routes)-1], ","), routes[len(routes)-1] +} + +// this will also consume the output channel of the abstract client +func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) { + router.Lock.Lock() + defer router.Lock.Unlock() + router.RouteMap[routeId] = rpc + go func() { + for { + msgBytes, ok := rpc.RecvRpcMessage() + if !ok { + break + } + var rpcMsg RpcMessage + err := json.Unmarshal(msgBytes, &rpcMsg) + if err != nil { + continue + } + if rpcMsg.Command != "" { + // new command, add source (for backward routing) + rpcMsg.Source = addRoute(rpcMsg.Source, routeId) + msgBytes, err = json.Marshal(rpcMsg) + if err != nil { + continue + } + } + router.InputCh <- msgBytes + } + }() +} + +func (router *WshRouter) UnregisterRoute(routeId string) { + router.Lock.Lock() + defer router.Lock.Unlock() + delete(router.RouteMap, routeId) +} + +func (router *WshRouter) SetDefaultRoute(routeId string) { + router.Lock.Lock() + defer router.Lock.Unlock() + router.DefaultRoute = routeId +} + +// this may return nil (returns default only for empty routeId) +func (router *WshRouter) GetRpc(routeId string) AbstractRpcClient { + router.Lock.Lock() + defer router.Lock.Unlock() + if routeId == "" { + routeId = router.DefaultRoute + } + return router.RouteMap[routeId] +} diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index 2479f047f..b7945728f 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -32,6 +32,11 @@ type ServerImpl interface { WshServerImpl() } +type AbstractRpcClient interface { + SendRpcMessage(msg []byte) + RecvRpcMessage() ([]byte, bool) // blocking +} + type WshRpc struct { Lock *sync.Mutex clientId string @@ -45,11 +50,16 @@ type WshRpc struct { } type wshRpcContextKey struct{} +type wshRpcRespHandlerContextKey struct{} func withWshRpcContext(ctx context.Context, wshRpc *WshRpc) context.Context { return context.WithValue(ctx, wshRpcContextKey{}, wshRpc) } +func withRespHandler(ctx context.Context, handler *RpcResponseHandler) context.Context { + return context.WithValue(ctx, wshRpcRespHandlerContextKey{}, handler) +} + func GetWshRpcFromContext(ctx context.Context) *WshRpc { rtn := ctx.Value(wshRpcContextKey{}) if rtn == nil { @@ -58,11 +68,38 @@ func GetWshRpcFromContext(ctx context.Context) *WshRpc { return rtn.(*WshRpc) } +func GetRpcSourceFromContext(ctx context.Context) string { + rtn := ctx.Value(wshRpcRespHandlerContextKey{}) + if rtn == nil { + return "" + } + return rtn.(*RpcResponseHandler).GetSource() +} + +func GetRpcResponseHandlerFromContext(ctx context.Context) *RpcResponseHandler { + rtn := ctx.Value(wshRpcRespHandlerContextKey{}) + if rtn == nil { + return nil + } + return rtn.(*RpcResponseHandler) +} + +func (w *WshRpc) SendRpcMessage(msg []byte) { + w.InputCh <- msg +} + +func (w *WshRpc) RecvRpcMessage() ([]byte, bool) { + msg, more := <-w.OutputCh + return msg, more +} + type RpcMessage struct { Command string `json:"command,omitempty"` ReqId string `json:"reqid,omitempty"` ResId string `json:"resid,omitempty"` Timeout int `json:"timeout,omitempty"` + Route string `json:"route,omitempty"` // to route/forward requests to alternate servers + Source string `json:"source,omitempty"` // source route id Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming) Error string `json:"error,omitempty"` @@ -141,7 +178,6 @@ func validateServerImpl(serverImpl ServerImpl) { } } -// oscEsc is the OSC escape sequence to use for *sending* messages // closes outputCh when inputCh is closed/done func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcContext, serverImpl ServerImpl) *WshRpc { validateServerImpl(serverImpl) @@ -235,12 +271,14 @@ func (w *WshRpc) handleRequest(req *RpcMessage) { reqId: req.ReqId, command: req.Command, commandData: req.Data, + source: req.Source, done: &atomic.Bool{}, canceled: &atomic.Bool{}, contextCancelFn: &atomic.Pointer[context.CancelFunc]{}, rpcCtx: w.GetRpcContext(), } respHandler.contextCancelFn.Store(&cancelFn) + respHandler.ctx = withRespHandler(ctx, respHandler) w.registerResponseHandler(req.ReqId, respHandler) isAsync := false defer func() { @@ -347,8 +385,14 @@ func (w *WshRpc) unregisterRpc(reqId string, err error) { } // no response -func (w *WshRpc) SendCommand(command string, data any) error { - handler, err := w.SendComplexRequest(command, data, false, 0) +func (w *WshRpc) SendCommand(command string, data any, opts *wshrpc.RpcOpts) error { + var optsCopy wshrpc.RpcOpts + if opts != nil { + optsCopy = *opts + } + optsCopy.NoResponse = true + optsCopy.Timeout = 0 + handler, err := w.SendComplexRequest(command, data, &optsCopy) if err != nil { return err } @@ -357,8 +401,13 @@ func (w *WshRpc) SendCommand(command string, data any) error { } // single response -func (w *WshRpc) SendRpcRequest(command string, data any, timeoutMs int) (any, error) { - handler, err := w.SendComplexRequest(command, data, true, timeoutMs) +func (w *WshRpc) SendRpcRequest(command string, data any, opts *wshrpc.RpcOpts) (any, error) { + var optsCopy wshrpc.RpcOpts + if opts != nil { + optsCopy = *opts + } + optsCopy.NoResponse = false + handler, err := w.SendComplexRequest(command, data, &optsCopy) if err != nil { return nil, err } @@ -444,6 +493,7 @@ type RpcResponseHandler struct { ctx context.Context contextCancelFn *atomic.Pointer[context.CancelFunc] reqId string + source string command string commandData any rpcCtx wshrpc.RpcContext @@ -467,6 +517,10 @@ func (handler *RpcResponseHandler) GetRpcContext() wshrpc.RpcContext { return handler.rpcCtx } +func (handler *RpcResponseHandler) GetSource() string { + return handler.source +} + func (handler *RpcResponseHandler) NeedsResponse() bool { return handler.reqId != "" } @@ -559,7 +613,11 @@ func (handler *RpcResponseHandler) IsDone() bool { return handler.done.Load() } -func (w *WshRpc) SendComplexRequest(command string, data any, expectsResponse bool, timeoutMs int) (rtnHandler *RpcRequestHandler, rtnErr error) { +func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOpts) (rtnHandler *RpcRequestHandler, rtnErr error) { + if opts == nil { + opts = &wshrpc.RpcOpts{} + } + timeoutMs := opts.Timeout if timeoutMs <= 0 { timeoutMs = DefaultTimeoutMs } @@ -579,7 +637,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, expectsResponse bo var cancelFn context.CancelFunc handler.ctx, cancelFn = context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) handler.ctxCancelFn.Store(&cancelFn) - if expectsResponse { + if !opts.NoResponse { handler.reqId = uuid.New().String() } req := &RpcMessage{ @@ -587,6 +645,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, expectsResponse bo ReqId: handler.reqId, Data: data, Timeout: timeoutMs, + Route: opts.Route, } barr, err := json.Marshal(req) if err != nil { diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index f2fd882df..cf4f15804 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -248,9 +248,6 @@ func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, erro if rpcCtx.TabId != "" { claims["tabid"] = rpcCtx.TabId } - if rpcCtx.WindowId != "" { - claims["windowid"] = rpcCtx.WindowId - } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) if err != nil { @@ -302,11 +299,6 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext { rpcCtx.TabId = tabId } } - if claims["windowid"] != nil { - if windowId, ok := claims["windowid"].(string); ok { - rpcCtx.WindowId = windowId - } - } return rpcCtx }