Integrate Skeema Knownhosts fix (#287)

This fix makes it possible to differentiate between keys when multiple
are provided by the remote server. It does not solve the case of
multiple keys of the same type being shared, but it handles multiple
keys of different types being shared, which is much more common.
This commit is contained in:
Sylvie Crowe 2024-08-28 13:18:43 -07:00 committed by GitHub
parent a7606b8363
commit ea5e5e1bac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 23 deletions

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/sashabaranov/go-openai v1.28.3 github.com/sashabaranov/go-openai v1.28.3
github.com/sawka/txwrap v0.2.0 github.com/sawka/txwrap v0.2.0
github.com/shirou/gopsutil/v4 v4.24.7 github.com/shirou/gopsutil/v4 v4.24.7
github.com/skeema/knownhosts v1.3.0
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
github.com/wavetermdev/htmltoken v0.1.0 github.com/wavetermdev/htmltoken v0.1.0
golang.org/x/crypto v0.26.0 golang.org/x/crypto v0.26.0

2
go.sum
View File

@ -63,6 +63,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY=
github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View File

@ -22,16 +22,19 @@ import (
"time" "time"
"github.com/kevinburke/ssh_config" "github.com/kevinburke/ssh_config"
"github.com/skeema/knownhosts"
"github.com/wavetermdev/thenextwave/pkg/userinput" "github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wavebase"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts" xknownhosts "golang.org/x/crypto/ssh/knownhosts"
) )
type UserInputCancelError struct { type UserInputCancelError struct {
Err error Err error
} }
type HostKeyAlgorithms = func(hostWithPort string) (algos []string)
func (uice UserInputCancelError) Error() string { func (uice UserInputCancelError) Error() string {
return uice.Err.Error() return uice.Err.Error()
} }
@ -352,7 +355,7 @@ func lineContainsMatch(line []byte, matches [][]byte) bool {
return false return false
} }
func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) { func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithms, error) {
ssh_config.ReloadConfigs() ssh_config.ReloadConfigs()
rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile") rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile")
userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes
@ -361,7 +364,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
osUser, err := user.Current() osUser, err := user.Current()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
var unexpandedKnownHostsFiles []string var unexpandedKnownHostsFiles []string
if osUser.Username == "root" { if osUser.Username == "root" {
@ -377,7 +380,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
// there are no good known hosts files // there are no good known hosts files
if len(knownHostsFiles) == 0 { if len(knownHostsFiles) == 0 {
return nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden") return nil, nil, fmt.Errorf("no known_hosts files provided by ssh. defaults are overridden")
} }
var unreadableFiles []string var unreadableFiles []string
@ -386,9 +389,9 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
// incorrectly. if a problem file is found, it is removed from our list // incorrectly. if a problem file is found, it is removed from our list
// and we try again // and we try again
var basicCallback ssh.HostKeyCallback var basicCallback ssh.HostKeyCallback
var hostKeyAlgorithms HostKeyAlgorithms
for basicCallback == nil && len(knownHostsFiles) > 0 { for basicCallback == nil && len(knownHostsFiles) > 0 {
var err error keyDb, err := knownhosts.NewDB(knownHostsFiles...)
basicCallback, err = knownhosts.New(knownHostsFiles...)
if serr, ok := err.(*os.PathError); ok { if serr, ok := err.(*os.PathError); ok {
badFile := serr.Path badFile := serr.Path
unreadableFiles = append(unreadableFiles, badFile) unreadableFiles = append(unreadableFiles, badFile)
@ -399,18 +402,26 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
} }
} }
if len(okFiles) >= len(knownHostsFiles) { if len(okFiles) >= len(knownHostsFiles) {
return nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile) return nil, nil, fmt.Errorf("problem file (%s) doesn't exist. this should not be possible", badFile)
} }
knownHostsFiles = okFiles knownHostsFiles = okFiles
} else if err != nil { } else if err != nil {
// TODO handle obscure problems if possible // TODO handle obscure problems if possible
return nil, fmt.Errorf("known_hosts formatting error: %+v", err) return nil, nil, fmt.Errorf("known_hosts formatting error: %+v", err)
} else {
basicCallback = keyDb.HostKeyCallback()
hostKeyAlgorithms = keyDb.HostKeyAlgorithms
} }
} }
if basicCallback == nil { if basicCallback == nil {
basicCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { basicCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return &knownhosts.KeyError{} return &xknownhosts.KeyError{}
}
// need to return nil here to avoid null pointer from attempting to call
// the one provided by the db if nothing was found
hostKeyAlgorithms = func(hostWithPort string) (algos []string) {
return nil
} }
} }
@ -419,21 +430,21 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
if err == nil { if err == nil {
// success // success
return nil return nil
} else if _, ok := err.(*knownhosts.RevokedError); ok { } else if _, ok := err.(*xknownhosts.RevokedError); ok {
// revoked credentials are refused outright // revoked credentials are refused outright
return err return err
} else if _, ok := err.(*knownhosts.KeyError); !ok { } else if _, ok := err.(*xknownhosts.KeyError); !ok {
// this is an unknown error (note the !ok is opposite of usual) // this is an unknown error (note the !ok is opposite of usual)
return err return err
} }
serr, _ := err.(*knownhosts.KeyError) serr, _ := err.(*xknownhosts.KeyError)
if len(serr.Want) == 0 { if len(serr.Want) == 0 {
// the key was not found // the key was not found
// try to write to a file that could be read // try to write to a file that could be read
err := fmt.Errorf("placeholder, should not be returned") // a null value here can cause problems with empty slice err := fmt.Errorf("placeholder, should not be returned") // a null value here can cause problems with empty slice
for _, filename := range knownHostsFiles { for _, filename := range knownHostsFiles {
newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key)
getUserVerification := createUnknownKeyVerifier(filename, hostname, remote.String(), key) getUserVerification := createUnknownKeyVerifier(filename, hostname, remote.String(), key)
err = writeToKnownHosts(filename, newLine, getUserVerification) err = writeToKnownHosts(filename, newLine, getUserVerification)
if err == nil { if err == nil {
@ -448,7 +459,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
// should catch cases where there is no known_hosts file // should catch cases where there is no known_hosts file
if err != nil { if err != nil {
for _, filename := range unreadableFiles { for _, filename := range unreadableFiles {
newLine := knownhosts.Line([]string{knownhosts.Normalize(hostname)}, key) newLine := xknownhosts.Line([]string{xknownhosts.Normalize(hostname)}, key)
getUserVerification := createMissingKnownHostsVerifier(filename, hostname, remote.String(), key) getUserVerification := createMissingKnownHostsVerifier(filename, hostname, remote.String(), key)
err = writeToKnownHosts(filename, newLine, getUserVerification) err = writeToKnownHosts(filename, newLine, getUserVerification)
if err == nil { if err == nil {
@ -496,7 +507,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
return fmt.Errorf("remote host identification has changed") return fmt.Errorf("remote host identification has changed")
} }
updatedCallback, err := knownhosts.New(knownHostsFiles...) updatedCallback, err := xknownhosts.New(knownHostsFiles...)
if err != nil { if err != nil {
return err return err
} }
@ -504,7 +515,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, error) {
return updatedCallback(hostname, remote, key) return updatedCallback(hostname, remote, key)
} }
return waveHostKeyCallback, nil return waveHostKeyCallback, hostKeyAlgorithms, nil
} }
func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
@ -530,7 +541,7 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error
if err != nil { if err != nil {
return nil, err return nil, err
} }
remoteName := sshKeywords.User + "@" + knownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port)
publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, "")) publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, ""))
keyboardInteractive := ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(connCtx, "", remoteName)) keyboardInteractive := ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(connCtx, "", remoteName))
@ -571,17 +582,18 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error
authMethods = append(authMethods, authMethod) authMethods = append(authMethods, authMethod)
} }
hostKeyCallback, err := createHostKeyCallback(opts) hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port
clientConfig := &ssh.ClientConfig{ clientConfig := &ssh.ClientConfig{
User: sshKeywords.User, User: sshKeywords.User,
Auth: authMethods, Auth: authMethods,
HostKeyCallback: hostKeyCallback, HostKeyCallback: hostKeyCallback,
HostKeyAlgorithms: hostKeyAlgorithms(networkAddr),
} }
networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port
return DialContext(connCtx, "tcp", networkAddr, clientConfig) return DialContext(connCtx, "tcp", networkAddr, clientConfig)
} }