diff --git a/waveshell/pkg/server/server.go b/waveshell/pkg/server/server.go index 8a4578a53..6f83ff663 100644 --- a/waveshell/pkg/server/server.go +++ b/waveshell/pkg/server/server.go @@ -651,7 +651,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err)) return } - cproc, _, err := shexec.MakeClientProc(context.Background(), ecmd) + cproc, _, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd}) if err != nil { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err)) return diff --git a/waveshell/pkg/shexec/client.go b/waveshell/pkg/shexec/client.go index 22770f924..61a28454a 100644 --- a/waveshell/pkg/shexec/client.go +++ b/waveshell/pkg/shexec/client.go @@ -12,6 +12,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/packet" + "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" ) @@ -19,8 +20,97 @@ import ( const NotFoundVersion = "v0.0" +type CmdWrap struct { + Cmd *exec.Cmd +} + +func (cw CmdWrap) Kill() { + cw.Cmd.Process.Kill() +} + +func (cw CmdWrap) Wait() error { + return cw.Cmd.Wait() +} + +func (cw CmdWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) { + inputWriter, err := cw.Cmd.StdinPipe() + if err != nil { + return nil, nil, fmt.Errorf("creating stdin pipe: %v", err) + } + sender := packet.MakePacketSender(inputWriter, nil) + return sender, inputWriter, nil +} + +func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) { + stdoutReader, err := cw.Cmd.StdoutPipe() + if err != nil { + return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err) + } + stderrReader, err := cw.Cmd.StderrPipe() + if err != nil { + return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err) + } + stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true}) + stderrPacketParser := packet.MakePacketParser(stderrReader, nil) + packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true) + return packetParser, stdoutReader, stderrReader, nil +} + +func (cw CmdWrap) Start() error { + return cw.Cmd.Start() +} + +type SessionWrap struct { + Session *ssh.Session + StartCmd string +} + +func (sw SessionWrap) Kill() { + sw.Session.Close() +} + +func (sw SessionWrap) Wait() error { + return sw.Session.Wait() +} + +func (sw SessionWrap) Start() error { + return sw.Session.Start(sw.StartCmd) +} + +func (sw SessionWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) { + inputWriter, err := sw.Session.StdinPipe() + if err != nil { + return nil, nil, fmt.Errorf("creating stdin pipe: %v", err) + } + sender := packet.MakePacketSender(inputWriter, nil) + return sender, inputWriter, nil +} + +func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) { + stdoutReader, err := sw.Session.StdoutPipe() + if err != nil { + return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err) + } + stderrReader, err := sw.Session.StderrPipe() + if err != nil { + return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err) + } + stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true}) + stderrPacketParser := packet.MakePacketParser(stderrReader, nil) + packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true) + return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil +} + +type ConnInterface interface { + Kill() + Wait() error + Sender() (*packet.PacketSender, io.WriteCloser, error) + Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) + Start() error +} + type ClientProc struct { - Cmd *exec.Cmd + Cmd ConnInterface InitPk *packet.InitPacketType StartTs time.Time StdinWriter io.WriteCloser @@ -31,28 +121,20 @@ type ClientProc struct { } // returns (clientproc, initpk, error) -func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.InitPacketType, error) { - inputWriter, err := ecmd.StdinPipe() - if err != nil { - return nil, nil, fmt.Errorf("creating stdin pipe: %v", err) - } - stdoutReader, err := ecmd.StdoutPipe() - if err != nil { - return nil, nil, fmt.Errorf("creating stdout pipe: %v", err) - } - stderrReader, err := ecmd.StderrPipe() - if err != nil { - return nil, nil, fmt.Errorf("creating stderr pipe: %v", err) - } +func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) { startTs := time.Now() + sender, inputWriter, err := ecmd.Sender() + if err != nil { + return nil, nil, err + } + packetParser, stdoutReader, stderrReader, err := ecmd.Parser() + if err != nil { + return nil, nil, err + } err = ecmd.Start() if err != nil { return nil, nil, fmt.Errorf("running local client: %w", err) } - sender := packet.MakePacketSender(inputWriter, nil) - stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true}) - stderrPacketParser := packet.MakePacketParser(stderrReader, nil) - packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true) cproc := &ClientProc{ Cmd: ecmd, StartTs: startTs, @@ -107,7 +189,7 @@ func (cproc *ClientProc) Close() { cproc.StderrReader.Close() } if cproc.Cmd != nil { - cproc.Cmd.Process.Kill() + cproc.Cmd.Kill() } } diff --git a/waveshell/pkg/shexec/shexec.go b/waveshell/pkg/shexec/shexec.go index 8ce9f10e1..d1be9414d 100644 --- a/waveshell/pkg/shexec/shexec.go +++ b/waveshell/pkg/shexec/shexec.go @@ -24,6 +24,7 @@ import ( "github.com/alessio/shellescape" "github.com/creack/pty" "github.com/google/uuid" + "github.com/kevinburke/ssh_config" "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/cirfile" "github.com/wavetermdev/waveterm/waveshell/pkg/mpio" @@ -31,6 +32,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/shellapi" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" + "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" "golang.org/x/sys/unix" ) @@ -476,6 +478,76 @@ func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) } } +func (opts SSHOpts) ConnectToClient() (*ssh.Client, error) { + ssh_config.ReloadConfigs() + configIdentity, _ := ssh_config.GetStrict(opts.SSHHost, "IdentityFile") + var identityFile string + if opts.SSHIdentity != "" { + identityFile = opts.SSHIdentity + } else { + identityFile = configIdentity + } + + var authMethods []ssh.AuthMethod + var hostKeyCallback ssh.HostKeyCallback + if identityFile != "" { + sshKeyFile, err := os.ReadFile(base.ExpandHomeDir(identityFile)) + if err != nil { + return nil, fmt.Errorf("failed to read ssh key file. err: %+v", err) + } + signer, err := ssh.ParsePrivateKey(sshKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to parse private ssh key. err: %+v", err) + } + /* + publicKey, err := ssh.ParsePublicKey(sshKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to parse public ssh key. err: %+v", err) + } + */ + authMethods = append(authMethods, ssh.PublicKeys(signer)) + hostKeyCallback = ssh.InsecureIgnoreHostKey() + } else { + hostKeyCallback = ssh.InsecureIgnoreHostKey() + } + configUser, _ := ssh_config.GetStrict(opts.SSHHost, "User") + configHostName, _ := ssh_config.GetStrict(opts.SSHHost, "HostName") + configPort, _ := ssh_config.GetStrict(opts.SSHHost, "Port") + var username string + if opts.SSHUser != "" { + username = opts.SSHUser + } else if configUser != "" { + username = configUser + } else { + user, err := user.Current() + if err != nil { + return nil, fmt.Errorf("failed to get user for ssh: %+v", err) + } + username = user.Username + } + var hostName string + if configHostName != "" { + hostName = configHostName + } else { + hostName = opts.SSHHost + } + clientConfig := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + } + var port string + if opts.SSHPort != 0 && opts.SSHPort != 22 { + port = strconv.Itoa(opts.SSHPort) + } else if configPort != "" && configPort != "22" { + port = configPort + } else { + port = "22" + } + networkAddr := hostName + ":" + port + return ssh.Dial("tcp", networkAddr, clientConfig) +} + func (opts SSHOpts) MakeMShellSSHOpts() string { var moreSSHOpts []string if opts.SSHIdentity != "" { diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index 9a4be7916..a9e1e0d44 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -36,6 +36,7 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/scbase" "github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket" "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" + "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" ) @@ -77,6 +78,12 @@ else fi ` +const WaveshellServerRunOnlyFmt = ` + PATH=$PATH:~/.mshell; + [%PINGPACKET%] + mshell-[%VERSION%] --server +` + func MakeLocalMShellCommandStr(isSudo bool) (string, error) { mshellPath, err := scbase.LocalMShellBinaryPath() if err != nil { @@ -95,6 +102,13 @@ func MakeServerCommandStr() string { return rtn } +func MakeServerRunOnlyCommandStr() string { + rtn := strings.ReplaceAll(WaveshellServerRunOnlyFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion)) + rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket) + return rtn + +} + const ( StatusConnected = sstore.RemoteStatus_Connected StatusConnecting = sstore.RemoteStatus_Connecting @@ -1206,7 +1220,286 @@ func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error return utilfn.CombineStrArrays(rtn, activeShells), nil } +func (msh *MShellProc) LaunchWithSshLib(interactive bool) { + remoteCopy := msh.GetRemoteCopy() + if remoteCopy.Archived { + msh.WriteToPtyBuffer("cannot launch archived remote\n") + return + } + curStatus := msh.GetStatus() + if curStatus == StatusConnected { + msh.WriteToPtyBuffer("remote is already connected (no action taken)\n") + return + } + if curStatus == StatusConnecting { + msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n") + return + } + istatus := msh.GetInstallStatus() + if istatus == StatusConnecting { + msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n") + return + } + if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 { + msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort) + } else { + msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName) + } + sshOpts := convertSSHOpts(remoteCopy.SSHOpts) + sshOpts.SSHErrorsToTty = true + if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive { + sshOpts.BatchMode = true + } + client, err := sshOpts.ConnectToClient() + if err != nil { + msh.WriteToPtyBuffer("*error, ssh cannot connect to client: %v\n", err) + } + makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background()) + defer makeClientCancelFn() + msh.WithLock(func() { + msh.Err = nil + msh.ErrNoInitPk = false + msh.Status = StatusConnecting + msh.MakeClientCancelFn = makeClientCancelFn + deadlineTime := time.Now().Add(RemoteConnectTimeout) + msh.MakeClientDeadline = &deadlineTime + go msh.NotifyRemoteUpdate() + }) + go msh.watchClientDeadlineTime() + session, err := client.NewSession() + if err != nil { + msh.WriteToPtyBuffer("*error, ssh cannot create session: %v\n", err) + } + cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()}) + // TODO check if initPk.State is not nil + var mshellVersion string + var hitDeadline bool + msh.WithLock(func() { + msh.MakeClientCancelFn = nil + if time.Now().After(*msh.MakeClientDeadline) { + hitDeadline = true + } + msh.MakeClientDeadline = nil + if initPk == nil { + msh.ErrNoInitPk = true + } + if initPk != nil { + msh.UName = initPk.UName + mshellVersion = initPk.Version + if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 { + // only set NeedsMShellUpgrade if we got an InitPk + msh.NeedsMShellUpgrade = true + } + msh.InitPkShellType = initPk.Shell + } + msh.StateMap.Clear() + // no notify here, because we'll call notify in either case below + }) + if err == context.Canceled { + if hitDeadline { + msh.WriteToPtyBuffer("*connect timeout\n") + msh.setErrorStatus(errors.New("connect timeout")) + } else { + msh.WriteToPtyBuffer("*forced disconnection\n") + msh.WithLock(func() { + msh.Status = StatusDisconnected + go msh.NotifyRemoteUpdate() + }) + } + return + } + if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) { + err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion) + } + if err != nil { + cs := fmt.Sprintf("error: %v\n", err) + os.WriteFile("/Users/oneirocosm/.waveterm-dev/temp.txt", []byte(cs), 0644) + msh.setErrorStatus(err) + msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err) + go msh.tryAutoInstall() + return + } + msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk) + msh.WithLock(func() { + msh.ServerProc = cproc + msh.Status = StatusConnected + }) + go func() { + exitErr := cproc.Cmd.Wait() + exitCode := shexec.GetExitCode(exitErr) + msh.WithLock(func() { + if msh.Status == StatusConnected || msh.Status == StatusConnecting { + msh.Status = StatusDisconnected + go msh.NotifyRemoteUpdate() + } + }) + msh.WriteToPtyBuffer("*disconnected exitcode=%d\n", exitCode) + }() + go msh.ProcessPackets() + msh.initActiveShells() + go msh.NotifyRemoteUpdate() + return +} + func (msh *MShellProc) Launch(interactive bool) { + remoteCopy := msh.GetRemoteCopy() + if remoteCopy.Archived { + msh.WriteToPtyBuffer("cannot launch archived remote\n") + return + } + curStatus := msh.GetStatus() + if curStatus == StatusConnected { + msh.WriteToPtyBuffer("remote is already connected (no action taken)\n") + return + } + if curStatus == StatusConnecting { + msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n") + return + } + sapi, err := shellapi.MakeShellApi(msh.GetShellType()) + if err != nil { + msh.WriteToPtyBuffer("*error, %v\n", err) + return + } + istatus := msh.GetInstallStatus() + if istatus == StatusConnecting { + msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n") + return + } + if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 { + msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort) + } else { + msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName) + } + sshOpts := convertSSHOpts(remoteCopy.SSHOpts) + sshOpts.SSHErrorsToTty = true + if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive { + sshOpts.BatchMode = true + } + makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background()) + defer makeClientCancelFn() + msh.WithLock(func() { + msh.Err = nil + msh.ErrNoInitPk = false + msh.Status = StatusConnecting + msh.MakeClientCancelFn = makeClientCancelFn + deadlineTime := time.Now().Add(RemoteConnectTimeout) + msh.MakeClientDeadline = &deadlineTime + go msh.NotifyRemoteUpdate() + }) + go msh.watchClientDeadlineTime() + var cmdStr string + var cproc *shexec.ClientProc + var initPk *packet.InitPacketType + if sshOpts.SSHHost == "" && remoteCopy.Local { + cmdStr, err = MakeLocalMShellCommandStr(remoteCopy.IsSudo()) + if err != nil { + msh.WriteToPtyBuffer("*error, cannot find local mshell binary: %v\n", err) + return + } + ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi) + var cmdPty *os.File + cmdPty, err = msh.addControllingTty(ecmd) + if err != nil { + statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err) + msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error()) + msh.setErrorStatus(statusErr) + return + } + defer func() { + if len(ecmd.ExtraFiles) > 0 { + ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close() + } + }() + go msh.RunPtyReadLoop(cmdPty) + if remoteCopy.SSHOpts.SSHPassword != "" { + go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword) + } + cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) + } else { + var client *ssh.Client + client, err = sshOpts.ConnectToClient() + es := fmt.Sprintf("err: %v\n", err) + os.WriteFile("/Users/oneirocosm/.waveterm-dev/temp.txt", []byte(es), 0644) + if err != nil { + msh.WriteToPtyBuffer("*error, ssh cannot connect to client: %v\n", err) + } + var session *ssh.Session + session, err = client.NewSession() + if err != nil { + msh.WriteToPtyBuffer("*error, ssh cannot create session: %v\n", err) + } + cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()}) + } + // TODO check if initPk.State is not nil + var mshellVersion string + var hitDeadline bool + msh.WithLock(func() { + msh.MakeClientCancelFn = nil + if time.Now().After(*msh.MakeClientDeadline) { + hitDeadline = true + } + msh.MakeClientDeadline = nil + if initPk == nil { + msh.ErrNoInitPk = true + } + if initPk != nil { + msh.UName = initPk.UName + mshellVersion = initPk.Version + if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 { + // only set NeedsMShellUpgrade if we got an InitPk + msh.NeedsMShellUpgrade = true + } + msh.InitPkShellType = initPk.Shell + } + msh.StateMap.Clear() + // no notify here, because we'll call notify in either case below + }) + if err == context.Canceled { + if hitDeadline { + msh.WriteToPtyBuffer("*connect timeout\n") + msh.setErrorStatus(errors.New("connect timeout")) + } else { + msh.WriteToPtyBuffer("*forced disconnection\n") + msh.WithLock(func() { + msh.Status = StatusDisconnected + go msh.NotifyRemoteUpdate() + }) + } + return + } + if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) { + err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion) + } + if err != nil { + msh.setErrorStatus(err) + msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err) + go msh.tryAutoInstall() + return + } + msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk) + msh.WithLock(func() { + msh.ServerProc = cproc + msh.Status = StatusConnected + }) + go func() { + exitErr := cproc.Cmd.Wait() + exitCode := shexec.GetExitCode(exitErr) + msh.WithLock(func() { + if msh.Status == StatusConnected || msh.Status == StatusConnecting { + msh.Status = StatusDisconnected + go msh.NotifyRemoteUpdate() + } + }) + msh.WriteToPtyBuffer("*disconnected exitcode=%d\n", exitCode) + }() + go msh.ProcessPackets() + msh.initActiveShells() + go msh.NotifyRemoteUpdate() + return +} + +func (msh *MShellProc) LaunchOld(interactive bool) { remoteCopy := msh.GetRemoteCopy() if remoteCopy.Archived { msh.WriteToPtyBuffer("cannot launch archived remote\n") @@ -1281,7 +1574,7 @@ func (msh *MShellProc) Launch(interactive bool) { go msh.NotifyRemoteUpdate() }) go msh.watchClientDeadlineTime() - cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, ecmd) + cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) // TODO check if initPk.State is not nil var mshellVersion string var hitDeadline bool