ProxyJump Support (#1107)

This adds basic ProxyJump support for handling ssh connections.
This commit is contained in:
Sylvie Crowe 2024-10-25 12:14:40 -07:00 committed by GitHub
parent e9fcb9b145
commit ac6f9a05d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 46 deletions

View File

@ -422,9 +422,9 @@ func (conn *SSHConn) WithLock(fn func()) {
}
func (conn *SSHConn) connectInternal(ctx context.Context) error {
client, err := remote.ConnectToClient(ctx, conn.Opts) //todo specify or remove opts
client, _, err := remote.ConnectToClient(ctx, conn.Opts, nil, 0)
if err != nil {
log.Printf("error: failed to connect to client %s: %v\n", conn.GetName(), err)
log.Printf("error: failed to connect to client %s: %s\n", conn.GetName(), err)
return err
}
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))

View File

@ -11,6 +11,7 @@ import (
"encoding/base64"
"fmt"
"log"
"math"
"net"
"os"
"os/exec"
@ -31,6 +32,8 @@ import (
xknownhosts "golang.org/x/crypto/ssh/knownhosts"
)
const SshProxyJumpMaxDepth = 10
type UserInputCancelError struct {
Err error
}
@ -41,6 +44,24 @@ func (uice UserInputCancelError) Error() string {
return uice.Err.Error()
}
type ConnectionDebugInfo struct {
CurrentClient *ssh.Client
NextOpts *SSHOpts
JumpNum int32
}
type ConnectionError struct {
*ConnectionDebugInfo
Err error
}
func (ce ConnectionError) Error() string {
if ce.CurrentClient == nil {
return fmt.Sprintf("Connecting to %+#v, Error: %v", ce.NextOpts, ce.Err)
}
return fmt.Sprintf("Connecting from %v to %+#v (jump number %d), Error: %v", ce.CurrentClient, ce.NextOpts, ce.JumpNum, ce.Err)
}
// This exists to trick the ssh library into continuing to try
// different public keys even when the current key cannot be
// properly parsed
@ -68,7 +89,7 @@ func createDummySigner() ([]ssh.Signer, error) {
// they were successes. An error in this function prevents any other
// keys from being attempted. But if there's an error because of a dummy
// file, the library can still try again with a new key.
func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, authSockSignersExt []ssh.Signer, agentClient agent.ExtendedAgent) func() ([]ssh.Signer, error) {
func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, authSockSignersExt []ssh.Signer, agentClient agent.ExtendedAgent, debugInfo *ConnectionDebugInfo) func() ([]ssh.Signer, error) {
var identityFiles []string
existingKeys := make(map[string][]byte)
@ -103,7 +124,7 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
}
if len(*identityFilesPtr) == 0 {
return nil, fmt.Errorf("no identity files remaining")
return nil, ConnectionError{ConnectionDebugInfo: debugInfo, Err: fmt.Errorf("no identity files remaining")}
}
identityFile := (*identityFilesPtr)[0]
*identityFilesPtr = (*identityFilesPtr)[1:]
@ -123,7 +144,7 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
PrivateKey: unencryptedPrivateKey,
})
}
return []ssh.Signer{signer}, err
return []ssh.Signer{signer}, nil
}
}
if _, ok := err.(*ssh.PassphraseMissingError); !ok {
@ -148,7 +169,8 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
if err != nil {
// this is an error where we actually do want to stop
// trying keys
return nil, UserInputCancelError{Err: err}
return nil, ConnectionError{ConnectionDebugInfo: debugInfo, Err: UserInputCancelError{Err: err}}
}
unencryptedPrivateKey, err = ssh.ParseRawPrivateKeyWithPassphrase(privateKey, []byte([]byte(response.Text)))
if err != nil {
@ -165,11 +187,11 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
PrivateKey: unencryptedPrivateKey,
})
}
return []ssh.Signer{signer}, err
return []ssh.Signer{signer}, nil
}
}
func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string) func() (secret string, err error) {
func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string, debugInfo *ConnectionDebugInfo) func() (secret string, err error) {
return func() (secret string, err error) {
ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second)
defer cancelFn()
@ -185,13 +207,13 @@ func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisp
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil {
return "", err
return "", ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
return response.Text, nil
}
}
func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string, debugInfo *ConnectionDebugInfo) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
if len(questions) != len(echos) {
return nil, fmt.Errorf("bad response from server: questions has len %d, echos has len %d", len(questions), len(echos))
@ -200,7 +222,7 @@ func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteNam
echo := echos[i]
answer, err := promptChallengeQuestion(connCtx, question, echo, remoteName)
if err != nil {
return nil, err
return nil, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
answers = append(answers, answer)
}
@ -336,12 +358,9 @@ func lineContainsMatch(line []byte, matches [][]byte) bool {
return false
}
func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithms, error) {
ssh_config.ReloadConfigs()
rawUserKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "UserKnownHostsFile")
userKnownHostsFiles := strings.Fields(rawUserKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes
rawGlobalKnownHostsFiles, _ := ssh_config.GetStrict(opts.SSHHost, "GlobalKnownHostsFile")
globalKnownHostsFiles := strings.Fields(rawGlobalKnownHostsFiles) // TODO - smarter splitting escaped spaces and quotes
func createHostKeyCallback(sshKeywords *SshKeywords) (ssh.HostKeyCallback, HostKeyAlgorithms, error) {
globalKnownHostsFiles := sshKeywords.GlobalKnownHostsFile
userKnownHostsFiles := sshKeywords.UserKnownHostsFile
osUser, err := user.Current()
if err != nil {
@ -485,6 +504,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithm
"%s\n\n"+
"**Offending Keys** \n"+
"%s", key.Type(), correctKeyFingerprint, strings.Join(bulletListKnownHosts, " \n"), strings.Join(offendingKeysFmt, " \n"))
log.Print(errorMsg)
//update := scbus.MakeUpdatePacket()
// create update into alert message
@ -504,29 +524,7 @@ func createHostKeyCallback(opts *SSHOpts) (ssh.HostKeyCallback, HostKeyAlgorithm
return waveHostKeyCallback, hostKeyAlgorithms, nil
}
func DialContext(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
d := net.Dialer{Timeout: config.Timeout}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error) {
sshConfigKeywords, err := findSshConfigKeywords(opts.SSHHost)
if err != nil {
return nil, err
}
sshKeywords, err := combineSshKeywords(opts, sshConfigKeywords)
if err != nil {
return nil, err
}
func createClientConfig(connCtx context.Context, sshKeywords *SshKeywords, debugInfo *ConnectionDebugInfo) (*ssh.ClientConfig, error) {
remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port)
var authSockSigners []ssh.Signer
@ -539,9 +537,9 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error
authSockSigners, _ = agentClient.Signers()
}
publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, authSockSigners, agentClient))
keyboardInteractive := ssh.KeyboardInteractive(createInteractiveKbdInteractiveChallenge(connCtx, remoteName))
passwordCallback := ssh.PasswordCallback(createInteractivePasswordCallbackPrompt(connCtx, remoteName))
publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, authSockSigners, agentClient, debugInfo))
keyboardInteractive := ssh.KeyboardInteractive(createInteractiveKbdInteractiveChallenge(connCtx, remoteName, debugInfo))
passwordCallback := ssh.PasswordCallback(createInteractivePasswordCallbackPrompt(connCtx, remoteName, debugInfo))
// exclude gssapi-with-mic and hostbased until implemented
authMethodMap := map[string]ssh.AuthMethod{
@ -570,19 +568,90 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error
authMethods = append(authMethods, authMethod)
}
hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(opts)
hostKeyCallback, hostKeyAlgorithms, err := createHostKeyCallback(sshKeywords)
if err != nil {
return nil, err
}
networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port
clientConfig := &ssh.ClientConfig{
return &ssh.ClientConfig{
User: sshKeywords.User,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
HostKeyAlgorithms: hostKeyAlgorithms(networkAddr),
}, nil
}
func connectInternal(ctx context.Context, networkAddr string, clientConfig *ssh.ClientConfig, currentClient *ssh.Client) (*ssh.Client, error) {
var clientConn net.Conn
var err error
if currentClient == nil {
d := net.Dialer{Timeout: clientConfig.Timeout}
clientConn, err = d.DialContext(ctx, "tcp", networkAddr)
if err != nil {
return nil, err
}
return DialContext(connCtx, "tcp", networkAddr, clientConfig)
} else {
clientConn, err = currentClient.DialContext(ctx, "tcp", networkAddr)
if err != nil {
return nil, err
}
}
c, chans, reqs, err := ssh.NewClientConn(clientConn, networkAddr, clientConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
func ConnectToClient(connCtx context.Context, opts *SSHOpts, currentClient *ssh.Client, jumpNum int32) (*ssh.Client, int32, error) {
debugInfo := &ConnectionDebugInfo{
CurrentClient: currentClient,
NextOpts: opts,
JumpNum: jumpNum,
}
if jumpNum > SshProxyJumpMaxDepth {
return nil, jumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: fmt.Errorf("ProxyJump %d exceeds Wave's max depth of %d", jumpNum, SshProxyJumpMaxDepth)}
}
// todo print final warning if logging gets turned off
sshConfigKeywords, err := findSshConfigKeywords(opts.SSHHost)
if err != nil {
return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
sshKeywords, err := combineSshKeywords(opts, sshConfigKeywords)
if err != nil {
return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
for _, proxyName := range sshKeywords.ProxyJump {
proxyOpts, err := ParseOpts(proxyName)
if err != nil {
return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
// ensure no overflow (this will likely never happen)
if jumpNum < math.MaxInt32 {
jumpNum += 1
}
debugInfo.CurrentClient, jumpNum, err = ConnectToClient(connCtx, proxyOpts, debugInfo.CurrentClient, jumpNum)
if err != nil {
// do not add a context on a recursive call
// (this can cause a recursive nested context that's arbitrarily deep)
return nil, jumpNum, err
}
}
clientConfig, err := createClientConfig(connCtx, sshKeywords, debugInfo)
if err != nil {
return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port
client, err := connectInternal(connCtx, networkAddr, clientConfig, debugInfo.CurrentClient)
if err != nil {
return client, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err}
}
return client, debugInfo.JumpNum, nil
}
type SshKeywords struct {
@ -597,6 +666,9 @@ type SshKeywords struct {
PreferredAuthentications []string
AddKeysToAgent bool
IdentityAgent string
ProxyJump []string
UserKnownHostsFile []string
GlobalKnownHostsFile []string
}
func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeywords, error) {
@ -641,6 +713,9 @@ func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeyword
sshKeywords.PreferredAuthentications = configKeywords.PreferredAuthentications
sshKeywords.AddKeysToAgent = configKeywords.AddKeysToAgent
sshKeywords.IdentityAgent = configKeywords.IdentityAgent
sshKeywords.ProxyJump = configKeywords.ProxyJump
sshKeywords.UserKnownHostsFile = configKeywords.UserKnownHostsFile
sshKeywords.GlobalKnownHostsFile = configKeywords.GlobalKnownHostsFile
return sshKeywords, nil
}
@ -740,6 +815,23 @@ func findSshConfigKeywords(hostPattern string) (*SshKeywords, error) {
sshKeywords.IdentityAgent = agentPath
}
proxyJumpRaw, err := ssh_config.GetStrict(hostPattern, "ProxyJump")
if err != nil {
return nil, err
}
proxyJumpSplit := strings.Split(proxyJumpRaw, ",")
for _, proxyJumpName := range proxyJumpSplit {
proxyJumpName = strings.TrimSpace(proxyJumpName)
if proxyJumpName == "" || strings.ToLower(proxyJumpName) == "none" {
continue
}
sshKeywords.ProxyJump = append(sshKeywords.ProxyJump, proxyJumpName)
}
rawUserKnownHostsFile, _ := ssh_config.GetStrict(hostPattern, "UserKnownHostsFile")
sshKeywords.UserKnownHostsFile = strings.Fields(rawUserKnownHostsFile) // TODO - smarter splitting escaped spaces and quotes
rawGlobalKnownHostsFile, _ := ssh_config.GetStrict(hostPattern, "GlobalKnownHostsFile")
sshKeywords.GlobalKnownHostsFile = strings.Fields(rawGlobalKnownHostsFile) // TODO - smarter splitting escaped spaces and quotes
return sshKeywords, nil
}