From 85874f92cab16fc4d14db3ed1e83ecf6d1c74a4a Mon Sep 17 00:00:00 2001 From: Mike Sawka Date: Sun, 18 Aug 2024 21:26:44 -0700 Subject: [PATCH] set up remote connserver (#245) --- ...{wshcmd-server.go => wshcmd-connserver.go} | 14 +- cmd/wsh/cmd/wshcmd-deleteblock.go | 9 +- cmd/wsh/cmd/wshcmd-getmeta.go | 9 +- cmd/wsh/cmd/wshcmd-html.go | 7 +- cmd/wsh/cmd/wshcmd-readfile.go | 9 +- cmd/wsh/cmd/wshcmd-root.go | 27 ++- cmd/wsh/cmd/wshcmd-setmeta.go | 9 +- cmd/wsh/cmd/wshcmd-term.go | 9 +- cmd/wsh/cmd/wshcmd-view.go | 9 +- pkg/blockcontroller/blockcontroller.go | 9 +- pkg/remote/conncontroller/conncontroller.go | 224 ++++++++++++++++++ pkg/remote/{conncontroller.go => connutil.go} | 152 +----------- pkg/shellexec/shellexec.go | 5 +- pkg/util/shellutil/shellutil.go | 5 + pkg/web/ws.go | 5 +- pkg/wshrpc/wshclient/wshclientutil.go | 22 +- pkg/wshrpc/wshrpctypes.go | 1 + pkg/wshutil/wshproxy.go | 12 +- pkg/wshutil/wshrouter.go | 9 + pkg/wshutil/wshrpcio.go | 4 +- pkg/wshutil/wshutil.go | 19 +- 21 files changed, 357 insertions(+), 212 deletions(-) rename cmd/wsh/cmd/{wshcmd-server.go => wshcmd-connserver.go} (59%) create mode 100644 pkg/remote/conncontroller/conncontroller.go rename pkg/remote/{conncontroller.go => connutil.go} (62%) diff --git a/cmd/wsh/cmd/wshcmd-server.go b/cmd/wsh/cmd/wshcmd-connserver.go similarity index 59% rename from cmd/wsh/cmd/wshcmd-server.go rename to cmd/wsh/cmd/wshcmd-connserver.go index 003a2e955..519a41a65 100644 --- a/cmd/wsh/cmd/wshcmd-server.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -7,15 +7,15 @@ import ( "os" "github.com/spf13/cobra" - "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient" "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote" ) var serverCmd = &cobra.Command{ - Use: "server", - Short: "remote server to power wave blocks", - Args: cobra.NoArgs, - Run: serverRun, + Use: "connserver", + Short: "remote server to power wave blocks", + Args: cobra.NoArgs, + Run: serverRun, + PreRunE: preRunSetupRpcClient, } func init() { @@ -23,10 +23,8 @@ func init() { } func serverRun(cmd *cobra.Command, args []string) { - WriteStdout("running wsh server\n") + WriteStdout("running wsh connserver\n") RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout}) - err := wshclient.TestCommand(RpcClient, "hello", nil) - WriteStdout("got test rtn: %v\n", err) select {} // run forever } diff --git a/cmd/wsh/cmd/wshcmd-deleteblock.go b/cmd/wsh/cmd/wshcmd-deleteblock.go index c4e72d069..7a280cf30 100644 --- a/cmd/wsh/cmd/wshcmd-deleteblock.go +++ b/cmd/wsh/cmd/wshcmd-deleteblock.go @@ -9,10 +9,11 @@ import ( ) var deleteBlockCmd = &cobra.Command{ - Use: "deleteblock", - Short: "delete a block", - Args: cobra.ExactArgs(1), - Run: deleteBlockRun, + Use: "deleteblock", + Short: "delete a block", + Args: cobra.ExactArgs(1), + Run: deleteBlockRun, + PreRunE: preRunSetupRpcClient, } func init() { diff --git a/cmd/wsh/cmd/wshcmd-getmeta.go b/cmd/wsh/cmd/wshcmd-getmeta.go index 522f49ca0..82db28a87 100644 --- a/cmd/wsh/cmd/wshcmd-getmeta.go +++ b/cmd/wsh/cmd/wshcmd-getmeta.go @@ -12,10 +12,11 @@ import ( ) var getMetaCmd = &cobra.Command{ - Use: "getmeta", - Short: "get metadata for an entity", - Args: cobra.RangeArgs(1, 2), - Run: getMetaRun, + Use: "getmeta", + Short: "get metadata for an entity", + Args: cobra.RangeArgs(1, 2), + Run: getMetaRun, + PreRunE: preRunSetupRpcClient, } func init() { diff --git a/cmd/wsh/cmd/wshcmd-html.go b/cmd/wsh/cmd/wshcmd-html.go index 912a3b289..5ed7198d5 100644 --- a/cmd/wsh/cmd/wshcmd-html.go +++ b/cmd/wsh/cmd/wshcmd-html.go @@ -15,9 +15,10 @@ func init() { } var htmlCmd = &cobra.Command{ - Use: "html", - Short: "Launch a demo html-mode terminal", - Run: htmlRun, + Use: "html", + Short: "Launch a demo html-mode terminal", + Run: htmlRun, + PreRunE: preRunSetupRpcClient, } func htmlRun(cmd *cobra.Command, args []string) { diff --git a/cmd/wsh/cmd/wshcmd-readfile.go b/cmd/wsh/cmd/wshcmd-readfile.go index aaa5cee49..f9d2621fa 100644 --- a/cmd/wsh/cmd/wshcmd-readfile.go +++ b/cmd/wsh/cmd/wshcmd-readfile.go @@ -12,10 +12,11 @@ import ( ) var readFileCmd = &cobra.Command{ - Use: "readfile", - Short: "read a blockfile", - Args: cobra.ExactArgs(2), - Run: runReadFile, + Use: "readfile", + Short: "read a blockfile", + Args: cobra.ExactArgs(2), + Run: runReadFile, + PreRunE: preRunSetupRpcClient, } func init() { diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index f9880ac5a..7e4651227 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -23,9 +23,10 @@ import ( var ( rootCmd = &cobra.Command{ - Use: "wsh", - Short: "CLI tool to control Wave Terminal", - Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`, + Use: "wsh", + Short: "CLI tool to control Wave Terminal", + Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`, + SilenceUsage: true, } ) @@ -60,9 +61,17 @@ func WriteStdout(fmtStr string, args ...interface{}) { fmt.Print(output) } +func preRunSetupRpcClient(cmd *cobra.Command, args []string) error { + err := setupRpcClient(nil) + if err != nil { + return err + } + return nil +} + // returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output) func setupRpcClient(serverImpl wshutil.ServerImpl) error { - jwtToken := os.Getenv("WAVETERM_JWT") + jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName) if jwtToken == "" { wshutil.SetTermRawModeAndInstallShutdownHandlers(true) UsingTermWshMode = true @@ -71,7 +80,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl) error { } sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken) if err != nil { - return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err) + return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err) } RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl) if err != nil { @@ -158,13 +167,7 @@ func Execute() { wshutil.DoShutdown("", 0, false) } }() - err := setupRpcClient(nil) - if err != nil { - log.Printf("[error] %v\n", err) - wshutil.DoShutdown("", 1, true) - return - } - err = rootCmd.Execute() + err := rootCmd.Execute() if err != nil { log.Printf("[error] %v\n", err) wshutil.DoShutdown("", 1, true) diff --git a/cmd/wsh/cmd/wshcmd-setmeta.go b/cmd/wsh/cmd/wshcmd-setmeta.go index aec41663c..e32d28d2c 100644 --- a/cmd/wsh/cmd/wshcmd-setmeta.go +++ b/cmd/wsh/cmd/wshcmd-setmeta.go @@ -14,10 +14,11 @@ import ( ) var setMetaCmd = &cobra.Command{ - Use: "setmeta", - Short: "set metadata for an entity", - Args: cobra.MinimumNArgs(2), - Run: setMetaRun, + Use: "setmeta", + Short: "set metadata for an entity", + Args: cobra.MinimumNArgs(2), + Run: setMetaRun, + PreRunE: preRunSetupRpcClient, } func init() { diff --git a/cmd/wsh/cmd/wshcmd-term.go b/cmd/wsh/cmd/wshcmd-term.go index c3b83d2b5..46ea6fda9 100644 --- a/cmd/wsh/cmd/wshcmd-term.go +++ b/cmd/wsh/cmd/wshcmd-term.go @@ -15,10 +15,11 @@ import ( ) var termCmd = &cobra.Command{ - Use: "term", - Short: "open a terminal in directory", - Args: cobra.RangeArgs(0, 1), - Run: termRun, + Use: "term", + Short: "open a terminal in directory", + Args: cobra.RangeArgs(0, 1), + Run: termRun, + PreRunE: preRunSetupRpcClient, } func init() { diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index b3dcc368e..746685946 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -17,10 +17,11 @@ import ( var viewNewBlock bool var viewCmd = &cobra.Command{ - Use: "view", - Short: "preview a file or directory", - Args: cobra.ExactArgs(1), - Run: viewRun, + Use: "view", + Short: "preview a file or directory", + Args: cobra.ExactArgs(1), + Run: viewRun, + PreRunE: preRunSetupRpcClient, } func init() { diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 4a452d413..7f08f08b7 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -18,6 +18,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/eventbus" "github.com/wavetermdev/thenextwave/pkg/filestore" "github.com/wavetermdev/thenextwave/pkg/remote" + "github.com/wavetermdev/thenextwave/pkg/remote/conncontroller" "github.com/wavetermdev/thenextwave/pkg/shellexec" "github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/waveobj" @@ -277,7 +278,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj if err != nil { return err } - conn, err := remote.GetConn(credentialCtx, opts) + conn, err := conncontroller.GetConn(credentialCtx, opts) if err != nil { return err } @@ -288,7 +289,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj } cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr } - shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn) + shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn.Client) if err != nil { return err } @@ -317,7 +318,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj // we don't need to authenticate this wshProxy since it is coming direct wshProxy := wshutil.MakeRpcProxy() wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}) - wshutil.DefaultRouter.RegisterRoute("controller:"+bc.BlockId, wshProxy) + wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy) ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh) go func() { // handles regular output from the pty (goes to the blockfile and xterm) @@ -376,7 +377,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj go func() { // wait for the shell to finish defer func() { - wshutil.DefaultRouter.UnregisterRoute("controller:" + bc.BlockId) + wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeControllerRouteId(bc.BlockId)) bc.UpdateControllerAndSendUpdate(func() bool { bc.ShellProcStatus = Status_Done return true diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go new file mode 100644 index 000000000..b8f0169dc --- /dev/null +++ b/pkg/remote/conncontroller/conncontroller.go @@ -0,0 +1,224 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package conncontroller + +import ( + "context" + "fmt" + "io" + "log" + "net" + "strings" + "sync" + + "github.com/wavetermdev/thenextwave/pkg/remote" + "github.com/wavetermdev/thenextwave/pkg/userinput" + "github.com/wavetermdev/thenextwave/pkg/util/shellutil" + "github.com/wavetermdev/thenextwave/pkg/util/utilfn" + "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" + "github.com/wavetermdev/thenextwave/pkg/wshutil" + "golang.org/x/crypto/ssh" +) + +var globalLock = &sync.Mutex{} +var clientControllerMap = make(map[remote.SSHOpts]*SSHConn) + +type SSHConn struct { + Lock *sync.Mutex + Opts *remote.SSHOpts + Client *ssh.Client + SockName string + DomainSockListener net.Listener + ConnController *ssh.Session +} + +func (conn *SSHConn) Close() error { + if conn.DomainSockListener != nil { + conn.DomainSockListener.Close() + conn.DomainSockListener = nil + } + if conn.ConnController != nil { + conn.ConnController.Close() + conn.ConnController = nil + } + err := conn.Client.Close() + conn.Client = nil + return err +} + +func (conn *SSHConn) OpenDomainSocketListener() error { + if conn.DomainSockListener != nil { + return nil + } + randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness + if err != nil { + return fmt.Errorf("error generating random string: %w", err) + } + sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr) + log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName) + listener, err := conn.Client.ListenUnix(sockName) + if err != nil { + return fmt.Errorf("unable to request connection domain socket: %v", err) + } + conn.SockName = sockName + conn.DomainSockListener = listener + go func() { + defer func() { + conn.Lock.Lock() + defer conn.Lock.Unlock() + conn.DomainSockListener = nil + }() + wshutil.RunWshRpcOverListener(listener) + }() + return nil +} + +func (conn *SSHConn) StartConnServer() error { + conn.Lock.Lock() + defer conn.Lock.Unlock() + if conn.ConnController != nil { + return nil + } + wshPath := remote.GetWshPath(conn.Client) + rpcCtx := wshrpc.RpcContext{ + Conn: conn.Opts.String(), + } + jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName) + if err != nil { + return fmt.Errorf("unable to create jwt token for conn controller: %w", err) + } + sshSession, err := conn.Client.NewSession() + if err != nil { + return fmt.Errorf("unable to create ssh session for conn controller: %w", err) + } + pipeRead, pipeWrite := io.Pipe() + sshSession.Stdout = pipeWrite + sshSession.Stderr = pipeWrite + conn.ConnController = sshSession + cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) + log.Printf("starting conn controller: %s\n", cmdStr) + err = sshSession.Start(cmdStr) + if err != nil { + return fmt.Errorf("unable to start conn controller: %w", err) + } + // service the I/O + go func() { + // wait for termination, clear the controller + waitErr := sshSession.Wait() + log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr) + conn.Lock.Lock() + defer conn.Lock.Unlock() + conn.ConnController = nil + }() + go func() { + readErr := wshutil.StreamToLines(pipeRead, func(line []byte) { + lineStr := string(line) + if !strings.HasSuffix(lineStr, "\n") { + lineStr += "\n" + } + log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr) + }) + if readErr != nil && readErr != io.EOF { + log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr) + } + }() + return nil +} + +func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error { + client := conn.Client + // check that correct wsh extensions are installed + expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) + clientVersion, err := remote.GetWshVersion(client) + if err == nil && clientVersion == expectedVersion { + return nil + } + var queryText string + var title string + if err != nil { + queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?" + title = "Install Wsh Shell Extensions" + } else { + queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion) + title = "Update Wsh Shell Extensions" + } + request := &userinput.UserInputRequest{ + ResponseType: "confirm", + QueryText: queryText, + Title: title, + CheckBoxMsg: "Don't show me this again", + } + response, err := userinput.GetUserInput(ctx, request) + if err != nil || !response.Confirm { + return err + } + log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String()) + clientOs, err := remote.GetClientOs(client) + if err != nil { + return err + } + clientArch, err := remote.GetClientArch(client) + if err != nil { + return err + } + // attempt to install extension + wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh") + if err != nil { + return err + } + log.Printf("successfully installed wsh on %s\n", conn.Opts.String()) + return nil +} + +func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) { + globalLock.Lock() + defer globalLock.Unlock() + + // attempt to retrieve if already opened + conn, ok := clientControllerMap[*opts] + if ok { + return conn, nil + } + + client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts + if err != nil { + return nil, err + } + conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client} + err = conn.OpenDomainSocketListener() + if err != nil { + conn.Close() + return nil, err + } + + installErr := conn.checkAndInstallWsh(ctx) + if installErr != nil { + conn.Close() + return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr) + } + + csErr := conn.StartConnServer() + if csErr != nil { + conn.Close() + return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr) + } + + // save successful connection to map + clientControllerMap[*opts] = conn + + return conn, nil +} + +func DisconnectClient(opts *remote.SSHOpts) error { + globalLock.Lock() + defer globalLock.Unlock() + + client, ok := clientControllerMap[*opts] + if ok { + return client.Close() + } + return fmt.Errorf("client %v not found", opts) +} diff --git a/pkg/remote/conncontroller.go b/pkg/remote/connutil.go similarity index 62% rename from pkg/remote/conncontroller.go rename to pkg/remote/connutil.go index 00fd00976..263e31289 100644 --- a/pkg/remote/conncontroller.go +++ b/pkg/remote/connutil.go @@ -2,156 +2,20 @@ package remote import ( "bytes" - "context" "fmt" "html/template" "io" "log" - "net" "os" "path/filepath" "regexp" "strconv" "strings" - "sync" - "github.com/wavetermdev/thenextwave/pkg/userinput" - "github.com/wavetermdev/thenextwave/pkg/util/shellutil" - "github.com/wavetermdev/thenextwave/pkg/util/utilfn" - "github.com/wavetermdev/thenextwave/pkg/wavebase" - "github.com/wavetermdev/thenextwave/pkg/wshutil" "golang.org/x/crypto/ssh" ) var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`) -var globalLock = &sync.Mutex{} -var clientControllerMap = make(map[SSHOpts]*SSHConn) - -type SSHConn struct { - Lock *sync.Mutex - Opts *SSHOpts - Client *ssh.Client - SockName string - DomainSockListener net.Listener -} - -func (conn *SSHConn) Close() error { - if conn.DomainSockListener != nil { - conn.DomainSockListener.Close() - } - return conn.Client.Close() -} - -func (conn *SSHConn) OpenDomainSocketListener() error { - if conn.DomainSockListener != nil { - return nil - } - randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness - if err != nil { - return fmt.Errorf("error generating random string: %w", err) - } - sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr) - log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName) - listener, err := conn.Client.ListenUnix(sockName) - if err != nil { - return fmt.Errorf("unable to request connection domain socket: %v", err) - } - conn.SockName = sockName - conn.DomainSockListener = listener - go func() { - wshutil.RunWshRpcOverListener(listener) - }() - return nil -} - -func GetConn(ctx context.Context, opts *SSHOpts) (*SSHConn, error) { - globalLock.Lock() - defer globalLock.Unlock() - - // attempt to retrieve if already opened - conn, ok := clientControllerMap[*opts] - if ok { - return conn, nil - } - - client, err := ConnectToClient(ctx, opts) //todo specify or remove opts - if err != nil { - return nil, err - } - conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client} - err = conn.OpenDomainSocketListener() - if err != nil { - conn.Close() - return nil, err - } - - // check that correct wsh extensions are installed - expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) - clientVersion, err := getWshVersion(client) - if err == nil && clientVersion == expectedVersion { - // save successful connection to map - clientControllerMap[*opts] = conn - return conn, nil - } - - var queryText string - var title string - if err != nil { - queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?" - title = "Install Wsh Shell Extensions" - } else { - queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion) - title = "Update Wsh Shell Extensions" - - } - - request := &userinput.UserInputRequest{ - ResponseType: "confirm", - QueryText: queryText, - Title: title, - CheckBoxMsg: "Don't show me this again", - } - response, err := userinput.GetUserInput(ctx, request) - if err != nil || !response.Confirm { - return nil, err - } - - log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String()) - - clientOs, err := getClientOs(client) - if err != nil { - return nil, err - } - - clientArch, err := getClientArch(client) - if err != nil { - return nil, err - } - - // attempt to install extension - wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) - err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh") - if err != nil { - return nil, err - } - log.Printf("successful install") - - // save successful connection to map - clientControllerMap[*opts] = conn - - return conn, nil -} - -func DisconnectClient(opts *SSHOpts) error { - globalLock.Lock() - defer globalLock.Unlock() - - client, ok := clientControllerMap[*opts] - if ok { - return client.Close() - } - return fmt.Errorf("client %v not found", opts) -} func ParseOpts(input string) (*SSHOpts, error) { m := userHostRe.FindStringSubmatch(input) @@ -173,7 +37,7 @@ func ParseOpts(input string) (*SSHOpts, error) { } func DetectShell(client *ssh.Client) (string, error) { - wshPath := getWshPath(client) + wshPath := GetWshPath(client) session, err := client.NewSession() if err != nil { @@ -191,8 +55,8 @@ func DetectShell(client *ssh.Client) (string, error) { return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil } -func getWshVersion(client *ssh.Client) (string, error) { - wshPath := getWshPath(client) +func GetWshVersion(client *ssh.Client) (string, error) { + wshPath := GetWshPath(client) session, err := client.NewSession() if err != nil { @@ -207,7 +71,7 @@ func getWshVersion(client *ssh.Client) (string, error) { return strings.TrimSpace(string(out)), nil } -func getWshPath(client *ssh.Client) string { +func GetWshPath(client *ssh.Client) string { defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh") session, err := client.NewSession() @@ -267,7 +131,7 @@ func hasBashInstalled(client *ssh.Client) (bool, error) { return false, nil } -func getClientOs(client *ssh.Client) (string, error) { +func GetClientOs(client *ssh.Client) (string, error) { session, err := client.NewSession() if err != nil { return "", err @@ -306,7 +170,7 @@ func getClientOs(client *ssh.Client) (string, error) { return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) } -func getClientArch(client *ssh.Client) (string, error) { +func GetClientArch(client *ssh.Client) (string, error) { session, err := client.NewSession() if err != nil { return "", err @@ -360,7 +224,7 @@ mv {{.tempPath}} {{.installPath}}; \ chmod a+x {{.installPath}}; \ ` -func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error { +func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error { // warning: does not work on windows remote yet bashInstalled, err := hasBashInstalled(client) if err != nil { @@ -414,7 +278,7 @@ func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) erro } func InstallClientRcFiles(client *ssh.Client) error { - path := getWshPath(client) + path := GetWshPath(client) session, err := client.NewSession() if err != nil { diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index d31fab0a3..701830ec8 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -24,6 +24,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wstore" + "golang.org/x/crypto/ssh" ) type CommandOptsType struct { @@ -149,8 +150,7 @@ func (pp *PipePty) WriteString(s string) (n int, err error) { return pp.Write([]byte(s)) } -func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *remote.SSHConn) (*ShellProc, error) { - client := conn.Client +func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) { shellPath, err := remote.DetectShell(client) if err != nil { return nil, err @@ -195,6 +195,7 @@ func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts Comma } shellOpts = append(shellOpts, "-c", cmdStr) cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) + log.Printf("combined command is: %s", cmdCombined) } session, err := client.NewSession() diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index 7bb02bbab..858b88522 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -19,6 +19,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/wstore" ) const DefaultTermType = "xterm-256color" @@ -125,6 +126,10 @@ func internalMacUserShell() string { return m[1] } +func DefaultTermSize() wstore.TermSize { + return wstore.TermSize{Rows: DefaultTermRows, Cols: DefaultTermCols} +} + func WaveshellLocalEnvVars(termType string) map[string]string { rtn := make(map[string]string) if termType != "" { diff --git a/pkg/web/ws.go b/pkg/web/ws.go index 324aa3d53..fa335631d 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -257,10 +257,9 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { defer eventbus.UnregisterWSChannel(wsConnId) // we create a wshproxy to handle rpc messages to/from the window wproxy := wshutil.MakeRpcProxy() - rpcRouteId := "window:" + windowId - wshutil.DefaultRouter.RegisterRoute(rpcRouteId, wproxy) + wshutil.DefaultRouter.RegisterRoute(wshutil.MakeWindowRouteId(windowId), wproxy) defer func() { - wshutil.DefaultRouter.UnregisterRoute(rpcRouteId) + wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeWindowRouteId(windowId)) close(wproxy.ToRemoteCh) }() // WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{}) diff --git a/pkg/wshrpc/wshclient/wshclientutil.go b/pkg/wshrpc/wshclient/wshclientutil.go index 6ac547a8d..a5066a343 100644 --- a/pkg/wshrpc/wshclient/wshclientutil.go +++ b/pkg/wshrpc/wshclient/wshclientutil.go @@ -4,6 +4,8 @@ package wshclient import ( + "errors" + "github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" @@ -14,6 +16,9 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int opts = &wshrpc.RpcOpts{} } var respData T + if w == nil { + return respData, errors.New("nil wshrpc passed to wshclient") + } if opts.NoResponse { err := w.SendCommand(command, data, opts) if err != nil { @@ -32,17 +37,26 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int return respData, nil } +func rtnErr[T any](ch chan wshrpc.RespOrErrorUnion[T], err error) { + go func() { + ch <- wshrpc.RespOrErrorUnion[T]{Error: err} + close(ch) + }() +} + func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] { if opts == nil { opts = &wshrpc.RpcOpts{} } respChan := make(chan wshrpc.RespOrErrorUnion[T]) + if w == nil { + rtnErr(respChan, errors.New("nil wshrpc passed to wshclient")) + return respChan + } reqHandler, err := w.SendComplexRequest(command, data, opts) if err != nil { - go func() { - respChan <- wshrpc.RespOrErrorUnion[T]{Error: err} - close(respChan) - }() + rtnErr(respChan, err) + return respChan } else { go func() { defer close(respChan) diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 91e567d10..e590b4ad2 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -107,6 +107,7 @@ type RpcOpts struct { type RpcContext struct { BlockId string `json:"blockid,omitempty"` TabId string `json:"tabid,omitempty"` + Conn string `json:"conn,omitempty"` } func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go index a990e0f28..c64d5b785 100644 --- a/pkg/wshutil/wshproxy.go +++ b/pkg/wshutil/wshproxy.go @@ -6,6 +6,7 @@ package wshutil import ( "encoding/json" "fmt" + "log" "sync" "github.com/google/uuid" @@ -81,11 +82,13 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, error) { if newCtx == nil { return nil, fmt.Errorf("no context found in jwt token") } - if newCtx.BlockId == "" { - return nil, fmt.Errorf("no blockId found in jwt token") + if newCtx.BlockId == "" && newCtx.Conn == "" { + return nil, fmt.Errorf("no blockid or conn found in jwt token") } - if _, err := uuid.Parse(newCtx.BlockId); err != nil { - return nil, fmt.Errorf("invalid blockId in jwt token") + if newCtx.BlockId != "" { + if _, err := uuid.Parse(newCtx.BlockId); err != nil { + return nil, fmt.Errorf("invalid blockId in jwt token") + } } return newCtx, nil } @@ -114,6 +117,7 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { } newCtx, err := handleAuthenticationCommand(msg) if err != nil { + log.Printf("error handling authentication: %v\n", err) p.sendResponseError(msg, err) continue } diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index 07413e7a3..9d2ad523b 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -42,6 +42,14 @@ func MakeConnectionRouteId(connId string) string { return "conn:" + connId } +func MakeControllerRouteId(blockId string) string { + return "controller:" + blockId +} + +func MakeWindowRouteId(windowId string) string { + return "window:" + windowId +} + var DefaultRouter = NewWshRouter() func NewWshRouter() *WshRouter { @@ -230,6 +238,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) { log.Printf("error: WshRouter cannot register sys route\n") return } + log.Printf("registering wsh route %q\n", routeId) router.Lock.Lock() defer router.Lock.Unlock() router.RouteMap[routeId] = rpc diff --git a/pkg/wshutil/wshrpcio.go b/pkg/wshutil/wshrpcio.go index 9d9ef361d..00d564d71 100644 --- a/pkg/wshutil/wshrpcio.go +++ b/pkg/wshutil/wshrpcio.go @@ -43,7 +43,7 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by } } -func streamToLines(input io.Reader, lineFn func([]byte)) error { +func StreamToLines(input io.Reader, lineFn func([]byte)) error { var lineBuf lineBuf readBuf := make([]byte, 16*1024) for { @@ -56,7 +56,7 @@ func streamToLines(input io.Reader, lineFn func([]byte)) error { } func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error { - return streamToLines(input, func(line []byte) { + return StreamToLines(input, func(line []byte) { output <- line }) } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 5d4377b8e..cdc0f4d8b 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -248,6 +248,9 @@ func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, erro if rpcCtx.TabId != "" { claims["tabid"] = rpcCtx.TabId } + if rpcCtx.Conn != "" { + claims["conn"] = rpcCtx.Conn + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) if err != nil { @@ -299,6 +302,11 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext { rpcCtx.TabId = tabId } } + if claims["conn"] != nil { + if conn, ok := claims["conn"].(string); ok { + rpcCtx.Conn = conn + } + } return rpcCtx } @@ -306,9 +314,12 @@ func RunWshRpcOverListener(listener net.Listener) { defer log.Printf("domain socket listener shutting down\n") for { conn, err := listener.Accept() + if err == io.EOF { + break + } if err != nil { log.Printf("error accepting connection: %v\n", err) - continue + break } log.Print("got domain socket connection\n") go handleDomainSocketClient(conn) @@ -337,7 +348,11 @@ func handleDomainSocketClient(conn net.Conn) { // now that we're authenticated, set the ctx and attach to the router log.Printf("domain socket connection authenticated: %#v\n", rpcCtx) proxy.SetRpcContext(rpcCtx) - DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy) + if rpcCtx.BlockId != "" { + DefaultRouter.RegisterRoute(MakeControllerRouteId(rpcCtx.BlockId), proxy) + } else if rpcCtx.Conn != "" { + DefaultRouter.RegisterRoute(MakeConnectionRouteId(rpcCtx.Conn), proxy) + } } // only for use on client