From a083db686b87ce30e59a3857bf27de9980d7ec8f Mon Sep 17 00:00:00 2001 From: Sylvia Crowe Date: Mon, 16 Dec 2024 12:15:51 -0800 Subject: [PATCH] feat: add new keyword cascade resolution --- pkg/remote/connutil.go | 11 +---- pkg/remote/sshclient.go | 93 +++++++++++++++++++++++++++++++++++------ 2 files changed, 81 insertions(+), 23 deletions(-) diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 025b78622..ee8910fab 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -13,7 +13,6 @@ import ( "os/user" "path/filepath" "regexp" - "strconv" "strings" "github.com/wavetermdev/waveterm/pkg/panichandler" @@ -27,16 +26,8 @@ func ParseOpts(input string) (*SSHOpts, error) { if m == nil { return nil, fmt.Errorf("invalid format of user@host argument") } - remoteUser, remoteHost, remotePortStr := m[1], m[2], m[3] + remoteUser, remoteHost, remotePort := m[1], m[2], m[3] remoteUser = strings.Trim(remoteUser, "@") - var remotePort int - if remotePortStr != "" { - var err error - remotePort, err = strconv.Atoi(remotePortStr) - if err != nil { - return nil, fmt.Errorf("invalid port specified on user@host argument") - } - } return &SSHOpts{SSHHost: remoteHost, SSHUser: remoteUser, SSHPort: remotePort}, nil } diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go index 8a9a1e50f..ac1cedaa7 100644 --- a/pkg/remote/sshclient.go +++ b/pkg/remote/sshclient.go @@ -661,21 +661,27 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts, currentClient *ssh. return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} } - connFlags.SshUser = &opts.SSHUser - connFlags.SshHostName = &opts.SSHHost - portStr := fmt.Sprintf("%d", opts.SSHPort) - connFlags.SshPort = &portStr + if opts.SSHUser != "" { + connFlags.SshUser = &opts.SSHUser + } + //connFlags.SshHostName = &opts.SSHHost + if opts.SSHPort != "" { + connFlags.SshPort = &opts.SSHPort + } rawName := opts.String() - savedKeywords, ok := wconfig.ReadFullConfig().Connections[rawName] + fullConfig := wconfig.ReadFullConfig() + internalSshConfigKeywords, ok := fullConfig.Connections[rawName] if !ok { - savedKeywords = wshrpc.ConnKeywords{} + internalSshConfigKeywords = wshrpc.ConnKeywords{} } + partialMerged := mergeKeywords(sshConfigKeywords, &internalSshConfigKeywords) + sshKeywords := mergeKeywords(partialMerged, connFlags) - sshKeywords, err := combineSshKeywords(connFlags, sshConfigKeywords, &savedKeywords) - if err != nil { - return nil, debugInfo.JumpNum, ConnectionError{ConnectionDebugInfo: debugInfo, Err: err} - } + // handle these separately since they append + sshKeywords.SshIdentityFile = append(sshKeywords.SshIdentityFile, connFlags.SshIdentityFile...) + sshKeywords.SshIdentityFile = append(sshKeywords.SshIdentityFile, internalSshConfigKeywords.SshIdentityFile...) + sshKeywords.SshIdentityFile = append(sshKeywords.SshIdentityFile, sshConfigKeywords.SshIdentityFile...) for _, proxyName := range sshKeywords.SshProxyJump { proxyOpts, err := ParseOpts(proxyName) @@ -777,7 +783,15 @@ func findSshConfigKeywords(hostPattern string) (*wshrpc.ConnKeywords, error) { if err != nil { return nil, err } - sshKeywords.SshUser = ptr(trimquotes.TryTrimQuotes(userRaw)) + userClean := trimquotes.TryTrimQuotes(userRaw) + if userClean == "" { + userDetails, err := user.Current() + if err != nil { + return nil, err + } + userClean = userDetails.Username + } + sshKeywords.SshUser = &userClean hostNameRaw, err := WaveSshConfigUserSettings().GetStrict(hostPattern, "HostName") if err != nil { @@ -883,7 +897,7 @@ func findSshConfigKeywords(hostPattern string) (*wshrpc.ConnKeywords, error) { type SSHOpts struct { SSHHost string `json:"sshhost"` SSHUser string `json:"sshuser"` - SSHPort int `json:"sshport,omitempty"` + SSHPort string `json:"sshport,omitempty"` } func (opts SSHOpts) String() string { @@ -892,8 +906,61 @@ func (opts SSHOpts) String() string { stringRepr = opts.SSHUser + "@" } stringRepr = stringRepr + opts.SSHHost - if opts.SSHPort != 0 { + if opts.SSHPort != "22" && opts.SSHPort != "" { stringRepr = stringRepr + ":" + fmt.Sprint(opts.SSHPort) } return stringRepr } + +func mergeKeywords(oldKeywords *wshrpc.ConnKeywords, newKeywords *wshrpc.ConnKeywords) *wshrpc.ConnKeywords { + if oldKeywords == nil { + oldKeywords = &wshrpc.ConnKeywords{} + } + if newKeywords == nil { + return oldKeywords + } + outKeywords := *oldKeywords + + if newKeywords.SshHostName != nil { + outKeywords.SshHostName = newKeywords.SshHostName + } + if newKeywords.SshUser != nil { + outKeywords.SshUser = newKeywords.SshUser + } + if newKeywords.SshPort != nil { + outKeywords.SshPort = newKeywords.SshPort + } + // skip identityfile (handled separately due to different behavior) + if newKeywords.SshBatchMode != nil { + outKeywords.SshBatchMode = newKeywords.SshBatchMode + } + if newKeywords.SshPubkeyAuthentication != nil { + outKeywords.SshPubkeyAuthentication = newKeywords.SshPubkeyAuthentication + } + if newKeywords.SshPasswordAuthentication != nil { + outKeywords.SshPasswordAuthentication = newKeywords.SshPasswordAuthentication + } + if newKeywords.SshKbdInteractiveAuthentication != nil { + outKeywords.SshKbdInteractiveAuthentication = newKeywords.SshKbdInteractiveAuthentication + } + if newKeywords.SshPreferredAuthentications != nil { + outKeywords.SshPreferredAuthentications = newKeywords.SshPreferredAuthentications + } + if newKeywords.SshAddKeysToAgent != nil { + outKeywords.SshAddKeysToAgent = newKeywords.SshAddKeysToAgent + } + if newKeywords.SshIdentityAgent != nil { + outKeywords.SshIdentityAgent = newKeywords.SshIdentityAgent + } + if newKeywords.SshProxyJump != nil { + outKeywords.SshProxyJump = newKeywords.SshProxyJump + } + if newKeywords.SshUserKnownHostsFile != nil { + outKeywords.SshUserKnownHostsFile = newKeywords.SshUserKnownHostsFile + } + if newKeywords.SshGlobalKnownHostsFile != nil { + outKeywords.SshGlobalKnownHostsFile = newKeywords.SshGlobalKnownHostsFile + } + + return &outKeywords +}