diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index 9a7a0781c..09497770c 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -231,7 +231,7 @@ func main() { } } }() - go web.RunWebServer(unixListener) + go wshserver.RunWshRpcOverListener(unixListener) web.RunWebServer(webListener) // blocking runtime.KeepAlive(waveLock) } diff --git a/cmd/wsh/cmd/wshcmd-deleteblock.go b/cmd/wsh/cmd/wshcmd-deleteblock.go index a1a52307f..cc304920b 100644 --- a/cmd/wsh/cmd/wshcmd-deleteblock.go +++ b/cmd/wsh/cmd/wshcmd-deleteblock.go @@ -4,11 +4,8 @@ package cmd import ( - "fmt" - "github.com/spf13/cobra" "github.com/wavetermdev/thenextwave/pkg/wshrpc" - "github.com/wavetermdev/thenextwave/pkg/wshutil" ) var deleteBlockCmd = &cobra.Command{ @@ -25,22 +22,21 @@ func init() { func deleteBlockRun(cmd *cobra.Command, args []string) { oref := args[0] if oref == "" { - fmt.Println("oref is required") + WriteStderr("[error] oref is required\n") return } err := validateEasyORef(oref) if err != nil { - fmt.Printf("%v\n", err) + WriteStderr("[error]%v\n", err) return } - wshutil.SetTermRawModeAndInstallShutdownHandlers(true) fullORef, err := resolveSimpleId(oref) if err != nil { - fmt.Printf("error resolving oref: %v\r\n", err) + WriteStderr("[error] resolving oref: %v\n", err) return } if fullORef.OType != "block" { - fmt.Printf("oref is not a block\r\n") + WriteStderr("[error] oref is not a block\n") return } deleteBlockData := &wshrpc.CommandDeleteBlockData{ @@ -48,8 +44,8 @@ func deleteBlockRun(cmd *cobra.Command, args []string) { } _, err = RpcClient.SendRpcRequest(wshrpc.Command_DeleteBlock, deleteBlockData, 2000) if err != nil { - fmt.Printf("error deleting block: %v\r\n", err) + WriteStderr("[error] deleting block: %v\n", err) return } - fmt.Print("block deleted\r\n") + WriteStdout("block deleted\n") } diff --git a/cmd/wsh/cmd/wshcmd-getmeta.go b/cmd/wsh/cmd/wshcmd-getmeta.go index a7a5a892d..1f5efabaf 100644 --- a/cmd/wsh/cmd/wshcmd-getmeta.go +++ b/cmd/wsh/cmd/wshcmd-getmeta.go @@ -5,14 +5,10 @@ package cmd import ( "encoding/json" - "fmt" - "log" - "strings" "github.com/spf13/cobra" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient" - "github.com/wavetermdev/thenextwave/pkg/wshutil" ) var getMetaCmd = &cobra.Command{ @@ -29,24 +25,22 @@ func init() { func getMetaRun(cmd *cobra.Command, args []string) { oref := args[0] if oref == "" { - fmt.Println("oref is required") + WriteStderr("[error] oref is required") return } err := validateEasyORef(oref) if err != nil { - fmt.Printf("%v\n", err) + WriteStderr("[error] %v\n", err) return } - - wshutil.SetTermRawModeAndInstallShutdownHandlers(true) fullORef, err := resolveSimpleId(oref) if err != nil { - fmt.Printf("error resolving oref: %v\r\n", err) + WriteStderr("[error] resolving oref: %v\n", err) return } resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: *fullORef}, &wshrpc.WshRpcCommandOpts{Timeout: 2000}) if err != nil { - log.Printf("error getting metadata: %v\r\n", err) + WriteStderr("[error] getting metadata: %v\n", err) return } if len(args) > 1 { @@ -56,21 +50,18 @@ func getMetaRun(cmd *cobra.Command, args []string) { } outBArr, err := json.MarshalIndent(val, "", " ") if err != nil { - log.Printf("error formatting metadata: %v\r\n", err) - } - outStr := string(outBArr) - outStr = strings.ReplaceAll(outStr, "\n", "\r\n") - fmt.Print(outStr) - fmt.Print("\r\n") - } else { - outBArr, err := json.MarshalIndent(resp, "", " ") - if err != nil { - log.Printf("error formatting metadata: %v\r\n", err) + WriteStderr("[error] formatting metadata: %v\n", err) return } outStr := string(outBArr) - outStr = strings.ReplaceAll(outStr, "\n", "\r\n") - fmt.Print(outStr) - fmt.Print("\r\n") + WriteStdout(outStr + "\n") + } else { + outBArr, err := json.MarshalIndent(resp, "", " ") + if err != nil { + WriteStderr("[error] formatting metadata: %v\n", err) + return + } + outStr := string(outBArr) + WriteStdout(outStr + "\n") } } diff --git a/cmd/wsh/cmd/wshcmd-readfile.go b/cmd/wsh/cmd/wshcmd-readfile.go index f04beedc7..2d240301d 100644 --- a/cmd/wsh/cmd/wshcmd-readfile.go +++ b/cmd/wsh/cmd/wshcmd-readfile.go @@ -5,13 +5,10 @@ package cmd import ( "encoding/base64" - "fmt" - "os" "github.com/spf13/cobra" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient" - "github.com/wavetermdev/thenextwave/pkg/wshutil" ) var readFileCmd = &cobra.Command{ @@ -28,30 +25,28 @@ func init() { func runReadFile(cmd *cobra.Command, args []string) { oref := args[0] if oref == "" { - fmt.Fprintf(os.Stderr, "oref is required\r\n") + WriteStderr("[error] oref is required\n") return } err := validateEasyORef(oref) if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) + WriteStderr("[error] %v\n", err) return } - - wshutil.SetTermRawModeAndInstallShutdownHandlers(true) fullORef, err := resolveSimpleId(oref) if err != nil { - fmt.Fprintf(os.Stderr, "error resolving oref: %v\r\n", err) + WriteStderr("error resolving oref: %v\n", err) return } resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.WshRpcCommandOpts{Timeout: 5000}) if err != nil { - fmt.Fprintf(os.Stderr, "error reading file: %v\r\n", err) + WriteStderr("[error] reading file: %v\n", err) return } resp, err := base64.StdEncoding.DecodeString(resp64) if err != nil { - fmt.Fprintf(os.Stderr, "error decoding file: %v\r\n", err) + WriteStderr("[error] decoding file: %v\n", err) return } - fmt.Print(string(resp)) + WriteStdout(string(resp)) } diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index 2ca5275a5..82ab1427f 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -29,8 +29,9 @@ var ( ) var usingHtmlMode bool -var WrappedStdin io.Reader +var WrappedStdin io.Reader = os.Stdin var RpcClient *wshutil.WshRpc +var UsingTermWshMode bool func extraShutdownFn() { if usingHtmlMode { @@ -42,15 +43,46 @@ func extraShutdownFn() { } } +func WriteStderr(fmtStr string, args ...interface{}) { + output := fmt.Sprintf(fmtStr, args...) + if UsingTermWshMode { + output = strings.ReplaceAll(output, "\n", "\r\n") + } + fmt.Fprint(os.Stderr, output) +} + +func WriteStdout(fmtStr string, args ...interface{}) { + output := fmt.Sprintf(fmtStr, args...) + if UsingTermWshMode { + output = strings.ReplaceAll(output, "\n", "\r\n") + } + fmt.Print(output) +} + // returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output) -func setupRpcClient(handlerFn wshutil.CommandHandlerFnType) { - log.Printf("setup rpc client\r\n") - RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(handlerFn) +func setupRpcClient(handlerFn wshutil.CommandHandlerFnType) error { + jwtToken := os.Getenv("WAVETERM_JWT") + if jwtToken == "" { + wshutil.SetTermRawModeAndInstallShutdownHandlers(true) + UsingTermWshMode = true + RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(handlerFn) + return nil + } + sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken) + if err != nil { + return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err) + } + RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, handlerFn) + if err != nil { + return fmt.Errorf("error setting up domain socket rpc client: %v", err) + } + wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.WshRpcCommandOpts{NoResponse: true}) + // note we don't modify WrappedStdin here (just use os.Stdin) + return nil } func setTermHtmlMode() { wshutil.SetExtraShutdownFunc(extraShutdownFn) - wshutil.SetTermRawModeAndInstallShutdownHandlers(true) cmd := &wshrpc.CommandSetMetaData{ Meta: map[string]any{"term:mode": "html"}, } @@ -109,8 +141,18 @@ func resolveSimpleId(id string) (*waveobj.ORef, error) { } // Execute executes the root command. -func Execute() error { +func Execute() { defer wshutil.DoShutdown("", 0, false) - setupRpcClient(nil) - return rootCmd.Execute() + err := setupRpcClient(nil) + if err != nil { + log.Printf("[error] %v\n", err) + wshutil.DoShutdown("", 1, true) + return + } + err = rootCmd.Execute() + if err != nil { + log.Printf("[error] %v\n", err) + wshutil.DoShutdown("", 1, true) + return + } } diff --git a/cmd/wsh/cmd/wshcmd-setmeta.go b/cmd/wsh/cmd/wshcmd-setmeta.go index eebcdd293..9dfe02191 100644 --- a/cmd/wsh/cmd/wshcmd-setmeta.go +++ b/cmd/wsh/cmd/wshcmd-setmeta.go @@ -11,7 +11,6 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/thenextwave/pkg/wshrpc" - "github.com/wavetermdev/thenextwave/pkg/wshutil" ) var setMetaCmd = &cobra.Command{ @@ -62,23 +61,22 @@ func setMetaRun(cmd *cobra.Command, args []string) { oref := args[0] metaSetsStrs := args[1:] if oref == "" { - fmt.Println("oref is required") + WriteStderr("[error] oref is required\n") return } err := validateEasyORef(oref) if err != nil { - fmt.Printf("%v\n", err) + WriteStderr("[error] %v\n", err) return } meta, err := parseMetaSets(metaSetsStrs) if err != nil { - fmt.Printf("%v\n", err) + WriteStderr("[error] %v\n", err) return } - wshutil.SetTermRawModeAndInstallShutdownHandlers(true) fullORef, err := resolveSimpleId(oref) if err != nil { - fmt.Printf("error resolving oref: %v\n", err) + WriteStderr("[error] resolving oref: %v\n", err) return } setMetaWshCmd := &wshrpc.CommandSetMetaData{ @@ -87,8 +85,8 @@ func setMetaRun(cmd *cobra.Command, args []string) { } _, err = RpcClient.SendRpcRequest(wshrpc.Command_SetMeta, setMetaWshCmd, 2000) if err != nil { - fmt.Printf("error setting metadata: %v\n", err) + WriteStderr("[error] setting metadata: %v\n", err) return } - fmt.Print("metadata set\r\n") + WriteStdout("metadata set\n") } diff --git a/cmd/wsh/cmd/wshcmd-version.go b/cmd/wsh/cmd/wshcmd-version.go index 44784b20b..0145d7c81 100644 --- a/cmd/wsh/cmd/wshcmd-version.go +++ b/cmd/wsh/cmd/wshcmd-version.go @@ -4,8 +4,6 @@ package cmd import ( - "fmt" - "github.com/spf13/cobra" ) @@ -17,6 +15,6 @@ var versionCmd = &cobra.Command{ Use: "version", Short: "Print the version number of wsh", Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("wsh v0.1.0\n") + WriteStdout("wsh v0.1.0\n") }, } diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index 1a5ddc995..0faf2855d 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -5,13 +5,11 @@ package cmd import ( "io/fs" - "log" "os" "path/filepath" "github.com/spf13/cobra" "github.com/wavetermdev/thenextwave/pkg/wshrpc" - "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" ) @@ -33,18 +31,18 @@ func viewRun(cmd *cobra.Command, args []string) { fileArg := args[0] absFile, err := filepath.Abs(fileArg) if err != nil { - log.Printf("error getting absolute path: %v\n", err) + WriteStderr("[error] getting absolute path: %v\n", err) return } _, err = os.Stat(absFile) if err == fs.ErrNotExist { - log.Printf("file does not exist: %q\n", absFile) + WriteStderr("[error] file does not exist: %q\n", absFile) return } if err != nil { - log.Printf("error getting file info: %v\n", err) + WriteStderr("[error] getting file info: %v\n", err) + return } - wshutil.SetTermRawModeAndInstallShutdownHandlers(true) viewWshCmd := &wshrpc.CommandCreateBlockData{ BlockDef: &wstore.BlockDef{ Meta: map[string]interface{}{ @@ -55,7 +53,7 @@ func viewRun(cmd *cobra.Command, args []string) { } _, err = RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, viewWshCmd, 2000) if err != nil { - log.Printf("error running view command: %v\r\n", err) + WriteStderr("[error] running view command: %v\r\n", err) return } } diff --git a/cmd/wsh/main-wsh.go b/cmd/wsh/main-wsh.go index 12419fe10..0b6be3c2c 100644 --- a/cmd/wsh/main-wsh.go +++ b/cmd/wsh/main-wsh.go @@ -9,5 +9,4 @@ import ( func main() { cmd.Execute() - } diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index c15b5dd83..84a6a97d6 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -267,13 +267,13 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj if err != nil { return fmt.Errorf("error making jwt token: %w", err) } - cmdOpts.Env["WAVETERM_JWT"] = jwtStr + cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } else { jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName()) if err != nil { return fmt.Errorf("error making jwt token: %w", err) } - cmdOpts.Env["WAVETERM_JWT"] = jwtStr + cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } } if bc.ControllerType == BlockController_Shell { diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 04be12c75..cde23d94b 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -94,6 +94,9 @@ type RpcContext struct { func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { dataVal := reflect.ValueOf(dataPtr).Elem() + if dataVal.Kind() != reflect.Struct { + return + } dataType := dataVal.Type() for i := 0; i < dataVal.NumField(); i++ { field := dataVal.Field(i) diff --git a/pkg/wshrpc/wshserver/wshserverutil.go b/pkg/wshrpc/wshserver/wshserverutil.go index dc876f73d..20824e86e 100644 --- a/pkg/wshrpc/wshserver/wshserverutil.go +++ b/pkg/wshrpc/wshserver/wshserverutil.go @@ -8,11 +8,9 @@ import ( "fmt" "log" "net" - "os" "reflect" "github.com/wavetermdev/thenextwave/pkg/util/utilfn" - "github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" ) @@ -159,48 +157,18 @@ func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool { } } -func MakeUnixListener(sockName string) (net.Listener, error) { - os.Remove(sockName) // ignore error - rtn, err := net.Listen("unix", sockName) - if err != nil { - return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err) - } - os.Chmod(sockName, 0700) - log.Printf("Server listening on %s\n", sockName) - return rtn, nil -} - -func runWshRpcWithStream(conn net.Conn) { - defer conn.Close() - inputCh := make(chan []byte, DefaultInputChSize) - outputCh := make(chan []byte, DefaultOutputChSize) - go wshutil.AdaptMsgChToStream(outputCh, conn) - go wshutil.AdaptStreamToMsgCh(conn, inputCh) - wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, mainWshServerHandler) -} - func RunWshRpcOverListener(listener net.Listener) { - go func() { - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("error accepting connection: %v\n", err) - continue - } - go runWshRpcWithStream(conn) + defer log.Printf("domain socket listener shutting down\n") + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("error accepting connection: %v\n", err) + continue } - }() -} - -func RunDomainSocketWshServer() error { - sockName := wavebase.GetDomainSocketName() - listener, err := MakeUnixListener(sockName) - if err != nil { - return fmt.Errorf("error starging unix listener for wsh-server: %w", err) + log.Print("got domain socket connection\n") + // TODO deal with closing connection + go wshutil.SetupConnRpcClient(conn, mainWshServerHandler) } - defer listener.Close() - RunWshRpcOverListener(listener) - return nil } func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) { diff --git a/pkg/wshutil/wshrpcio.go b/pkg/wshutil/wshrpcio.go index c976696dc..9d9ef361d 100644 --- a/pkg/wshutil/wshrpcio.go +++ b/pkg/wshutil/wshrpcio.go @@ -43,28 +43,32 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by } } -func streamToLines(input io.Reader, lineFn func([]byte)) { +func streamToLines(input io.Reader, lineFn func([]byte)) error { var lineBuf lineBuf readBuf := make([]byte, 16*1024) for { n, err := input.Read(readBuf) streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn) if err != nil { - break + return err } } } -func AdaptStreamToMsgCh(input io.Reader, output chan []byte) { - streamToLines(input, func(line []byte) { +func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error { + return streamToLines(input, func(line []byte) { output <- line }) } -func AdaptMsgChToStream(outputCh chan []byte, output io.Writer) error { +func AdaptOutputChToStream(outputCh chan []byte, output io.Writer) error { for msg := range outputCh { if _, err := output.Write(msg); err != nil { - return fmt.Errorf("error writing to output: %w", err) + return fmt.Errorf("error writing to output (AdaptOutputChToStream): %w", err) + } + // write trailing newline + if _, err := output.Write([]byte{'\n'}); err != nil { + return fmt.Errorf("error writing trailing newline to output (AdaptOutputChToStream): %w", err) } } return nil diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index e1919e5d8..7a478cc79 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log" + "net" "os" "os/signal" "sync" @@ -35,6 +36,11 @@ const BEL = 0x07 const ST = 0x9c const ESC = 0x1b +const DefaultOutputChSize = 32 +const DefaultInputChSize = 32 + +const WaveJwtTokenVarName = "WAVETERM_JWT" + // OSC escape types // OSC 23198 ; (JSON | base64-JSON) ST // JSON = must escape all ASCII control characters ([\x00-\x1F\x7F]) @@ -181,8 +187,8 @@ func RestoreTermState() { // returns (wshRpc, wrappedStdin) func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc, io.Reader) { - messageCh := make(chan []byte, 32) - outputCh := make(chan []byte, 32) + messageCh := make(chan []byte, DefaultInputChSize) + outputCh := make(chan []byte, DefaultOutputChSize) ptyBuf := MakePtyBuffer(WaveServerOSCPrefix, os.Stdin, messageCh) rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, handlerFn) go func() { @@ -194,6 +200,42 @@ func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc, return rpcClient, ptyBuf } +func SetupConnRpcClient(conn net.Conn, handlerFn func(*RpcResponseHandler) bool) (*WshRpc, chan error, error) { + inputCh := make(chan []byte, DefaultInputChSize) + outputCh := make(chan []byte, DefaultOutputChSize) + writeErrCh := make(chan error, 1) + go func() { + writeErr := AdaptOutputChToStream(outputCh, conn) + if writeErr != nil { + writeErrCh <- writeErr + close(writeErrCh) + } + }() + go func() { + // when input is closed, close the connection + defer conn.Close() + AdaptStreamToMsgCh(conn, inputCh) + }() + rtn := MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, handlerFn) + return rtn, writeErrCh, nil +} + +func SetupDomainSocketRpcClient(sockName string, handlerFn func(*RpcResponseHandler) bool) (*WshRpc, error) { + conn, err := net.Dial("unix", sockName) + if err != nil { + return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err) + } + rtn, errCh, err := SetupConnRpcClient(conn, handlerFn) + go func() { + defer conn.Close() + err := <-errCh + if err != nil && err != io.EOF { + log.Printf("error in domain socket connection: %v\n", err) + } + }() + return rtn, err +} + func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, error) { claims := jwt.MapClaims{} claims["iat"] = time.Now().Unix() @@ -246,9 +288,21 @@ func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext, return nil, fmt.Errorf("iss claim is missing or invalid") } rpcCtx := &wshrpc.RpcContext{} - rpcCtx.BlockId = claims["blockid"].(string) - rpcCtx.TabId = claims["tabid"].(string) - rpcCtx.WindowId = claims["windowid"].(string) + if claims["blockid"] != nil { + if blockId, ok := claims["blockid"].(string); ok { + rpcCtx.BlockId = blockId + } + } + if claims["tabid"] != nil { + if tabId, ok := claims["tabid"].(string); ok { + rpcCtx.TabId = tabId + } + } + if claims["windowid"] != nil { + if windowId, ok := claims["windowid"].(string); ok { + rpcCtx.WindowId = windowId + } + } return rpcCtx, nil }