diff --git a/waveshell/pkg/shexec/shexec.go b/waveshell/pkg/shexec/shexec.go index d1be9414d..8ce9f10e1 100644 --- a/waveshell/pkg/shexec/shexec.go +++ b/waveshell/pkg/shexec/shexec.go @@ -24,7 +24,6 @@ import ( "github.com/alessio/shellescape" "github.com/creack/pty" "github.com/google/uuid" - "github.com/kevinburke/ssh_config" "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/cirfile" "github.com/wavetermdev/waveterm/waveshell/pkg/mpio" @@ -32,7 +31,6 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/shellapi" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" - "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" "golang.org/x/sys/unix" ) @@ -478,76 +476,6 @@ func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string, sapi shellapi.ShellApi) } } -func (opts SSHOpts) ConnectToClient() (*ssh.Client, error) { - ssh_config.ReloadConfigs() - configIdentity, _ := ssh_config.GetStrict(opts.SSHHost, "IdentityFile") - var identityFile string - if opts.SSHIdentity != "" { - identityFile = opts.SSHIdentity - } else { - identityFile = configIdentity - } - - var authMethods []ssh.AuthMethod - var hostKeyCallback ssh.HostKeyCallback - if identityFile != "" { - sshKeyFile, err := os.ReadFile(base.ExpandHomeDir(identityFile)) - if err != nil { - return nil, fmt.Errorf("failed to read ssh key file. err: %+v", err) - } - signer, err := ssh.ParsePrivateKey(sshKeyFile) - if err != nil { - return nil, fmt.Errorf("failed to parse private ssh key. err: %+v", err) - } - /* - publicKey, err := ssh.ParsePublicKey(sshKeyFile) - if err != nil { - return nil, fmt.Errorf("failed to parse public ssh key. err: %+v", err) - } - */ - authMethods = append(authMethods, ssh.PublicKeys(signer)) - hostKeyCallback = ssh.InsecureIgnoreHostKey() - } else { - hostKeyCallback = ssh.InsecureIgnoreHostKey() - } - configUser, _ := ssh_config.GetStrict(opts.SSHHost, "User") - configHostName, _ := ssh_config.GetStrict(opts.SSHHost, "HostName") - configPort, _ := ssh_config.GetStrict(opts.SSHHost, "Port") - var username string - if opts.SSHUser != "" { - username = opts.SSHUser - } else if configUser != "" { - username = configUser - } else { - user, err := user.Current() - if err != nil { - return nil, fmt.Errorf("failed to get user for ssh: %+v", err) - } - username = user.Username - } - var hostName string - if configHostName != "" { - hostName = configHostName - } else { - hostName = opts.SSHHost - } - clientConfig := &ssh.ClientConfig{ - User: username, - Auth: authMethods, - HostKeyCallback: hostKeyCallback, - } - var port string - if opts.SSHPort != 0 && opts.SSHPort != 22 { - port = strconv.Itoa(opts.SSHPort) - } else if configPort != "" && configPort != "22" { - port = configPort - } else { - port = "22" - } - networkAddr := hostName + ":" + port - return ssh.Dial("tcp", networkAddr, clientConfig) -} - func (opts SSHOpts) MakeMShellSSHOpts() string { var moreSSHOpts []string if opts.SSHIdentity != "" { diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index a9e1e0d44..8747735b1 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -1250,7 +1250,7 @@ func (msh *MShellProc) LaunchWithSshLib(interactive bool) { if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive { sshOpts.BatchMode = true } - client, err := sshOpts.ConnectToClient() + client, err := ConnectToClient(remoteCopy.SSHOpts) if err != nil { msh.WriteToPtyBuffer("*error, ssh cannot connect to client: %v\n", err) } @@ -1418,16 +1418,16 @@ func (msh *MShellProc) Launch(interactive bool) { cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd}) } else { var client *ssh.Client - client, err = sshOpts.ConnectToClient() - es := fmt.Sprintf("err: %v\n", err) - os.WriteFile("/Users/oneirocosm/.waveterm-dev/temp.txt", []byte(es), 0644) + client, err = ConnectToClient(remoteCopy.SSHOpts) if err != nil { msh.WriteToPtyBuffer("*error, ssh cannot connect to client: %v\n", err) + return } var session *ssh.Session session, err = client.NewSession() if err != nil { msh.WriteToPtyBuffer("*error, ssh cannot create session: %v\n", err) + return } cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()}) } diff --git a/wavesrv/pkg/remote/sshclient.go b/wavesrv/pkg/remote/sshclient.go new file mode 100644 index 000000000..aa7cbabf9 --- /dev/null +++ b/wavesrv/pkg/remote/sshclient.go @@ -0,0 +1,108 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "errors" + "fmt" + "os" + "os/user" + "strconv" + "strings" + + "github.com/kevinburke/ssh_config" + "github.com/wavetermdev/waveterm/waveshell/pkg/base" + "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" + "golang.org/x/crypto/ssh" +) + +func createPublicKeyAuth(identityFile string, passphrase string) (ssh.AuthMethod, error) { + privateKey, err := os.ReadFile(base.ExpandHomeDir(identityFile)) + if err != nil { + return nil, fmt.Errorf("failed to read ssh key file. err: %+v", err) + } + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + if errors.Is(err, &ssh.PassphraseMissingError{}) { + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase)) + if err != nil { + return nil, fmt.Errorf("failed to parse private ssh key with passphrase. err: %+v", err) + } + } else { + return nil, fmt.Errorf("failed to parse private ssh key. err: %+v", err) + } + } + return ssh.PublicKeys(signer), nil +} + +func createKeyboardInteractiveAuth(password string) ssh.AuthMethod { + challenge := func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + for _, q := range questions { + if strings.Contains(strings.ToLower(q), "password") { + answers = append(answers, password) + } else { + answers = append(answers, "") + } + } + return answers, nil + } + return ssh.KeyboardInteractive(challenge) +} + +func ConnectToClient(opts *sstore.SSHOpts) (*ssh.Client, error) { + ssh_config.ReloadConfigs() + configIdentity, _ := ssh_config.GetStrict(opts.SSHHost, "IdentityFile") + var identityFile string + if opts.SSHIdentity != "" { + identityFile = opts.SSHIdentity + } else { + identityFile = configIdentity + } + + hostKeyCallback := ssh.InsecureIgnoreHostKey() + var authMethods []ssh.AuthMethod + publicKeyAuth, err := createPublicKeyAuth(identityFile, opts.SSHPassword) + if err == nil { + authMethods = append(authMethods, publicKeyAuth) + } + authMethods = append(authMethods, createKeyboardInteractiveAuth(opts.SSHPassword)) + authMethods = append(authMethods, ssh.Password(opts.SSHPassword)) + + configUser, _ := ssh_config.GetStrict(opts.SSHHost, "User") + configHostName, _ := ssh_config.GetStrict(opts.SSHHost, "HostName") + configPort, _ := ssh_config.GetStrict(opts.SSHHost, "Port") + var username string + if opts.SSHUser != "" { + username = opts.SSHUser + } else if configUser != "" { + username = configUser + } else { + user, err := user.Current() + if err != nil { + return nil, fmt.Errorf("failed to get user for ssh: %+v", err) + } + username = user.Username + } + var hostName string + if configHostName != "" { + hostName = configHostName + } else { + hostName = opts.SSHHost + } + clientConfig := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + } + var port string + if opts.SSHPort != 0 && opts.SSHPort != 22 { + port = strconv.Itoa(opts.SSHPort) + } else if configPort != "" && configPort != "22" { + port = configPort + } else { + port = "22" + } + networkAddr := hostName + ":" + port + return ssh.Dial("tcp", networkAddr, clientConfig) +}