From ac6f9a05d4f3661ca584eb61376a21633426b5be Mon Sep 17 00:00:00 2001 From: Sylvie Crowe <107814465+oneirocosm@users.noreply.github.com> Date: Fri, 25 Oct 2024 12:14:40 -0700 Subject: [PATCH] ProxyJump Support (#1107) This adds basic ProxyJump support for handling ssh connections. --- pkg/remote/conncontroller/conncontroller.go | 4 +- pkg/remote/sshclient.go | 180 +++++++++++++++----- 2 files changed, 138 insertions(+), 46 deletions(-) diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index d50f63978..545738f02 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -422,9 +422,9 @@ func (conn *SSHConn) WithLock(fn func()) { } func (conn *SSHConn) connectInternal(ctx context.Context) error { - client, err := remote.ConnectToClient(ctx, conn.Opts) //todo specify or remove opts + client, _, err := remote.ConnectToClient(ctx, conn.Opts, nil, 0) if err != nil { - log.Printf("error: failed to connect to client %s: %v\n", conn.GetName(), err) + log.Printf("error: failed to connect to client %s: %s\n", conn.GetName(), err) return err } fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String())) diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go index 13ee56d31..13893784e 100644 --- a/pkg/remote/sshclient.go +++ b/pkg/remote/sshclient.go @@ -11,6 +11,7 @@ import ( "encoding/base64" "fmt" "log" + "math" "net" "os" "os/exec" @@ -31,6 +32,8 @@ import ( xknownhosts "golang.org/x/crypto/ssh/knownhosts" ) +const SshProxyJumpMaxDepth = 10 + type UserInputCancelError struct { Err error } @@ -41,6 +44,24 @@ func (uice UserInputCancelError) Error() string { return uice.Err.Error() } +type ConnectionDebugInfo struct { + CurrentClient *ssh.Client + NextOpts *SSHOpts + JumpNum int32 +} + +type ConnectionError struct { + *ConnectionDebugInfo + Err error +} + +func (ce ConnectionError) Error() string { + if ce.CurrentClient == nil { + return fmt.Sprintf("Connecting to %+#v, Error: %v", ce.NextOpts, ce.Err) + } + return fmt.Sprintf("Connecting from %v to %+#v (jump number %d), Error: %v", ce.CurrentClient, ce.NextOpts, ce.JumpNum, ce.Err) +} + // This exists to trick the ssh library into continuing to try // different public keys even when the current key cannot be // properly parsed @@ -68,7 +89,7 @@ func createDummySigner() ([]ssh.Signer, error) { // they were successes. An error in this function prevents any other // keys from being attempted. But if there's an error because of a dummy // file, the library can still try again with a new key. -func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, authSockSignersExt []ssh.Signer, agentClient agent.ExtendedAgent) func() ([]ssh.Signer, error) { +func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, authSockSignersExt []ssh.Signer, agentClient agent.ExtendedAgent, debugInfo *ConnectionDebugInfo) func() ([]ssh.Signer, error) { var identityFiles []string existingKeys := make(map[string][]byte) @@ -103,7 +124,7 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, } if len(*identityFilesPtr) == 0 { - return nil, fmt.Errorf("no identity files remaining") + return nil, ConnectionError{ConnectionDebugInfo: debugInfo, Err: fmt.Errorf("no identity files remaining")} } identityFile := (*identityFilesPtr)[0] *identityFilesPtr = (*identityFilesPtr)[1:] @@ -123,7 +144,7 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, PrivateKey: unencryptedPrivateKey, }) } - return []ssh.Signer{signer}, err + return []ssh.Signer{signer}, nil } } if _, ok := err.(*ssh.PassphraseMissingError); !ok { @@ -148,7 +169,8 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, if err != nil { // this is an error where we actually do want to stop // trying keys - return nil, UserInputCancelError{Err: err} + + return nil, ConnectionError{ConnectionDebugInfo: debugInfo, Err: UserInputCancelError{Err: err}} } unencryptedPrivateKey, err = ssh.ParseRawPrivateKeyWithPassphrase(privateKey, []byte([]byte(response.Text))) if err != nil { @@ -165,11 +187,11 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, PrivateKey: unencryptedPrivateKey, }) } - return []ssh.Signer{signer}, err + return []ssh.Signer{signer}, nil } } -func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string) func() (secret string, err error) { +func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string, debugInfo *ConnectionDebugInfo) func() (secret string, err error) { return func() (secret string, err error) { ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second) defer cancelFn() @@ -185,13 +207,13 @@ func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisp } response, err := userinput.GetUserInput(ctx, request) if err != nil { - return "", err + return "", ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} } return response.Text, nil } } -func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { +func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string, debugInfo *ConnectionDebugInfo) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { if len(questions) != len(echos) { return nil, fmt.Errorf("bad response from server: questions has len %d, echos has len %d", len(questions), len(echos)) @@ -200,7 +222,7 @@ func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteNam echo := echos[i] answer, err := promptChallengeQuestion(connCtx, question, echo, remoteName) if err != nil { - return nil, err + return nil, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} } answers = append(answers, answer) } @@ -336,12 +358,9 @@ func lineContainsMatch(line []byte, matches [][]byte) bool { return false } -func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithms, error) { - ssh_config.ReloadConfigs() - rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile") - userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes - rawGlobalKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "GlobalKnownHostsFile") - globalKnownHostsFiles := strings.Fields(rawGlobalKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes +func createHostKeyCallback(sshKeywords *SshKeywords) (ssh.HostKeyCallback, HostKeyAlgorithms, error) { + globalKnownHostsFiles := sshKeywords.GlobalKnownHostsFile + userKnownHostsFiles := sshKeywords.UserKnownHostsFile osUser, err := user.Current() if err != nil { @@ -485,6 +504,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithm "%s\n\n"+ "**Offending Keys** \n"+ "%s", key.Type(), correctKeyFingerprint, strings.Join(bulletListKnownHosts, " \n"), strings.Join(offendingKeysFmt, " \n")) + log.Print(errorMsg) //update := scbus.MakeUpdatePacket() // create update into alert message @@ -504,29 +524,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithm return waveHostKeyCallback, hostKeyAlgorithms, nil } -func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - d := net.Dialer{Timeout: config.Timeout} - conn, err := d.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) - if err != nil { - return nil, err - } - return ssh.NewClient(c, chans, reqs), nil -} - -func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error) { - sshConfigKeywords, err := findSshConfigKeywords(opts.SSHHost) - if err != nil { - return nil, err - } - - sshKeywords, err := combineSshKeywords(opts, sshConfigKeywords) - if err != nil { - return nil, err - } +func createClientConfig(connCtx context.Context, sshKeywords *SshKeywords, debugInfo *ConnectionDebugInfo) (*ssh.ClientConfig, error) { remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) var authSockSigners []ssh.Signer @@ -539,9 +537,9 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error authSockSigners, _ = agentClient.Signers() } - publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, authSockSigners, agentClient)) - keyboardInteractive := ssh.KeyboardInteractive(createInteractiveKbdInteractiveChallenge(connCtx, remoteName)) - passwordCallback := ssh.PasswordCallback(createInteractivePasswordCallbackPrompt(connCtx, remoteName)) + publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, authSockSigners, agentClient, debugInfo)) + keyboardInteractive := ssh.KeyboardInteractive(createInteractiveKbdInteractiveChallenge(connCtx, remoteName, debugInfo)) + passwordCallback := ssh.PasswordCallback(createInteractivePasswordCallbackPrompt(connCtx, remoteName, debugInfo)) // exclude gssapi-with-mic and hostbased until implemented authMethodMap := map[string]ssh.AuthMethod{ @@ -570,19 +568,90 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error authMethods = append(authMethods, authMethod) } - hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(opts) + hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(sshKeywords) if err != nil { return nil, err } networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port - clientConfig := &ssh.ClientConfig{ + return &ssh.ClientConfig{ User: sshKeywords.User, Auth: authMethods, HostKeyCallback: hostKeyCallback, HostKeyAlgorithms: hostKeyAlgorithms(networkAddr), + }, nil +} + +func connectInternal(ctx context.Context, networkAddr string, clientConfig *ssh.ClientConfig, currentClient *ssh.Client) (*ssh.Client, error) { + var clientConn net.Conn + var err error + if currentClient == nil { + d := net.Dialer{Timeout: clientConfig.Timeout} + clientConn, err = d.DialContext(ctx, "tcp", networkAddr) + if err != nil { + return nil, err + } + } else { + clientConn, err = currentClient.DialContext(ctx, "tcp", networkAddr) + if err != nil { + return nil, err + } } - return DialContext(connCtx, "tcp", networkAddr, clientConfig) + c, chans, reqs, err := ssh.NewClientConn(clientConn, networkAddr, clientConfig) + if err != nil { + return nil, err + } + return ssh.NewClient(c, chans, reqs), nil +} + +func ConnectToClient(connCtx context.Context, opts *SSHOpts, currentClient *ssh.Client, jumpNum int32) (*ssh.Client, int32, error) { + debugInfo := &ConnectionDebugInfo{ + CurrentClient: currentClient, + NextOpts: opts, + JumpNum: jumpNum, + } + if jumpNum > SshProxyJumpMaxDepth { + return nil, jumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: fmt.Errorf("ProxyJump %d exceeds Wave's max depth of %d", jumpNum, SshProxyJumpMaxDepth)} + } + // todo print final warning if logging gets turned off + sshConfigKeywords, err := findSshConfigKeywords(opts.SSHHost) + if err != nil { + return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} + } + + sshKeywords, err := combineSshKeywords(opts, sshConfigKeywords) + if err != nil { + return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} + } + + for _, proxyName := range sshKeywords.ProxyJump { + proxyOpts, err := ParseOpts(proxyName) + if err != nil { + return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} + } + + // ensure no overflow (this will likely never happen) + if jumpNum < math.MaxInt32 { + jumpNum += 1 + } + + debugInfo.CurrentClient, jumpNum, err = ConnectToClient(connCtx, proxyOpts, debugInfo.CurrentClient, jumpNum) + if err != nil { + // do not add a context on a recursive call + // (this can cause a recursive nested context that's arbitrarily deep) + return nil, jumpNum, err + } + } + clientConfig, err := createClientConfig(connCtx, sshKeywords, debugInfo) + if err != nil { + return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} + } + networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port + client, err := connectInternal(connCtx, networkAddr, clientConfig, debugInfo.CurrentClient) + if err != nil { + return client, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} + } + return client, debugInfo.JumpNum, nil } type SshKeywords struct { @@ -597,6 +666,9 @@ type SshKeywords struct { PreferredAuthentications []string AddKeysToAgent bool IdentityAgent string + ProxyJump []string + UserKnownHostsFile []string + GlobalKnownHostsFile []string } func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeywords, error) { @@ -641,6 +713,9 @@ func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeyword sshKeywords.PreferredAuthentications = configKeywords.PreferredAuthentications sshKeywords.AddKeysToAgent = configKeywords.AddKeysToAgent sshKeywords.IdentityAgent = configKeywords.IdentityAgent + sshKeywords.ProxyJump = configKeywords.ProxyJump + sshKeywords.UserKnownHostsFile = configKeywords.UserKnownHostsFile + sshKeywords.GlobalKnownHostsFile = configKeywords.GlobalKnownHostsFile return sshKeywords, nil } @@ -740,6 +815,23 @@ func findSshConfigKeywords(hostPattern string) (*SshKeywords, error) { sshKeywords.IdentityAgent = agentPath } + proxyJumpRaw, err := ssh_config.GetStrict(hostPattern, "ProxyJump") + if err != nil { + return nil, err + } + proxyJumpSplit := strings.Split(proxyJumpRaw, ",") + for _, proxyJumpName := range proxyJumpSplit { + proxyJumpName = strings.TrimSpace(proxyJumpName) + if proxyJumpName == "" || strings.ToLower(proxyJumpName) == "none" { + continue + } + sshKeywords.ProxyJump = append(sshKeywords.ProxyJump, proxyJumpName) + } + rawUserKnownHostsFile, _ := ssh_config.GetStrict(hostPattern, "UserKnownHostsFile") + sshKeywords.UserKnownHostsFile = strings.Fields(rawUserKnownHostsFile) // TODO - smarter splitting escaped spaces and quotes + rawGlobalKnownHostsFile, _ := ssh_config.GetStrict(hostPattern, "GlobalKnownHostsFile") + sshKeywords.GlobalKnownHostsFile = strings.Fields(rawGlobalKnownHostsFile) // TODO - smarter splitting escaped spaces and quotes + return sshKeywords, nil }