From ea5e5e1baceece1de0912c858c8475405a2db2ef Mon Sep 17 00:00:00 2001 From: Sylvie Crowe <107814465+oneirocosm@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:18:43 -0700 Subject: [PATCH] Integrate Skeema Knownhosts fix (#287) This fix makes it possible to differentiate between keys when multiple are provided by the remote server. It does not solve the case of multiple keys of the same type being shared, but it handles multiple keys of different types being shared, which is much more common. --- go.mod | 1 + go.sum | 2 ++ pkg/remote/sshclient.go | 58 +++++++++++++++++++++++++---------------- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/go.mod b/go.mod index f0f6e2c37..5b7239e41 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/sashabaranov/go-openai v1.28.3 github.com/sawka/txwrap v0.2.0 github.com/shirou/gopsutil/v4 v4.24.7 + github.com/skeema/knownhosts v1.3.0 github.com/spf13/cobra v1.8.1 github.com/wavetermdev/htmltoken v0.1.0 golang.org/x/crypto v0.26.0 diff --git a/go.sum b/go.sum index 331e832e1..a88300e03 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= +github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go index 160fc65e2..b3f4f8e5e 100644 --- a/pkg/remote/sshclient.go +++ b/pkg/remote/sshclient.go @@ -22,16 +22,19 @@ import ( "time" "github.com/kevinburke/ssh_config" + "github.com/skeema/knownhosts" "github.com/wavetermdev/thenextwave/pkg/userinput" "github.com/wavetermdev/thenextwave/pkg/wavebase" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/knownhosts" + xknownhosts "golang.org/x/crypto/ssh/knownhosts" ) type UserInputCancelError struct { Err error } +type HostKeyAlgorithms = func(hostWithPort string) (algos []string) + func (uice UserInputCancelError) Error() string { return uice.Err.Error() } @@ -352,7 +355,7 @@ func lineContainsMatch(line []byte, matches [][]byte) bool { return false } -func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { +func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithms, error) { ssh_config.ReloadConfigs() rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile") userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes @@ -361,7 +364,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { osUser, err := user.Current() if err != nil { - return nil, err + return nil, nil, err } var unexpandedKnownHostsFiles []string if osUser.Username == "root" { @@ -377,7 +380,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { // there are no good known hosts files if len(knownHostsFiles) == 0 { - return nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden") + return nil, nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden") } var unreadableFiles []string @@ -386,9 +389,9 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { // incorrectly. if a problem file is found, it is removed from our list // and we try again var basicCallback ssh.HostKeyCallback + var hostKeyAlgorithms HostKeyAlgorithms for basicCallback == nil && len(knownHostsFiles) > 0 { - var err error - basicCallback, err = knownhosts.New(knownHostsFiles...) + keyDb, err := knownhosts.NewDB(knownHostsFiles...) if serr, ok := err.(*os.PathError); ok { badFile := serr.Path unreadableFiles = append(unreadableFiles, badFile) @@ -399,18 +402,26 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { } } if len(okFiles) >= len(knownHostsFiles) { - return nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile) + return nil, nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile) } knownHostsFiles = okFiles } else if err != nil { // TODO handle obscure problems if possible - return nil, fmt.Errorf("known_hosts formatting error: %+v", err) + return nil, nil, fmt.Errorf("known_hosts formatting error: %+v", err) + } else { + basicCallback = keyDb.HostKeyCallback() + hostKeyAlgorithms = keyDb.HostKeyAlgorithms } } if basicCallback == nil { basicCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return &knownhosts.KeyError{} + return &xknownhosts.KeyError{} + } + // need to return nil here to avoid null pointer from attempting to call + // the one provided by the db if nothing was found + hostKeyAlgorithms = func(hostWithPort string) (algos []string) { + return nil } } @@ -419,21 +430,21 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { if err == nil { // success return nil - } else if _, ok := err.(*knownhosts.RevokedError); ok { + } else if _, ok := err.(*xknownhosts.RevokedError); ok { // revoked credentials are refused outright return err - } else if _, ok := err.(*knownhosts.KeyError); !ok { + } else if _, ok := err.(*xknownhosts.KeyError); !ok { // this is an unknown error (note the !ok is opposite of usual) return err } - serr, _ := err.(*knownhosts.KeyError) + serr, _ := err.(*xknownhosts.KeyError) if len(serr.Want) == 0 { // the key was not found // try to write to a file that could be read err := fmt.Errorf("placeholder, should not be returned") // a null value here can cause problems with empty slice for _, filename := range knownHostsFiles { - newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) + newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key) getUserVerification := createUnknownKeyVerifier(filename, hostname, remote.String(), key) err = writeToKnownHosts(filename, newLine, getUserVerification) if err == nil { @@ -448,7 +459,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { // should catch cases where there is no known_hosts file if err != nil { for _, filename := range unreadableFiles { - newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) + newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key) getUserVerification := createMissingKnownHostsVerifier(filename, hostname, remote.String(), key) err = writeToKnownHosts(filename, newLine, getUserVerification) if err == nil { @@ -496,7 +507,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { return fmt.Errorf("remote host identification has changed") } - updatedCallback, err := knownhosts.New(knownHostsFiles...) + updatedCallback, err := xknownhosts.New(knownHostsFiles...) if err != nil { return err } @@ -504,7 +515,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { return updatedCallback(hostname, remote, key) } - return waveHostKeyCallback, nil + return waveHostKeyCallback, hostKeyAlgorithms, nil } func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { @@ -530,7 +541,7 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error if err != nil { return nil, err } - remoteName := sshKeywords.User + "@" + knownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) + remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, "")) keyboardInteractive := ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(connCtx, "", remoteName)) @@ -571,17 +582,18 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error authMethods = append(authMethods, authMethod) } - hostKeyCallback, err := createHostKeyCallback(opts) + hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(opts) if err != nil { return nil, err } - clientConfig := &ssh.ClientConfig{ - User: sshKeywords.User, - Auth: authMethods, - HostKeyCallback: hostKeyCallback, - } networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port + clientConfig := &ssh.ClientConfig{ + User: sshKeywords.User, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + HostKeyAlgorithms: hostKeyAlgorithms(networkAddr), + } return DialContext(connCtx, "tcp", networkAddr, clientConfig) }