diff --git a/go.mod b/go.mod index f0f6e2c37..5b7239e41 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/sashabaranov/go-openai v1.28.3 github.com/sawka/txwrap v0.2.0 github.com/shirou/gopsutil/v4 v4.24.7 + github.com/skeema/knownhosts v1.3.0 github.com/spf13/cobra v1.8.1 github.com/wavetermdev/htmltoken v0.1.0 golang.org/x/crypto v0.26.0 diff --git a/go.sum b/go.sum index 331e832e1..a88300e03 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= +github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go index 160fc65e2..b3f4f8e5e 100644 --- a/pkg/remote/sshclient.go +++ b/pkg/remote/sshclient.go @@ -22,16 +22,19 @@ import ( "time" "github.com/kevinburke/ssh_config" + "github.com/skeema/knownhosts" "github.com/wavetermdev/thenextwave/pkg/userinput" "github.com/wavetermdev/thenextwave/pkg/wavebase" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/knownhosts" + xknownhosts "golang.org/x/crypto/ssh/knownhosts" ) type UserInputCancelError struct { Err error } +type HostKeyAlgorithms = func(hostWithPort string) (algos []string) + func (uice UserInputCancelError) Error() string { return uice.Err.Error() } @@ -352,7 +355,7 @@ func lineContainsMatch(line []byte, matches [][]byte) bool { return false } -func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { +func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithms, error) { ssh_config.ReloadConfigs() rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile") userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes @@ -361,7 +364,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { osUser, err := user.Current() if err != nil { - return nil, err + return nil, nil, err } var unexpandedKnownHostsFiles []string if osUser.Username == "root" { @@ -377,7 +380,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { // there are no good known hosts files if len(knownHostsFiles) == 0 { - return nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden") + return nil, nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden") } var unreadableFiles []string @@ -386,9 +389,9 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { // incorrectly. if a problem file is found, it is removed from our list // and we try again var basicCallback ssh.HostKeyCallback + var hostKeyAlgorithms HostKeyAlgorithms for basicCallback == nil && len(knownHostsFiles) > 0 { - var err error - basicCallback, err = knownhosts.New(knownHostsFiles...) + keyDb, err := knownhosts.NewDB(knownHostsFiles...) if serr, ok := err.(*os.PathError); ok { badFile := serr.Path unreadableFiles = append(unreadableFiles, badFile) @@ -399,18 +402,26 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { } } if len(okFiles) >= len(knownHostsFiles) { - return nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile) + return nil, nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile) } knownHostsFiles = okFiles } else if err != nil { // TODO handle obscure problems if possible - return nil, fmt.Errorf("known_hosts formatting error: %+v", err) + return nil, nil, fmt.Errorf("known_hosts formatting error: %+v", err) + } else { + basicCallback = keyDb.HostKeyCallback() + hostKeyAlgorithms = keyDb.HostKeyAlgorithms } } if basicCallback == nil { basicCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return &knownhosts.KeyError{} + return &xknownhosts.KeyError{} + } + // need to return nil here to avoid null pointer from attempting to call + // the one provided by the db if nothing was found + hostKeyAlgorithms = func(hostWithPort string) (algos []string) { + return nil } } @@ -419,21 +430,21 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { if err == nil { // success return nil - } else if _, ok := err.(*knownhosts.RevokedError); ok { + } else if _, ok := err.(*xknownhosts.RevokedError); ok { // revoked credentials are refused outright return err - } else if _, ok := err.(*knownhosts.KeyError); !ok { + } else if _, ok := err.(*xknownhosts.KeyError); !ok { // this is an unknown error (note the !ok is opposite of usual) return err } - serr, _ := err.(*knownhosts.KeyError) + serr, _ := err.(*xknownhosts.KeyError) if len(serr.Want) == 0 { // the key was not found // try to write to a file that could be read err := fmt.Errorf("placeholder, should not be returned") // a null value here can cause problems with empty slice for _, filename := range knownHostsFiles { - newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) + newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key) getUserVerification := createUnknownKeyVerifier(filename, hostname, remote.String(), key) err = writeToKnownHosts(filename, newLine, getUserVerification) if err == nil { @@ -448,7 +459,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { // should catch cases where there is no known_hosts file if err != nil { for _, filename := range unreadableFiles { - newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) + newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key) getUserVerification := createMissingKnownHostsVerifier(filename, hostname, remote.String(), key) err = writeToKnownHosts(filename, newLine, getUserVerification) if err == nil { @@ -496,7 +507,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { return fmt.Errorf("remote host identification has changed") } - updatedCallback, err := knownhosts.New(knownHostsFiles...) + updatedCallback, err := xknownhosts.New(knownHostsFiles...) if err != nil { return err } @@ -504,7 +515,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { return updatedCallback(hostname, remote, key) } - return waveHostKeyCallback, nil + return waveHostKeyCallback, hostKeyAlgorithms, nil } func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { @@ -530,7 +541,7 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error if err != nil { return nil, err } - remoteName := sshKeywords.User + "@" + knownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) + remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, "")) keyboardInteractive := ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(connCtx, "", remoteName)) @@ -571,17 +582,18 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error authMethods = append(authMethods, authMethod) } - hostKeyCallback, err := createHostKeyCallback(opts) + hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(opts) if err != nil { return nil, err } - clientConfig := &ssh.ClientConfig{ - User: sshKeywords.User, - Auth: authMethods, - HostKeyCallback: hostKeyCallback, - } networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port + clientConfig := &ssh.ClientConfig{ + User: sshKeywords.User, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + HostKeyAlgorithms: hostKeyAlgorithms(networkAddr), + } return DialContext(connCtx, "tcp", networkAddr, clientConfig) }