SSH Agent Integration (#334)

Hook into an existing SSH Agent.
This allows us to pull keys already authenticated by the agent and write
to the agent ourselves.

---------

Co-authored-by: Evan Simkowitz <esimkowitz@users.noreply.github.com>
This commit is contained in:
Sylvie Crowe 2024-09-06 13:19:38 -07:00 committed by GitHub
parent 566bf461ff
commit a9533b0426
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 129 additions and 86 deletions

View File

@ -97,7 +97,7 @@ tasks:
vars: vars:
- ARCHS - ARCHS
cmds: cmds:
- cmd: CGO_ENABLED=1 GOARCH={{.GOARCH}} go build -tags "osusergo,netgo,sqlite_omit_load_extension" -ldflags "{{.GO_LDFLAGS}} -X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M') -X main.WaveVersion={{.VERSION}}" -o dist/bin/wavesrv.{{if eq .GOARCH "amd64"}}x64{{else}}{{.GOARCH}}{{end}}{{exeExt}} cmd/server/main-server.go - cmd: CGO_ENABLED=1 GOARCH={{.GOARCH}} go build -tags "osusergo,netcgo,sqlite_omit_load_extension" -ldflags "{{.GO_LDFLAGS}} -X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M') -X main.WaveVersion={{.VERSION}}" -o dist/bin/wavesrv.{{if eq .GOARCH "amd64"}}x64{{else}}{{.GOARCH}}{{end}}{{exeExt}} cmd/server/main-server.go
for: for:
var: ARCHS var: ARCHS
split: "," split: ","

View File

@ -8,24 +8,26 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"log" "log"
"net" "net"
"os" "os"
"os/exec"
"os/user" "os/user"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/kevinburke/ssh_config" "github.com/kevinburke/ssh_config"
"github.com/skeema/knownhosts" "github.com/skeema/knownhosts"
"github.com/wavetermdev/waveterm/pkg/trimquotes"
"github.com/wavetermdev/waveterm/pkg/userinput" "github.com/wavetermdev/waveterm/pkg/userinput"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
xknownhosts "golang.org/x/crypto/ssh/knownhosts" xknownhosts "golang.org/x/crypto/ssh/knownhosts"
) )
@ -66,7 +68,7 @@ func createDummySigner() ([]ssh.Signer, error) {
// they were successes. An error in this function prevents any other // they were successes. An error in this function prevents any other
// keys from being attempted. But if there's an error because of a dummy // keys from being attempted. But if there's an error because of a dummy
// file, the library can still try again with a new key. // file, the library can still try again with a new key.
func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, passphrase string) func() ([]ssh.Signer, error) { func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords, authSockSignersExt []ssh.Signer, agentClient agent.ExtendedAgent) func() ([]ssh.Signer, error) {
var identityFiles []string var identityFiles []string
existingKeys := make(map[string][]byte) existingKeys := make(map[string][]byte)
@ -84,7 +86,18 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
// require pointer to modify list in closure // require pointer to modify list in closure
identityFilesPtr := &identityFiles identityFilesPtr := &identityFiles
var authSockSigners []ssh.Signer
authSockSigners = append(authSockSigners, authSockSignersExt...)
authSockSignersPtr := &authSockSigners
return func() ([]ssh.Signer, error) { return func() ([]ssh.Signer, error) {
// try auth sock
if len(*authSockSignersPtr) != 0 {
authSockSigner := (*authSockSignersPtr)[0]
*authSockSignersPtr = (*authSockSignersPtr)[1:]
return []ssh.Signer{authSockSigner}, nil
}
if len(*identityFilesPtr) == 0 { if len(*identityFilesPtr) == 0 {
return nil, fmt.Errorf("no identity files remaining") return nil, fmt.Errorf("no identity files remaining")
} }
@ -96,22 +109,22 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
// skip this key and try with the next // skip this key and try with the next
return createDummySigner() return createDummySigner()
} }
signer, err := ssh.ParsePrivateKey(privateKey)
if err == nil { unencryptedPrivateKey, err := ssh.ParseRawPrivateKey(privateKey)
return []ssh.Signer{signer}, err
}
if _, ok := err.(*ssh.PassphraseMissingError); !ok { if _, ok := err.(*ssh.PassphraseMissingError); !ok {
// skip this key and try with the next // skip this key and try with the next
return createDummySigner() return createDummySigner()
} }
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
if err == nil { if err == nil {
signer, err := ssh.NewSignerFromKey(unencryptedPrivateKey)
if err == nil {
if sshKeywords.AddKeysToAgent && agentClient != nil {
agentClient.Add(agent.AddedKey{
PrivateKey: unencryptedPrivateKey,
})
}
return []ssh.Signer{signer}, err return []ssh.Signer{signer}, err
} }
if err != x509.IncorrectPasswordError && err.Error() != "bcrypt_pbkdf: empty password" {
// skip this key and try with the next
return createDummySigner()
} }
// batch mode deactivates user input // batch mode deactivates user input
@ -133,21 +146,22 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *SshKeywords,
// trying keys // trying keys
return nil, UserInputCancelError{Err: err} return nil, UserInputCancelError{Err: err}
} }
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(response.Text)) unencryptedPrivateKey, err = ssh.ParseRawPrivateKeyWithPassphrase(privateKey, []byte([]byte(response.Text)))
if err != nil { if err != nil {
// skip this key and try with the next // skip this key and try with the next
return createDummySigner() return createDummySigner()
} }
return []ssh.Signer{signer}, err signer, err := ssh.NewSignerFromKey(unencryptedPrivateKey)
if err != nil {
// skip this key and try with the next
return createDummySigner()
} }
} if sshKeywords.AddKeysToAgent && agentClient != nil {
agentClient.Add(agent.AddedKey{
func createDefaultPasswordCallbackPrompt(password string) func() (secret string, err error) { PrivateKey: unencryptedPrivateKey,
return func() (secret string, err error) { })
// this should be modified to return an error if no password is stored }
// but an empty password is not sufficient because some systems allow return []ssh.Signer{signer}, err
// empty passwords
return password, nil
} }
} }
@ -173,31 +187,6 @@ func createInteractivePasswordCallbackPrompt(connCtx context.Context, remoteDisp
} }
} }
func createCombinedPasswordCallbackPrompt(connCtx context.Context, password string, remoteDisplayName string) func() (secret string, err error) {
var once sync.Once
return func() (secret string, err error) {
var prompt func() (secret string, err error)
once.Do(func() { prompt = createDefaultPasswordCallbackPrompt(password) })
if prompt == nil {
prompt = createInteractivePasswordCallbackPrompt(connCtx, remoteDisplayName)
}
return prompt()
}
}
func createNaiveKbdInteractiveChallenge(password string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
for _, q := range questions {
if strings.Contains(strings.ToLower(q), "password") {
answers = append(answers, password)
} else {
answers = append(answers, "")
}
}
return answers, nil
}
}
func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { func createInteractiveKbdInteractiveChallenge(connCtx context.Context, remoteName string) func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) { return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
if len(questions) != len(echos) { if len(questions) != len(echos) {
@ -238,18 +227,6 @@ func promptChallengeQuestion(connCtx context.Context, question string, echo bool
return response.Text, nil return response.Text, nil
} }
func createCombinedKbdInteractiveChallenge(connCtx context.Context, password string, remoteName string) ssh.KeyboardInteractiveChallenge {
var once sync.Once
return func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
var challenge ssh.KeyboardInteractiveChallenge
once.Do(func() { challenge = createNaiveKbdInteractiveChallenge(password) })
if challenge == nil {
challenge = createInteractiveKbdInteractiveChallenge(connCtx, remoteName)
}
return challenge(name, instruction, questions, echos)
}
}
func openKnownHostsForEdit(knownHostsFilename string) (*os.File, error) { func openKnownHostsForEdit(knownHostsFilename string) (*os.File, error) {
path, _ := filepath.Split(knownHostsFilename) path, _ := filepath.Split(knownHostsFilename)
err := os.MkdirAll(path, 0700) err := os.MkdirAll(path, 0700)
@ -543,30 +520,32 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts) (*ssh.Client, error
} }
remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port) remoteName := sshKeywords.User + "@" + xknownhosts.Normalize(sshKeywords.HostName+":"+sshKeywords.Port)
publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, "")) var authSockSigners []ssh.Signer
keyboardInteractive := ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(connCtx, "", remoteName)) var agentClient agent.ExtendedAgent
passwordCallback := ssh.PasswordCallback(createCombinedPasswordCallbackPrompt(connCtx, "", remoteName)) conn, err := net.Dial("unix", sshKeywords.IdentityAgent)
if err != nil {
// batch mode turns off interactive input. this means the number of log.Printf("Failed to open Identity Agent Socket: %v", err)
// attemtps must drop to 1 with this setup
var attemptsAllowed int
if sshKeywords.BatchMode {
attemptsAllowed = 1
} else { } else {
attemptsAllowed = 2 agentClient = agent.NewClient(conn)
authSockSigners, _ = agentClient.Signers()
} }
publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(connCtx, sshKeywords, authSockSigners, agentClient))
keyboardInteractive := ssh.KeyboardInteractive(createInteractiveKbdInteractiveChallenge(connCtx, remoteName))
passwordCallback := ssh.PasswordCallback(createInteractivePasswordCallbackPrompt(connCtx, remoteName))
// exclude gssapi-with-mic and hostbased until implemented // exclude gssapi-with-mic and hostbased until implemented
authMethodMap := map[string]ssh.AuthMethod{ authMethodMap := map[string]ssh.AuthMethod{
"publickey": ssh.RetryableAuthMethod(publicKeyCallback, len(sshKeywords.IdentityFile)), "publickey": ssh.RetryableAuthMethod(publicKeyCallback, len(sshKeywords.IdentityFile)+len(authSockSigners)),
"keyboard-interactive": ssh.RetryableAuthMethod(keyboardInteractive, attemptsAllowed), "keyboard-interactive": ssh.RetryableAuthMethod(keyboardInteractive, 1),
"password": ssh.RetryableAuthMethod(passwordCallback, attemptsAllowed), "password": ssh.RetryableAuthMethod(passwordCallback, 1),
} }
// note: batch mode turns off interactive input
authMethodActiveMap := map[string]bool{ authMethodActiveMap := map[string]bool{
"publickey": sshKeywords.PubkeyAuthentication, "publickey": sshKeywords.PubkeyAuthentication,
"keyboard-interactive": sshKeywords.KbdInteractiveAuthentication, "keyboard-interactive": sshKeywords.KbdInteractiveAuthentication && !sshKeywords.BatchMode,
"password": sshKeywords.PasswordAuthentication, "password": sshKeywords.PasswordAuthentication && !sshKeywords.BatchMode,
} }
var authMethods []ssh.AuthMethod var authMethods []ssh.AuthMethod
@ -607,6 +586,8 @@ type SshKeywords struct {
PasswordAuthentication bool PasswordAuthentication bool
KbdInteractiveAuthentication bool KbdInteractiveAuthentication bool
PreferredAuthentications []string PreferredAuthentications []string
AddKeysToAgent bool
IdentityAgent string
} }
func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeywords, error) { func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeywords, error) {
@ -649,6 +630,8 @@ func combineSshKeywords(opts *SSHOpts, configKeywords *SshKeywords) (*SshKeyword
sshKeywords.PasswordAuthentication = configKeywords.PasswordAuthentication sshKeywords.PasswordAuthentication = configKeywords.PasswordAuthentication
sshKeywords.KbdInteractiveAuthentication = configKeywords.KbdInteractiveAuthentication sshKeywords.KbdInteractiveAuthentication = configKeywords.KbdInteractiveAuthentication
sshKeywords.PreferredAuthentications = configKeywords.PreferredAuthentications sshKeywords.PreferredAuthentications = configKeywords.PreferredAuthentications
sshKeywords.AddKeysToAgent = configKeywords.AddKeysToAgent
sshKeywords.IdentityAgent = configKeywords.IdentityAgent
return sshKeywords, nil return sshKeywords, nil
} }
@ -661,47 +644,54 @@ func findSshConfigKeywords(hostPattern string) (*SshKeywords, error) {
sshKeywords := &SshKeywords{} sshKeywords := &SshKeywords{}
var err error var err error
sshKeywords.User, err = ssh_config.GetStrict(hostPattern, "User") userRaw, err := ssh_config.GetStrict(hostPattern, "User")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.User = trimquotes.TryTrimQuotes(userRaw)
sshKeywords.HostName, err = ssh_config.GetStrict(hostPattern, "HostName") hostNameRaw, err := ssh_config.GetStrict(hostPattern, "HostName")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.HostName = trimquotes.TryTrimQuotes(hostNameRaw)
sshKeywords.Port, err = ssh_config.GetStrict(hostPattern, "Port") portRaw, err := ssh_config.GetStrict(hostPattern, "Port")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.Port = trimquotes.TryTrimQuotes(portRaw)
sshKeywords.IdentityFile = ssh_config.GetAll(hostPattern, "IdentityFile") identityFileRaw := ssh_config.GetAll(hostPattern, "IdentityFile")
for i := 0; i < len(identityFileRaw); i++ {
identityFileRaw[i] = trimquotes.TryTrimQuotes(identityFileRaw[i])
}
sshKeywords.IdentityFile = identityFileRaw
batchModeRaw, err := ssh_config.GetStrict(hostPattern, "BatchMode") batchModeRaw, err := ssh_config.GetStrict(hostPattern, "BatchMode")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.BatchMode = (strings.ToLower(batchModeRaw) == "yes") sshKeywords.BatchMode = (strings.ToLower(trimquotes.TryTrimQuotes(batchModeRaw)) == "yes")
// we currently do not support host-bound or unbound but will use yes when they are selected // we currently do not support host-bound or unbound but will use yes when they are selected
pubkeyAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PubkeyAuthentication") pubkeyAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PubkeyAuthentication")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.PubkeyAuthentication = (strings.ToLower(pubkeyAuthenticationRaw) != "no") sshKeywords.PubkeyAuthentication = (strings.ToLower(trimquotes.TryTrimQuotes(pubkeyAuthenticationRaw)) != "no")
passwordAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PasswordAuthentication") passwordAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PasswordAuthentication")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.PasswordAuthentication = (strings.ToLower(passwordAuthenticationRaw) != "no") sshKeywords.PasswordAuthentication = (strings.ToLower(trimquotes.TryTrimQuotes(passwordAuthenticationRaw)) != "no")
kbdInteractiveAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "KbdInteractiveAuthentication") kbdInteractiveAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "KbdInteractiveAuthentication")
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.KbdInteractiveAuthentication = (strings.ToLower(kbdInteractiveAuthenticationRaw) != "no") sshKeywords.KbdInteractiveAuthentication = (strings.ToLower(trimquotes.TryTrimQuotes(kbdInteractiveAuthenticationRaw)) != "no")
// these are parsed as a single string and must be separated // these are parsed as a single string and must be separated
// these are case sensitive in openssh so they are here too // these are case sensitive in openssh so they are here too
@ -709,7 +699,29 @@ func findSshConfigKeywords(hostPattern string) (*SshKeywords, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sshKeywords.PreferredAuthentications = strings.Split(preferredAuthenticationsRaw, ",") sshKeywords.PreferredAuthentications = strings.Split(trimquotes.TryTrimQuotes(preferredAuthenticationsRaw), ",")
addKeysToAgentRaw, err := ssh_config.GetStrict(hostPattern, "AddKeysToAgent")
if err != nil {
return nil, err
}
sshKeywords.AddKeysToAgent = (strings.ToLower(trimquotes.TryTrimQuotes(addKeysToAgentRaw)) == "yes")
identityAgentRaw, err := ssh_config.GetStrict(hostPattern, "IdentityAgent")
if err != nil {
return nil, err
}
if identityAgentRaw == "" {
shellPath := shellutil.DetectLocalShellPath()
authSockCommand := exec.Command(shellPath, "-c", "echo ${SSH_AUTH_SOCK}")
sshAuthSock, err := authSockCommand.Output()
if err == nil {
sshKeywords.IdentityAgent = wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(strings.TrimSpace(string(sshAuthSock))))
} else {
log.Printf("unable to find SSH_AUTH_SOCK: %v\n", err)
}
} else {
sshKeywords.IdentityAgent = wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(identityAgentRaw))
}
return sshKeywords, nil return sshKeywords, nil
} }

View File

@ -0,0 +1,28 @@
package trimquotes
import (
"strconv"
)
func TrimQuotes(s string) (string, bool) {
if len(s) > 2 && s[0] == '"' {
trimmed, err := strconv.Unquote(s)
if err != nil {
return s, false
}
return trimmed, true
}
return s, false
}
func TryTrimQuotes(s string) string {
trimmed, _ := TrimQuotes(s)
return trimmed
}
func ReplaceQuotes(s string, shouldReplace bool) string {
if shouldReplace {
return strconv.Quote(s)
}
return s
}

View File

@ -214,6 +214,9 @@ func GetWshBaseName(version string, goos string, goarch string) string {
if goarch == "amd64" { if goarch == "amd64" {
goarch = "x64" goarch = "x64"
} }
if goarch == "aarch64" {
goarch = "arm64"
}
if goos == "windows" { if goos == "windows" {
ext = ".exe" ext = ".exe"
} }