From 6c115716b0e18f13e2671be594fe7a12c5e7b08e Mon Sep 17 00:00:00 2001 From: Sylvie Crowe <107814465+oneirocosm@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:37:03 -0800 Subject: [PATCH] Integrate SSH Library with Waveshell Installation/Auto-update (#322) * refactor launch code to integrate install easier The previous set up of launch was difficult to navigate. This makes it much clearer which will make the auto install flow easier to manage. * feat: integrate auto install into new ssh setup This change makes it possible to auto install using the ssh library instead of making a call to the ssh cli command. This will auto install if the installed waveshell version is incorrect or cannot be found. * chore: clean up some lints for sshclient There was a context that didn't have it's cancel function deferred and an error that wasn't being handle. They're fixed now. * fix: disconnect client if requested or launch fail A recent commit made it so a client remained part of the MShellProc after being disconnected. This is undesireable since a manual disconnection indicates that the user will need to enter their credentials again if required. Similarly, if the launch fails with an error, the expectation is that credentials will have to be entered again. * fix: use legacy timer for the time being The legacy timer frustrates me because it adds a lot of state to the MShellProc struct that is complicated to manage. But it currently works, so I will be keeping it for the time being. * fix: change separator between remoteref and name With the inclusion of the port number in the canonical id, the : separator between the remoteref and remote name causes problems if the port is parsed instead. This changes it to a # in order to avoid this conflict. * fix: check for null when closing extra files It is possible for the list of extra files to contain null files. This change ensures the null files will not be erroneously closed. * fix: change connecting method to show port once With port added to the canonicalname, it no longer makes sense to append the port afterward. * feat: use user input modal for sudo connection The sudo connection used to have a unique way of entering a password. This change provides an alternative method using the user input modal that the other connection methods use. It does not work perfectly with this revision, but the basic building blocks are in place. It needs a few timer updates to be complete. * fix: remove old timer to prevent conflicts with it With this change the old timer is no longer needed. It is not fully removed yet, but it is disabled so as to not get in the way. Additionally, error handling has been slightly improved. There is still a bug where an incorrect password prints a new password prompt after the error message. That needs to be fixed in the future. --- go.work.sum | 4 +- waveshell/pkg/server/server.go | 2 +- waveshell/pkg/shexec/client.go | 112 ++++-- waveshell/pkg/shexec/shexec.go | 12 +- wavesrv/pkg/cmdrunner/resolver.go | 2 +- wavesrv/pkg/remote/remote.go | 548 +++++++++++++++++++----------- wavesrv/pkg/remote/sshclient.go | 7 +- 7 files changed, 451 insertions(+), 236 deletions(-) diff --git a/go.work.sum b/go.work.sum index f41591a7f..b26bd90d3 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,9 +1,7 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.27.11 h1:3/gm/JTX9bX8CpzTgIlrtYpB3EVBDxyg/GY/QdcIEZw= -github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= -github.com/wavetermdev/ssh_config v0.0.0-20240109090616-36c8da3d7376 h1:tFhJgTu7lgd+hldLfPSzDCoWUpXI8wHKR3rxq5jTLkQ= -github.com/wavetermdev/ssh_config v0.0.0-20240109090616-36c8da3d7376/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/waveshell/pkg/server/server.go b/waveshell/pkg/server/server.go index b331d7370..d3729eb16 100644 --- a/waveshell/pkg/server/server.go +++ b/waveshell/pkg/server/server.go @@ -652,7 +652,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err)) return } - cproc, _, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd}) + cproc, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd}) if err != nil { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err)) return diff --git a/waveshell/pkg/shexec/client.go b/waveshell/pkg/shexec/client.go index 61a28454a..9efcf2250 100644 --- a/waveshell/pkg/shexec/client.go +++ b/waveshell/pkg/shexec/client.go @@ -57,9 +57,28 @@ func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, } func (cw CmdWrap) Start() error { + defer func() { + for _, extraFile := range cw.Cmd.ExtraFiles { + if extraFile != nil { + extraFile.Close() + } + } + }() return cw.Cmd.Start() } +func (cw CmdWrap) StdinPipe() (io.WriteCloser, error) { + return cw.Cmd.StdinPipe() +} + +func (cw CmdWrap) StdoutPipe() (io.ReadCloser, error) { + return cw.Cmd.StdoutPipe() +} + +func (cw CmdWrap) StderrPipe() (io.ReadCloser, error) { + return cw.Cmd.StderrPipe() +} + type SessionWrap struct { Session *ssh.Session StartCmd string @@ -101,12 +120,35 @@ func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadClos return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil } +func (sw SessionWrap) StdinPipe() (io.WriteCloser, error) { + return sw.Session.StdinPipe() +} + +func (sw SessionWrap) StdoutPipe() (io.ReadCloser, error) { + stdoutReader, err := sw.Session.StdoutPipe() + if err != nil { + return nil, err + } + return io.NopCloser(stdoutReader), nil +} + +func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) { + stderrReader, err := sw.Session.StderrPipe() + if err != nil { + return nil, err + } + return io.NopCloser(stderrReader), nil +} + type ConnInterface interface { Kill() Wait() error Sender() (*packet.PacketSender, io.WriteCloser, error) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) Start() error + StdinPipe() (io.WriteCloser, error) + StdoutPipe() (io.ReadCloser, error) + StderrPipe() (io.ReadCloser, error) } type ClientProc struct { @@ -120,20 +162,44 @@ type ClientProc struct { Output *packet.PacketParser } +type WaveshellLaunchError struct { + InitPk *packet.InitPacketType +} + +func (wle WaveshellLaunchError) Error() string { + if wle.InitPk.NotFound { + return "waveshell client not found" + } else if semver.MajorMinor(wle.InitPk.Version) != semver.MajorMinor(base.MShellVersion) { + return fmt.Sprintf("invalid remote waveshell version '%s', must be '=%s'", wle.InitPk.Version, semver.MajorMinor(base.MShellVersion)) + } + return fmt.Sprintf("invalid waveshell: init packet=%v", *wle.InitPk) +} + +type InvalidPacketError struct { + InvalidPk *packet.PacketType +} + +func (ipe InvalidPacketError) Error() string { + if ipe.InvalidPk == nil { + return "no init packet received from waveshell client" + } + return fmt.Sprintf("invalid packet received from waveshell client: %s", packet.AsString(*ipe.InvalidPk)) +} + // returns (clientproc, initpk, error) -func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) { +func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, error) { startTs := time.Now() sender, inputWriter, err := ecmd.Sender() if err != nil { - return nil, nil, err + return nil, err } packetParser, stdoutReader, stderrReader, err := ecmd.Parser() if err != nil { - return nil, nil, err + return nil, err } err = ecmd.Start() if err != nil { - return nil, nil, fmt.Errorf("running local client: %w", err) + return nil, fmt.Errorf("running local client: %w", err) } cproc := &ClientProc{ Cmd: ecmd, @@ -150,29 +216,27 @@ func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *pack case pk = <-packetParser.MainCh: case <-ctx.Done(): cproc.Close() - return nil, nil, ctx.Err() + return nil, ctx.Err() } - if pk != nil { - if pk.GetType() != packet.InitPacketStr { - cproc.Close() - return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk)) - } - initPk := pk.(*packet.InitPacketType) - if initPk.NotFound { - cproc.Close() - return nil, initPk, fmt.Errorf("mshell client not found") - } - if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) { - cproc.Close() - return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion)) - } - cproc.InitPk = initPk - } - if cproc.InitPk == nil { + if pk == nil { cproc.Close() - return nil, nil, fmt.Errorf("no init packet received from mshell client") + return nil, InvalidPacketError{} } - return cproc, cproc.InitPk, nil + if pk.GetType() != packet.InitPacketStr { + cproc.Close() + return nil, InvalidPacketError{InvalidPk: &pk} + } + initPk := pk.(*packet.InitPacketType) + if initPk.NotFound { + cproc.Close() + return nil, WaveshellLaunchError{InitPk: initPk} + } + if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) { + cproc.Close() + return nil, WaveshellLaunchError{InitPk: initPk} + } + cproc.InitPk = initPk + return cproc, nil } func (cproc *ClientProc) Close() { diff --git a/waveshell/pkg/shexec/shexec.go b/waveshell/pkg/shexec/shexec.go index 6a858b786..2264440b2 100644 --- a/waveshell/pkg/shexec/shexec.go +++ b/waveshell/pkg/shexec/shexec.go @@ -437,6 +437,16 @@ func MakeMShellSingleCmd() (*exec.Cmd, error) { return ecmd, nil } +func MakeLocalExecCmd(cmdStr string, sapi shellapi.ShellApi) *exec.Cmd { + homeDir, _ := os.UserHomeDir() // ignore error + if homeDir == "" { + homeDir = "/" + } + ecmd := exec.Command(sapi.GetLocalShellPath(), "-c", cmdStr) + ecmd.Dir = homeDir + return ecmd +} + func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) *exec.Cmd { remoteCommand = strings.TrimSpace(remoteCommand) if opts.SSHHost == "" { @@ -587,7 +597,7 @@ func sendMShellBinary(input io.WriteCloser, mshellStream io.Reader) { }() } -func RunInstallFromCmd(ctx context.Context, ecmd *exec.Cmd, tryDetect bool, mshellStream io.Reader, mshellReaderFn MShellBinaryReaderFn, msgFn func(string)) error { +func RunInstallFromCmd(ctx context.Context, ecmd ConnInterface, tryDetect bool, mshellStream io.Reader, mshellReaderFn MShellBinaryReaderFn, msgFn func(string)) error { inputWriter, err := ecmd.StdinPipe() if err != nil { return fmt.Errorf("creating stdin pipe: %v", err) diff --git a/wavesrv/pkg/cmdrunner/resolver.go b/wavesrv/pkg/cmdrunner/resolver.go index 265cf1936..15c8854f9 100644 --- a/wavesrv/pkg/cmdrunner/resolver.go +++ b/wavesrv/pkg/cmdrunner/resolver.go @@ -443,7 +443,7 @@ func parseFullRemoteRef(fullRemoteRef string) (string, string, string, error) { if strings.HasPrefix(fullRemoteRef, "[") && strings.HasSuffix(fullRemoteRef, "]") { fullRemoteRef = fullRemoteRef[1 : len(fullRemoteRef)-1] } - fields := strings.Split(fullRemoteRef, ":") + fields := strings.Split(fullRemoteRef, "#") if len(fields) > 3 { return "", "", "", fmt.Errorf("invalid remote format '%s'", fullRemoteRef) } diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index e9b24f6cd..6e69e28df 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -37,6 +37,7 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/scbus" "github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket" "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" + "github.com/wavetermdev/waveterm/wavesrv/pkg/userinput" "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" @@ -83,12 +84,6 @@ else fi ` -const WaveshellServerRunOnlyFmt = ` - PATH=$PATH:~/.mshell; - [%PINGPACKET%] - mshell-[%VERSION%] --server -` - func MakeLocalMShellCommandStr(isSudo bool) (string, error) { mshellPath, err := scbase.LocalMShellBinaryPath() if err != nil { @@ -107,13 +102,6 @@ func MakeServerCommandStr() string { return rtn } -func MakeServerRunOnlyCommandStr() string { - rtn := strings.ReplaceAll(WaveshellServerRunOnlyFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion)) - rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket) - return rtn - -} - const ( StatusConnected = sstore.RemoteStatus_Connected StatusConnecting = sstore.RemoteStatus_Connecting @@ -175,6 +163,7 @@ type MShellProc struct { RunningCmds map[base.CommandKey]RunCmdType PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] launcher Launcher // for conditional launch method based on ssh library in use. remove once ssh library is stabilized + Client *ssh.Client } type RunCmdType struct { @@ -602,7 +591,7 @@ func (msh *MShellProc) GetRemoteRuntimeState() RemoteRuntimeState { if msh.Status == StatusConnecting { state.WaitingForPassword = msh.isWaitingForPassword_nolock() if msh.MakeClientDeadline != nil { - state.ConnectTimeout = int((*msh.MakeClientDeadline).Sub(time.Now()) / time.Second) + state.ConnectTimeout = int(time.Until(*msh.MakeClientDeadline) / time.Second) if state.ConnectTimeout < 0 { state.ConnectTimeout = 0 } @@ -734,7 +723,7 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc { func SendRemoteInput(pk *scpacket.RemoteInputPacketType) error { data, err := base64.StdEncoding.DecodeString(pk.InputData64) if err != nil { - return fmt.Errorf("cannot decode base64: %v\n", err) + return fmt.Errorf("cannot decode base64: %v", err) } msh := GetRemoteById(pk.RemoteId) if msh == nil { @@ -892,6 +881,7 @@ func (msh *MShellProc) Disconnect(force bool) { defer msh.Lock.Unlock() if msh.ServerProc != nil { msh.ServerProc.Close() + msh.Client = nil } if msh.MakeClientCancelFn != nil { msh.MakeClientCancelFn() @@ -989,6 +979,45 @@ func (msh *MShellProc) isWaitingForPassphrase_nolock() bool { return pwIdx != -1 } +func (msh *MShellProc) RunPasswordReadLoop(cmdPty *os.File) { + buf := make([]byte, PtyReadBufSize) + for { + _, readErr := cmdPty.Read(buf) + if readErr == io.EOF { + return + } + if readErr != nil { + msh.WriteToPtyBuffer("*error reading from controlling-pty: %v\n", readErr) + return + } + var newIsWaiting bool + msh.WithLock(func() { + newIsWaiting = msh.isWaitingForPassword_nolock() + }) + if newIsWaiting { + break + } + } + request := &userinput.UserInputRequestType{ + QueryText: "Please enter your password", + ResponseType: "text", + Title: "Sudo Password", + Markdown: false, + } + ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFn() + response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + if err != nil { + msh.WriteToPtyBuffer("*error timed out waiting for password: %v\n", err) + return + } + msh.WithLock(func() { + curOffset := msh.PtyBuffer.TotalWritten() + msh.PtyBuffer.Write([]byte(response.Text)) + sendRemotePtyUpdate(msh.Remote.RemoteId, curOffset, []byte(response.Text)) + }) +} + func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) { buf := make([]byte, PtyReadBufSize) var isWaiting bool @@ -1015,6 +1044,141 @@ func (msh *MShellProc) RunPtyReadLoop(cmdPty *os.File) { } } +func (msh *MShellProc) CheckPasswordRequested(ctx context.Context, requiresPassword chan bool) { + for { + msh.WithLock(func() { + if msh.isWaitingForPassword_nolock() { + select { + case requiresPassword <- true: + default: + } + return + } + if msh.Status != StatusConnecting { + select { + case requiresPassword <- false: + default: + } + return + } + }) + select { + case <-ctx.Done(): + return + default: + } + time.Sleep(100 * time.Millisecond) + } +} + +func (msh *MShellProc) SendPassword(pw string) { + msh.WithLock(func() { + if msh.ControllingPty == nil { + return + } + pwBytes := []byte(pw + "\r") + msh.writeToPtyBuffer_nolock("~[sent password]\r\n") + _, err := msh.ControllingPty.Write(pwBytes) + if err != nil { + msh.writeToPtyBuffer_nolock("*cannot write password to controlling pty: %v\n", err) + } + }) +} + +func (msh *MShellProc) WaitAndSendPasswordNew(pw string) { + requiresPassword := make(chan bool, 1) + ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFn() + if pw != "" { + // do an extra check with the saved password if it is provided + go msh.CheckPasswordRequested(ctx, requiresPassword) + select { + case <-ctx.Done(): + err := ctx.Err() + var errMsg error + if err == context.Canceled { + errMsg = fmt.Errorf("canceled by the user: %v", err) + } else { + errMsg = fmt.Errorf("timed out waiting for password prompt: %v", err) + } + msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error()) + msh.setErrorStatus(errMsg) + return + case required := <-requiresPassword: + if !required { + // we don't need user input in this case, so we exit early + return + } + } + msh.SendPassword(pw) + } + + // ask for user input once + go msh.CheckPasswordRequested(ctx, requiresPassword) + select { + case <-ctx.Done(): + err := ctx.Err() + var errMsg error + if err == context.Canceled { + errMsg = fmt.Errorf("canceled by the user") + } else { + errMsg = fmt.Errorf("timed out waiting for password prompt") + } + msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error()) + msh.setErrorStatus(errMsg) + return + case required := <-requiresPassword: + if !required { + // we don't need user input in this case, so we exit early + return + } + } + + request := &userinput.UserInputRequestType{ + QueryText: "Please enter your password", + ResponseType: "text", + Title: "Sudo Password", + Markdown: false, + } + response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + if err != nil { + var errMsg error + if err == context.Canceled { + errMsg = fmt.Errorf("canceled by the user") + } else { + errMsg = fmt.Errorf("timed out waiting for user input") + } + msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error()) + msh.setErrorStatus(errMsg) + return + } + msh.SendPassword(response.Text) + + //error out if requested again + go msh.CheckPasswordRequested(ctx, requiresPassword) + select { + case <-ctx.Done(): + err := ctx.Err() + var errMsg error + if err == context.Canceled { + errMsg = fmt.Errorf("canceled by the user") + } else { + errMsg = fmt.Errorf("timed out waiting for password prompt") + } + msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error()) + msh.setErrorStatus(errMsg) + return + case required := <-requiresPassword: + if !required { + // we don't need user input in this case, so we exit early + return + } + } + errMsg := fmt.Errorf("*error, incorrect password") + msh.WriteToPtyBuffer("*error, %s\n", errMsg.Error()) + msh.setErrorStatus(errMsg) +} + func (msh *MShellProc) WaitAndSendPassword(pw string) { var numWaits int for { @@ -1073,29 +1237,34 @@ func (msh *MShellProc) RunInstall() { msh.WriteToPtyBuffer("*error: cannot install on remote that is already trying to install, cancel current install to try again\n") return } - sapi, err := shellapi.MakeShellApi(packet.ShellType_bash) + if remoteCopy.Local { + msh.WriteToPtyBuffer("*error: cannot install on a local remote\n") + return + } + _, err := shellapi.MakeShellApi(packet.ShellType_bash) if err != nil { msh.WriteToPtyBuffer("*error: %v\n", err) return } - msh.WriteToPtyBuffer("installing mshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName) - sshOpts := convertSSHOpts(remoteCopy.SSHOpts) - sshOpts.SSHErrorsToTty = true - cmdStr := shexec.MakeInstallCommandStr() - ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi) - cmdPty, err := msh.addControllingTty(ecmd) + if msh.Client == nil { + client, err := ConnectToClient(remoteCopy.SSHOpts) + if err != nil { + statusErr := fmt.Errorf("ssh cannot connect to client: %w", err) + msh.setInstallErrorStatus(statusErr) + return + } + msh.WithLock(func() { + msh.Client = client + }) + } + session, err := msh.Client.NewSession() if err != nil { - statusErr := fmt.Errorf("cannot attach controlling tty to mshell install command: %w", err) + statusErr := fmt.Errorf("ssh cannot connect to client: %w", err) msh.setInstallErrorStatus(statusErr) return } - defer func() { - if len(ecmd.ExtraFiles) > 0 { - ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close() - } - cmdPty.Close() - }() - go msh.RunPtyReadLoop(cmdPty) + installSession := shexec.SessionWrap{Session: session, StartCmd: shexec.MakeInstallCommandStr()} + msh.WriteToPtyBuffer("installing waveshell %s to %s...\n", scbase.MShellVersion, remoteCopy.RemoteCanonicalName) clientCtx, clientCancelFn := context.WithCancel(context.Background()) defer clientCancelFn() msh.WithLock(func() { @@ -1107,7 +1276,7 @@ func (msh *MShellProc) RunInstall() { msgFn := func(msg string) { msh.WriteToPtyBuffer("%s", msg) } - err = shexec.RunInstallFromCmd(clientCtx, ecmd, true, nil, scbase.MShellBinaryReader, msgFn) + err = shexec.RunInstallFromCmd(clientCtx, installSession, true, nil, scbase.MShellBinaryReader, msgFn) if err == context.Canceled { msh.WriteToPtyBuffer("*install canceled\n") msh.WithLock(func() { @@ -1130,13 +1299,12 @@ func (msh *MShellProc) RunInstall() { msh.Err = nil connectMode = msh.Remote.ConnectMode }) - msh.WriteToPtyBuffer("successfully installed mshell %s to ~/.mshell\n", scbase.MShellVersion) + msh.WriteToPtyBuffer("successfully installed waveshell %s to ~/.mshell\n", scbase.MShellVersion) go msh.NotifyRemoteUpdate() if connectMode == sstore.ConnectModeStartup || connectMode == sstore.ConnectModeAuto { // the install was successful, and we don't have a manual connect mode, try to connect go msh.Launch(true) } - return } func (msh *MShellProc) updateRemoteStateVars(ctx context.Context, remoteId string, initPk *packet.InitPacketType) { @@ -1286,6 +1454,56 @@ func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error return utilfn.CombineStrArrays(rtn, activeShells), nil } +func (msh *MShellProc) createWaveshellSession(remoteCopy sstore.RemoteType) (shexec.ConnInterface, error) { + msh.WithLock(func() { + msh.Err = nil + msh.ErrNoInitPk = false + msh.Status = StatusConnecting + msh.MakeClientDeadline = nil + go msh.NotifyRemoteUpdate() + }) + sapi, err := shellapi.MakeShellApi(msh.GetShellType()) + if err != nil { + return nil, err + } + var wsSession shexec.ConnInterface + if remoteCopy.SSHOpts.SSHHost == "" && remoteCopy.Local { + cmdStr, err := MakeLocalMShellCommandStr(remoteCopy.IsSudo()) + if err != nil { + return nil, fmt.Errorf("cannot find local mshell binary: %v", err) + } + ecmd := shexec.MakeLocalExecCmd(cmdStr, sapi) + var cmdPty *os.File + cmdPty, err = msh.addControllingTty(ecmd) + if err != nil { + return nil, fmt.Errorf("cannot attach controlling tty to mshell command: %v", err) + } + go msh.RunPtyReadLoop(cmdPty) + go msh.WaitAndSendPasswordNew(remoteCopy.SSHOpts.SSHPassword) + wsSession = shexec.CmdWrap{Cmd: ecmd} + } else if msh.Client == nil { + client, err := ConnectToClient(remoteCopy.SSHOpts) + if err != nil { + return nil, fmt.Errorf("ssh cannot connect to client: %w", err) + } + msh.WithLock(func() { + msh.Client = client + }) + session, err := client.NewSession() + if err != nil { + return nil, fmt.Errorf("ssh cannot create session: %w", err) + } + wsSession = shexec.SessionWrap{Session: session, StartCmd: MakeServerCommandStr()} + } else { + session, err := msh.Client.NewSession() + if err != nil { + return nil, fmt.Errorf("ssh cannot create session: %w", err) + } + wsSession = shexec.SessionWrap{Session: session, StartCmd: MakeServerCommandStr()} + } + return wsSession, nil +} + // for conditional launch method based on ssh library in use // remove once ssh library is stabilized type NewLauncher struct{} @@ -1306,147 +1524,79 @@ func (NewLauncher) Launch(msh *MShellProc, interactive bool) { msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n") return } - sapi, err := shellapi.MakeShellApi(msh.GetShellType()) - if err != nil { - msh.WriteToPtyBuffer("*error, %v\n", err) - return - } istatus := msh.GetInstallStatus() if istatus == StatusConnecting { msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n") return } - if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 { - msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort) - } else { - msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName) - } - sshOpts := convertSSHOpts(remoteCopy.SSHOpts) - sshOpts.SSHErrorsToTty = true - if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive { - sshOpts.BatchMode = true - } - var cproc *shexec.ClientProc - var initPk *packet.InitPacketType - if sshOpts.SSHHost == "" && remoteCopy.Local { - makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background()) - defer makeClientCancelFn() + msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName) + wsSession, err := msh.createWaveshellSession(remoteCopy) + if err != nil { + msh.WriteToPtyBuffer("*error, %s\n", err.Error()) + msh.setErrorStatus(err) msh.WithLock(func() { - msh.Err = nil - msh.ErrNoInitPk = false - msh.Status = StatusConnecting - msh.MakeClientCancelFn = makeClientCancelFn - deadlineTime := time.Now().Add(RemoteConnectTimeout) - msh.MakeClientDeadline = &deadlineTime - go msh.NotifyRemoteUpdate() + msh.Client = nil }) - go msh.watchClientDeadlineTime() - cmdStr, err := MakeLocalMShellCommandStr(remoteCopy.IsSudo()) - if err != nil { - msh.WriteToPtyBuffer("*error, cannot find local mshell binary: %v\n", err) - return - } - ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi) - var cmdPty *os.File - cmdPty, err = msh.addControllingTty(ecmd) - if err != nil { - statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err) - msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error()) - msh.setErrorStatus(statusErr) - return - } - defer func() { - if len(ecmd.ExtraFiles) > 0 { - ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close() - } - }() - go msh.RunPtyReadLoop(cmdPty) - if remoteCopy.SSHOpts.SSHPassword != "" { - go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword) - } - cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) - } else { - msh.WithLock(func() { - msh.Err = nil - msh.ErrNoInitPk = false - msh.Status = StatusConnecting - msh.MakeClientDeadline = nil - go msh.NotifyRemoteUpdate() - }) - var client *ssh.Client - client, err = ConnectToClient(remoteCopy.SSHOpts) - if err != nil { - statusErr := fmt.Errorf("ssh cannot connect to client: %w", err) - msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error()) - msh.setErrorStatus(statusErr) - return - } - var session *ssh.Session - session, err = client.NewSession() - if err != nil { - statusErr := fmt.Errorf("ssh cannot create session: %w", err) - msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error()) - msh.setErrorStatus(statusErr) - return - } - makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background()) - defer makeClientCancelFn() - msh.WithLock(func() { - msh.MakeClientCancelFn = makeClientCancelFn - deadlineTime := time.Now().Add(RemoteConnectTimeout) - msh.MakeClientDeadline = &deadlineTime - go msh.NotifyRemoteUpdate() - }) - go msh.watchClientDeadlineTime() - cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()}) + return } - // TODO check if initPk.State is not nil - var mshellVersion string - var hitDeadline bool + var makeClientCtx context.Context + var makeClientCancelFn context.CancelFunc + msh.WithLock(func() { + makeClientCtx, makeClientCancelFn = context.WithCancel(context.Background()) + msh.MakeClientCancelFn = makeClientCancelFn + msh.MakeClientDeadline = nil + go msh.NotifyRemoteUpdate() + }) + defer makeClientCancelFn() + cproc, err := shexec.MakeClientProc(makeClientCtx, wsSession) msh.WithLock(func() { msh.MakeClientCancelFn = nil - if time.Now().After(*msh.MakeClientDeadline) { - hitDeadline = true - } msh.MakeClientDeadline = nil - if initPk == nil { - msh.ErrNoInitPk = true - } - if initPk != nil { - msh.UName = initPk.UName - mshellVersion = initPk.Version - if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 { - // only set NeedsMShellUpgrade if we got an InitPk - msh.NeedsMShellUpgrade = true - } - msh.InitPkShellType = initPk.Shell - } + }) + if err == context.DeadlineExceeded { + msh.WriteToPtyBuffer("*connect timeout\n") + msh.setErrorStatus(errors.New("connect timeout")) + msh.WithLock(func() { + msh.Client = nil + }) + return + } else if err == context.Canceled { + msh.WriteToPtyBuffer("*forced disconnection\n") + msh.WithLock(func() { + msh.Status = StatusDisconnected + go msh.NotifyRemoteUpdate() + }) + msh.WithLock(func() { + msh.Client = nil + }) + return + } else if serr, ok := err.(shexec.WaveshellLaunchError); ok { + msh.WithLock(func() { + msh.UName = serr.InitPk.UName + msh.NeedsMShellUpgrade = true + msh.InitPkShellType = serr.InitPk.Shell + }) + msh.StateMap.Clear() + msh.WriteToPtyBuffer("*error, %s\n", serr.Error()) + msh.setErrorStatus(serr) + go msh.tryAutoInstall() + return + } else if err != nil { + msh.WriteToPtyBuffer("*error, %s\n", serr.Error()) + msh.setErrorStatus(err) + msh.WithLock(func() { + msh.Client = nil + }) + return + } + msh.WithLock(func() { + msh.UName = cproc.InitPk.UName + msh.InitPkShellType = cproc.InitPk.Shell msh.StateMap.Clear() // no notify here, because we'll call notify in either case below }) - if err == context.Canceled { - if hitDeadline { - msh.WriteToPtyBuffer("*connect timeout\n") - msh.setErrorStatus(errors.New("connect timeout")) - } else { - msh.WriteToPtyBuffer("*forced disconnection\n") - msh.WithLock(func() { - msh.Status = StatusDisconnected - go msh.NotifyRemoteUpdate() - }) - } - return - } - if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) { - err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion) - } - if err != nil { - msh.setErrorStatus(err) - msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err) - go msh.tryAutoInstall() - return - } - msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk) + + msh.updateRemoteStateVars(context.Background(), msh.RemoteId, cproc.InitPk) msh.WithLock(func() { msh.ServerProc = cproc msh.Status = StatusConnected @@ -1465,7 +1615,6 @@ func (NewLauncher) Launch(msh *MShellProc, interactive bool) { go msh.ProcessPackets() msh.initActiveShells() go msh.NotifyRemoteUpdate() - return } // for conditional launch method based on ssh library in use @@ -1536,66 +1685,58 @@ func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) { if remoteCopy.SSHOpts.SSHPassword != "" { go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword) } - makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background()) - defer makeClientCancelFn() + var makeClientCtx context.Context + var makeClientCancelFn context.CancelFunc msh.WithLock(func() { + deadlineTime := time.Now().Add(RemoteConnectTimeout) + makeClientCtx, makeClientCancelFn = context.WithDeadline(context.Background(), deadlineTime) + defer makeClientCancelFn() msh.Err = nil msh.ErrNoInitPk = false msh.Status = StatusConnecting msh.MakeClientCancelFn = makeClientCancelFn - deadlineTime := time.Now().Add(RemoteConnectTimeout) msh.MakeClientDeadline = &deadlineTime go msh.NotifyRemoteUpdate() }) go msh.watchClientDeadlineTime() - cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) - // TODO check if initPk.State is not nil - var mshellVersion string - var hitDeadline bool + cproc, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) msh.WithLock(func() { msh.MakeClientCancelFn = nil - if time.Now().After(*msh.MakeClientDeadline) { - hitDeadline = true - } msh.MakeClientDeadline = nil - if initPk == nil { - msh.ErrNoInitPk = true - } - if initPk != nil { - msh.UName = initPk.UName - mshellVersion = initPk.Version - if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 { - // only set NeedsMShellUpgrade if we got an InitPk - msh.NeedsMShellUpgrade = true - } - msh.InitPkShellType = initPk.Shell - } msh.StateMap.Clear() // no notify here, because we'll call notify in either case below }) - if err == context.Canceled { - if hitDeadline { - msh.WriteToPtyBuffer("*connect timeout\n") - msh.setErrorStatus(errors.New("connect timeout")) - } else { - msh.WriteToPtyBuffer("*forced disconnection\n") - msh.WithLock(func() { - msh.Status = StatusDisconnected - go msh.NotifyRemoteUpdate() - }) - } + if err == context.DeadlineExceeded { + msh.WriteToPtyBuffer("*connect timeout\n") + msh.setErrorStatus(errors.New("connect timeout")) return - } - if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) { - err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion) - } - if err != nil { - msh.setErrorStatus(err) - msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err) + } else if err == context.Canceled { + msh.WriteToPtyBuffer("*forced disconnection\n") + msh.WithLock(func() { + msh.Status = StatusDisconnected + go msh.NotifyRemoteUpdate() + }) + return + } else if serr, ok := err.(shexec.WaveshellLaunchError); ok { + msh.WithLock(func() { + msh.UName = serr.InitPk.UName + if semver.Compare(serr.InitPk.Version, scbase.MShellVersion) < 0 { + // only set NeedsMShellUpgrade if we got an InitPk + msh.NeedsMShellUpgrade = true + } + msh.InitPkShellType = serr.InitPk.Shell + }) + msh.WriteToPtyBuffer("*error, %s\n", serr.Error()) + msh.setErrorStatus(serr) go msh.tryAutoInstall() return + } else if err != nil { + msh.WriteToPtyBuffer("*error, %s\n", serr.Error()) + msh.setErrorStatus(err) + return } - msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk) + + msh.updateRemoteStateVars(context.Background(), msh.RemoteId, cproc.InitPk) msh.WithLock(func() { msh.ServerProc = cproc msh.Status = StatusConnected @@ -1614,7 +1755,6 @@ func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) { go msh.ProcessPackets() msh.initActiveShells() go msh.NotifyRemoteUpdate() - return } func (msh *MShellProc) initActiveShells() { @@ -2099,7 +2239,6 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) { } } scbus.MainUpdateBus.DoUpdate(update) - return } func (msh *MShellProc) handleCmdFinalPacket(finalPk *packet.CmdFinalPacketType) { @@ -2143,7 +2282,6 @@ func (msh *MShellProc) handleCmdErrorPacket(errPk *packet.CmdErrorPacketType) { msh.WriteToPtyBuffer("cmderr> [remote %s] [error] adding cmderr: %v\n", msh.GetRemoteName(), err) return } - return } func (msh *MShellProc) ResetDataPos(ck base.CommandKey) { diff --git a/wavesrv/pkg/remote/sshclient.go b/wavesrv/pkg/remote/sshclient.go index 9ec5e9777..2b4e22a02 100644 --- a/wavesrv/pkg/remote/sshclient.go +++ b/wavesrv/pkg/remote/sshclient.go @@ -110,7 +110,8 @@ func createPublicKeyCallback(sshKeywords *SshKeywords, passphrase string) func() QueryText: fmt.Sprintf("Enter passphrase for the SSH key: %s", identityFile), Title: "Publickey Auth + Passphrase", } - ctx, _ := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFn() response, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) if err != nil { // this is an error where we actually do want to stop @@ -267,6 +268,10 @@ func writeToKnownHosts(knownHostsFile string, newLine string, getUserVerificatio } _, err = f.WriteString(newLine) + if err != nil { + f.Close() + return err + } return f.Close() }