diff --git a/wavesrv/pkg/remote/sshclient.go b/wavesrv/pkg/remote/sshclient.go index 31d35de33..26eeee522 100644 --- a/wavesrv/pkg/remote/sshclient.go +++ b/wavesrv/pkg/remote/sshclient.go @@ -454,55 +454,134 @@ func createHostKeyCallback(opts *sstore.SSHOpts) (ssh.HostKeyCallback, error) { } func ConnectToClient(opts *sstore.SSHOpts) (*ssh.Client, error) { - ssh_config.ReloadConfigs() - configIdentityFiles := ssh_config.GetAll(opts.SSHHost, "IdentityFile") - identityFiles := []string{opts.SSHIdentity} - identityFiles = append(identityFiles, configIdentityFiles...) + sshConfigKeywords, err := findSshConfigKeywords(opts.SSHHost) + if err != nil { + return nil, err + } + + sshKeywords, err := combineSshKeywords(opts, sshConfigKeywords) + if err != nil { + return nil, err + } + + publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(&sshKeywords.IdentityFile, opts.SSHPassword)) + + var authMethods []ssh.AuthMethod + authMethods = append(authMethods, ssh.RetryableAuthMethod(publicKeyCallback, len(sshKeywords.IdentityFile))) + authMethods = append(authMethods, ssh.RetryableAuthMethod(ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(opts.SSHPassword)), 2)) + authMethods = append(authMethods, ssh.RetryableAuthMethod(ssh.PasswordCallback(createCombinedPasswordCallbackPrompt(opts.SSHPassword)), 2)) hostKeyCallback, err := createHostKeyCallback(opts) if err != nil { return nil, err } - publicKeyCallback := ssh.PublicKeysCallback(createPublicKeyCallback(&identityFiles, opts.SSHPassword)) - var authMethods []ssh.AuthMethod - authMethods = append(authMethods, ssh.RetryableAuthMethod(publicKeyCallback, len(identityFiles))) - authMethods = append(authMethods, ssh.RetryableAuthMethod(ssh.KeyboardInteractive(createCombinedKbdInteractiveChallenge(opts.SSHPassword)), 2)) - authMethods = append(authMethods, ssh.RetryableAuthMethod(ssh.PasswordCallback(createCombinedPasswordCallbackPrompt(opts.SSHPassword)), 2)) - configUser, _ := ssh_config.GetStrict(opts.SSHHost, "User") - configHostName, _ := ssh_config.GetStrict(opts.SSHHost, "HostName") - configPort, _ := ssh_config.GetStrict(opts.SSHHost, "Port") - var username string + clientConfig := &ssh.ClientConfig{ + User: sshKeywords.User, + Auth: authMethods, + HostKeyCallback: hostKeyCallback, + } + networkAddr := sshKeywords.HostName + ":" + sshKeywords.Port + return ssh.Dial("tcp", networkAddr, clientConfig) +} + +type SshKeywords struct { + User string + HostName string + Port string + IdentityFile []string + BatchMode bool + PasswordAuthentication bool + KbdInteractiveAuthentication bool + PreferredAuthentications []string +} + +func combineSshKeywords(opts *sstore.SSHOpts, configKeywords *SshKeywords) (*SshKeywords, error) { + sshKeywords := &SshKeywords{} + if opts.SSHUser != "" { - username = opts.SSHUser - } else if configUser != "" { - username = configUser + sshKeywords.User = opts.SSHUser + } else if configKeywords.User != "" { + sshKeywords.User = configKeywords.User } else { user, err := user.Current() if err != nil { return nil, fmt.Errorf("failed to get user for ssh: %+v", err) } - username = user.Username + sshKeywords.User = user.Username } - var hostName string - if configHostName != "" { - hostName = configHostName + + // we have to check the host value because of the weird way + // we store the pattern as the hostname for imported remotes + if configKeywords.HostName != "" { + sshKeywords.HostName = configKeywords.HostName } else { - hostName = opts.SSHHost + sshKeywords.HostName = opts.SSHHost } - clientConfig := &ssh.ClientConfig{ - User: username, - Auth: authMethods, - HostKeyCallback: hostKeyCallback, - } - var port string + if opts.SSHPort != 0 && opts.SSHPort != 22 { - port = strconv.Itoa(opts.SSHPort) - } else if configPort != "" && configPort != "22" { - port = configPort + sshKeywords.Port = strconv.Itoa(opts.SSHPort) + } else if configKeywords.Port != "" && configKeywords.Port != "22" { + sshKeywords.Port = configKeywords.Port } else { - port = "22" + sshKeywords.Port = "22" } - networkAddr := hostName + ":" + port - return ssh.Dial("tcp", networkAddr, clientConfig) + + sshKeywords.IdentityFile = []string{opts.SSHIdentity} + sshKeywords.IdentityFile = append(sshKeywords.IdentityFile, configKeywords.IdentityFile...) + + // these are not officially supported in the waveterm frontend but can be configured + // in ssh config files + sshKeywords.BatchMode = configKeywords.BatchMode + sshKeywords.PasswordAuthentication = configKeywords.PasswordAuthentication + sshKeywords.KbdInteractiveAuthentication = configKeywords.KbdInteractiveAuthentication + sshKeywords.PreferredAuthentications = configKeywords.PreferredAuthentications + + return sshKeywords, nil +} + +func findSshConfigKeywords(hostPattern string) (*SshKeywords, error) { + ssh_config.ReloadConfigs() + sshKeywords := &SshKeywords{} + var err error + + sshKeywords.User, err = ssh_config.GetStrict(hostPattern, "User") + if err != nil { + return nil, err + } + + sshKeywords.HostName, err = ssh_config.GetStrict(hostPattern, "HostName") + if err != nil { + return nil, err + } + + sshKeywords.Port, err = ssh_config.GetStrict(hostPattern, "Port") + if err != nil { + return nil, err + } + + sshKeywords.IdentityFile = ssh_config.GetAll(hostPattern, "IdentityFile") + + batchModeRaw, err := ssh_config.GetStrict(hostPattern, "BatchMode") + if err != nil { + return nil, err + } + sshKeywords.BatchMode = (strings.ToLower(batchModeRaw) == "yes") + + passwordAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "PasswordAuthentication") + if err != nil { + return nil, err + } + sshKeywords.PasswordAuthentication = (strings.ToLower(passwordAuthenticationRaw) == "yes") + + kbdInteractiveAuthenticationRaw, err := ssh_config.GetStrict(hostPattern, "KbdInteractiveAuthentication") + if err != nil { + return nil, err + } + sshKeywords.KbdInteractiveAuthentication = (strings.ToLower(kbdInteractiveAuthenticationRaw) == "yes") + + // these are case sensitive in openssh so they are here too + sshKeywords.PreferredAuthentications = ssh_config.GetAll(hostPattern, "PreferredAuthentications") + + return sshKeywords, nil }