mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
Use ssh library for remote connections (#250)
* create proof of concept ssh library integration This is a first attempt to integrate the golang crypto/ssh library for handling remote connections. As it stands, this features is limited to identity files without passphrases. It needs to be expanded to include key+passphrase and password verifications as well. * add password and keyboard-interactive ssh auth This adds several new ssh auth methods. In addition to the PublicKey method used previously, this adds password authentication, keyboard-interactive authentication, and PublicKey+Passphrase authentication. Furthermore, it refactores the ssh connection code into its own wavesrv file rather than storing int in waveshell's shexec file. * clean up old mshell launch methods In the debugging the addition of the ssh library, i had several versions of the MShellProc Launch function. Since this seems mostly stable, I have removed the old version and the experimental version in favor of the combined version. * allow switching between new and old ssh for dev It is inconvenient to create milestones without being able to merge into the main branch. But due to the experimental nature of the ssh changes, it is not desired to use these changes in the main branch yet. This change disables the new ssh launcher by default. It can be used by changing the UseSshLibrary constant to true in remote.go. With this, it becomes possible to merge these changes into the main branch without them being used in production. * fix: allow retry after ssh auth failure Previously, the error status was not set when an ssh connection failed. Because of this, an ssh connection failure would lock the failed remote until waveterm was rebooted. This fix properly sets the error status so this cannot happen.
This commit is contained in:
parent
99f5c094d2
commit
018bb14b6a
@ -651,7 +651,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
}
|
||||
cproc, _, err := shexec.MakeClientProc(context.Background(), ecmd)
|
||||
cproc, _, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd})
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
|
||||
return
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@ -19,8 +20,97 @@ import (
|
||||
|
||||
const NotFoundVersion = "v0.0"
|
||||
|
||||
type CmdWrap struct {
|
||||
Cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (cw CmdWrap) Kill() {
|
||||
cw.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (cw CmdWrap) Wait() error {
|
||||
return cw.Cmd.Wait()
|
||||
}
|
||||
|
||||
func (cw CmdWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
|
||||
inputWriter, err := cw.Cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
|
||||
}
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
return sender, inputWriter, nil
|
||||
}
|
||||
|
||||
func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
|
||||
stdoutReader, err := cw.Cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
|
||||
}
|
||||
stderrReader, err := cw.Cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
|
||||
return packetParser, stdoutReader, stderrReader, nil
|
||||
}
|
||||
|
||||
func (cw CmdWrap) Start() error {
|
||||
return cw.Cmd.Start()
|
||||
}
|
||||
|
||||
type SessionWrap struct {
|
||||
Session *ssh.Session
|
||||
StartCmd string
|
||||
}
|
||||
|
||||
func (sw SessionWrap) Kill() {
|
||||
sw.Session.Close()
|
||||
}
|
||||
|
||||
func (sw SessionWrap) Wait() error {
|
||||
return sw.Session.Wait()
|
||||
}
|
||||
|
||||
func (sw SessionWrap) Start() error {
|
||||
return sw.Session.Start(sw.StartCmd)
|
||||
}
|
||||
|
||||
func (sw SessionWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
|
||||
inputWriter, err := sw.Session.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
|
||||
}
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
return sender, inputWriter, nil
|
||||
}
|
||||
|
||||
func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
|
||||
stdoutReader, err := sw.Session.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
|
||||
}
|
||||
stderrReader, err := sw.Session.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
|
||||
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
|
||||
}
|
||||
|
||||
type ConnInterface interface {
|
||||
Kill()
|
||||
Wait() error
|
||||
Sender() (*packet.PacketSender, io.WriteCloser, error)
|
||||
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
|
||||
Start() error
|
||||
}
|
||||
|
||||
type ClientProc struct {
|
||||
Cmd *exec.Cmd
|
||||
Cmd ConnInterface
|
||||
InitPk *packet.InitPacketType
|
||||
StartTs time.Time
|
||||
StdinWriter io.WriteCloser
|
||||
@ -31,28 +121,20 @@ type ClientProc struct {
|
||||
}
|
||||
|
||||
// returns (clientproc, initpk, error)
|
||||
func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.InitPacketType, error) {
|
||||
inputWriter, err := ecmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
|
||||
}
|
||||
stdoutReader, err := ecmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
|
||||
}
|
||||
stderrReader, err := ecmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) {
|
||||
startTs := time.Now()
|
||||
sender, inputWriter, err := ecmd.Sender()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("running local client: %w", err)
|
||||
}
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
|
||||
cproc := &ClientProc{
|
||||
Cmd: ecmd,
|
||||
StartTs: startTs,
|
||||
@ -107,7 +189,7 @@ func (cproc *ClientProc) Close() {
|
||||
cproc.StderrReader.Close()
|
||||
}
|
||||
if cproc.Cmd != nil {
|
||||
cproc.Cmd.Process.Kill()
|
||||
cproc.Cmd.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,9 +36,12 @@ import (
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
const UseSshLibrary = false
|
||||
|
||||
const RemoteTypeMShell = "mshell"
|
||||
const DefaultTerm = "xterm-256color"
|
||||
const DefaultMaxPtySize = 1024 * 1024
|
||||
@ -77,6 +80,12 @@ else
|
||||
fi
|
||||
`
|
||||
|
||||
const WaveshellServerRunOnlyFmt = `
|
||||
PATH=$PATH:~/.mshell;
|
||||
[%PINGPACKET%]
|
||||
mshell-[%VERSION%] --server
|
||||
`
|
||||
|
||||
func MakeLocalMShellCommandStr(isSudo bool) (string, error) {
|
||||
mshellPath, err := scbase.LocalMShellBinaryPath()
|
||||
if err != nil {
|
||||
@ -95,6 +104,13 @@ func MakeServerCommandStr() string {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func MakeServerRunOnlyCommandStr() string {
|
||||
rtn := strings.ReplaceAll(WaveshellServerRunOnlyFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
|
||||
rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket)
|
||||
return rtn
|
||||
|
||||
}
|
||||
|
||||
const (
|
||||
StatusConnected = sstore.RemoteStatus_Connected
|
||||
StatusConnecting = sstore.RemoteStatus_Connecting
|
||||
@ -121,6 +137,12 @@ type pendingStateKey struct {
|
||||
RemotePtr sstore.RemotePtrType
|
||||
}
|
||||
|
||||
// for conditional launch method based on ssh library in use
|
||||
// remove once ssh library is stabilized
|
||||
type Launcher interface {
|
||||
Launch(*MShellProc, bool)
|
||||
}
|
||||
|
||||
type MShellProc struct {
|
||||
Lock *sync.Mutex
|
||||
Remote *sstore.RemoteType
|
||||
@ -149,6 +171,7 @@ type MShellProc struct {
|
||||
RunningCmds map[base.CommandKey]RunCmdType
|
||||
WaitingCmds []RunCmdType
|
||||
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name]
|
||||
launcher Launcher // for conditional launch method based on ssh library in use. remove once ssh library is stabilized
|
||||
}
|
||||
|
||||
type RunCmdType struct {
|
||||
@ -169,6 +192,12 @@ func CanComplete(remoteType string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// for conditional launch method based on ssh library in use
|
||||
// remove once ssh library is stabilized
|
||||
func (msh *MShellProc) Launch(interactive bool) {
|
||||
msh.launcher.Launch(msh, interactive)
|
||||
}
|
||||
|
||||
func (msh *MShellProc) GetStatus() string {
|
||||
msh.Lock.Lock()
|
||||
defer msh.Lock.Unlock()
|
||||
@ -662,7 +691,14 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
|
||||
RunningCmds: make(map[base.CommandKey]RunCmdType),
|
||||
PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
|
||||
StateMap: server.MakeShellStateMap(),
|
||||
launcher: LegacyLauncher{}, // for conditional launch method based on ssh library in use. remove once ssh library is stabilized
|
||||
}
|
||||
// for conditional launch method based on ssh library in use
|
||||
// remove once ssh library is stabilized
|
||||
if UseSshLibrary {
|
||||
rtn.launcher = NewLauncher{}
|
||||
}
|
||||
|
||||
rtn.WriteToPtyBuffer("console for connection [%s]\n", r.GetName())
|
||||
return rtn
|
||||
}
|
||||
@ -1218,7 +1254,179 @@ func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error
|
||||
return utilfn.CombineStrArrays(rtn, activeShells), nil
|
||||
}
|
||||
|
||||
func (msh *MShellProc) Launch(interactive bool) {
|
||||
// for conditional launch method based on ssh library in use
|
||||
// remove once ssh library is stabilized
|
||||
type NewLauncher struct{}
|
||||
|
||||
// func (msh *MShellProc) LaunchNew(interactive bool) {
|
||||
func (NewLauncher) Launch(msh *MShellProc, interactive bool) {
|
||||
remoteCopy := msh.GetRemoteCopy()
|
||||
if remoteCopy.Archived {
|
||||
msh.WriteToPtyBuffer("cannot launch archived remote\n")
|
||||
return
|
||||
}
|
||||
curStatus := msh.GetStatus()
|
||||
if curStatus == StatusConnected {
|
||||
msh.WriteToPtyBuffer("remote is already connected (no action taken)\n")
|
||||
return
|
||||
}
|
||||
if curStatus == StatusConnecting {
|
||||
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
|
||||
return
|
||||
}
|
||||
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error, %v\n", err)
|
||||
return
|
||||
}
|
||||
istatus := msh.GetInstallStatus()
|
||||
if istatus == StatusConnecting {
|
||||
msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n")
|
||||
return
|
||||
}
|
||||
if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 {
|
||||
msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort)
|
||||
} else {
|
||||
msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName)
|
||||
}
|
||||
sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
|
||||
sshOpts.SSHErrorsToTty = true
|
||||
if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive {
|
||||
sshOpts.BatchMode = true
|
||||
}
|
||||
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
|
||||
defer makeClientCancelFn()
|
||||
msh.WithLock(func() {
|
||||
msh.Err = nil
|
||||
msh.ErrNoInitPk = false
|
||||
msh.Status = StatusConnecting
|
||||
msh.MakeClientCancelFn = makeClientCancelFn
|
||||
deadlineTime := time.Now().Add(RemoteConnectTimeout)
|
||||
msh.MakeClientDeadline = &deadlineTime
|
||||
go msh.NotifyRemoteUpdate()
|
||||
})
|
||||
go msh.watchClientDeadlineTime()
|
||||
var cmdStr string
|
||||
var cproc *shexec.ClientProc
|
||||
var initPk *packet.InitPacketType
|
||||
if sshOpts.SSHHost == "" && remoteCopy.Local {
|
||||
cmdStr, err = MakeLocalMShellCommandStr(remoteCopy.IsSudo())
|
||||
if err != nil {
|
||||
msh.WriteToPtyBuffer("*error, cannot find local mshell binary: %v\n", err)
|
||||
return
|
||||
}
|
||||
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
|
||||
var cmdPty *os.File
|
||||
cmdPty, err = msh.addControllingTty(ecmd)
|
||||
if err != nil {
|
||||
statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err)
|
||||
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
|
||||
msh.setErrorStatus(statusErr)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if len(ecmd.ExtraFiles) > 0 {
|
||||
ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close()
|
||||
}
|
||||
}()
|
||||
go msh.RunPtyReadLoop(cmdPty)
|
||||
if remoteCopy.SSHOpts.SSHPassword != "" {
|
||||
go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword)
|
||||
}
|
||||
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
|
||||
} else {
|
||||
var client *ssh.Client
|
||||
client, err = ConnectToClient(remoteCopy.SSHOpts)
|
||||
if err != nil {
|
||||
statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
|
||||
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
|
||||
msh.setErrorStatus(statusErr)
|
||||
return
|
||||
}
|
||||
var session *ssh.Session
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
statusErr := fmt.Errorf("ssh cannot create session: %w", err)
|
||||
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
|
||||
msh.setErrorStatus(statusErr)
|
||||
return
|
||||
}
|
||||
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()})
|
||||
}
|
||||
// TODO check if initPk.State is not nil
|
||||
var mshellVersion string
|
||||
var hitDeadline bool
|
||||
msh.WithLock(func() {
|
||||
msh.MakeClientCancelFn = nil
|
||||
if time.Now().After(*msh.MakeClientDeadline) {
|
||||
hitDeadline = true
|
||||
}
|
||||
msh.MakeClientDeadline = nil
|
||||
if initPk == nil {
|
||||
msh.ErrNoInitPk = true
|
||||
}
|
||||
if initPk != nil {
|
||||
msh.UName = initPk.UName
|
||||
mshellVersion = initPk.Version
|
||||
if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 {
|
||||
// only set NeedsMShellUpgrade if we got an InitPk
|
||||
msh.NeedsMShellUpgrade = true
|
||||
}
|
||||
msh.InitPkShellType = initPk.Shell
|
||||
}
|
||||
msh.StateMap.Clear()
|
||||
// no notify here, because we'll call notify in either case below
|
||||
})
|
||||
if err == context.Canceled {
|
||||
if hitDeadline {
|
||||
msh.WriteToPtyBuffer("*connect timeout\n")
|
||||
msh.setErrorStatus(errors.New("connect timeout"))
|
||||
} else {
|
||||
msh.WriteToPtyBuffer("*forced disconnection\n")
|
||||
msh.WithLock(func() {
|
||||
msh.Status = StatusDisconnected
|
||||
go msh.NotifyRemoteUpdate()
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) {
|
||||
err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion)
|
||||
}
|
||||
if err != nil {
|
||||
msh.setErrorStatus(err)
|
||||
msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err)
|
||||
go msh.tryAutoInstall()
|
||||
return
|
||||
}
|
||||
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
|
||||
msh.WithLock(func() {
|
||||
msh.ServerProc = cproc
|
||||
msh.Status = StatusConnected
|
||||
})
|
||||
go func() {
|
||||
exitErr := cproc.Cmd.Wait()
|
||||
exitCode := shexec.GetExitCode(exitErr)
|
||||
msh.WithLock(func() {
|
||||
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
|
||||
msh.Status = StatusDisconnected
|
||||
go msh.NotifyRemoteUpdate()
|
||||
}
|
||||
})
|
||||
msh.WriteToPtyBuffer("*disconnected exitcode=%d\n", exitCode)
|
||||
}()
|
||||
go msh.ProcessPackets()
|
||||
msh.initActiveShells()
|
||||
go msh.NotifyRemoteUpdate()
|
||||
return
|
||||
}
|
||||
|
||||
// for conditional launch method based on ssh library in use
|
||||
// remove once ssh library is stabilized
|
||||
type LegacyLauncher struct{}
|
||||
|
||||
// func (msh *MShellProc) LaunchLegacy(interactive bool) {
|
||||
func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) {
|
||||
remoteCopy := msh.GetRemoteCopy()
|
||||
if remoteCopy.Archived {
|
||||
msh.WriteToPtyBuffer("cannot launch archived remote\n")
|
||||
@ -1293,7 +1501,7 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
go msh.NotifyRemoteUpdate()
|
||||
})
|
||||
go msh.watchClientDeadlineTime()
|
||||
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, ecmd)
|
||||
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
|
||||
// TODO check if initPk.State is not nil
|
||||
var mshellVersion string
|
||||
var hitDeadline bool
|
||||
|
108
wavesrv/pkg/remote/sshclient.go
Normal file
108
wavesrv/pkg/remote/sshclient.go
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package remote
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func createPublicKeyAuth(identityFile string, passphrase string) (ssh.AuthMethod, error) {
|
||||
privateKey, err := os.ReadFile(base.ExpandHomeDir(identityFile))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read ssh key file. err: %+v", err)
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, &ssh.PassphraseMissingError{}) {
|
||||
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private ssh key with passphrase. err: %+v", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to parse private ssh key. err: %+v", err)
|
||||
}
|
||||
}
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
func createKeyboardInteractiveAuth(password string) ssh.AuthMethod {
|
||||
challenge := func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||
for _, q := range questions {
|
||||
if strings.Contains(strings.ToLower(q), "password") {
|
||||
answers = append(answers, password)
|
||||
} else {
|
||||
answers = append(answers, "")
|
||||
}
|
||||
}
|
||||
return answers, nil
|
||||
}
|
||||
return ssh.KeyboardInteractive(challenge)
|
||||
}
|
||||
|
||||
func ConnectToClient(opts *sstore.SSHOpts) (*ssh.Client, error) {
|
||||
ssh_config.ReloadConfigs()
|
||||
configIdentity, _ := ssh_config.GetStrict(opts.SSHHost, "IdentityFile")
|
||||
var identityFile string
|
||||
if opts.SSHIdentity != "" {
|
||||
identityFile = opts.SSHIdentity
|
||||
} else {
|
||||
identityFile = configIdentity
|
||||
}
|
||||
|
||||
hostKeyCallback := ssh.InsecureIgnoreHostKey()
|
||||
var authMethods []ssh.AuthMethod
|
||||
publicKeyAuth, err := createPublicKeyAuth(identityFile, opts.SSHPassword)
|
||||
if err == nil {
|
||||
authMethods = append(authMethods, publicKeyAuth)
|
||||
}
|
||||
authMethods = append(authMethods, createKeyboardInteractiveAuth(opts.SSHPassword))
|
||||
authMethods = append(authMethods, ssh.Password(opts.SSHPassword))
|
||||
|
||||
configUser, _ := ssh_config.GetStrict(opts.SSHHost, "User")
|
||||
configHostName, _ := ssh_config.GetStrict(opts.SSHHost, "HostName")
|
||||
configPort, _ := ssh_config.GetStrict(opts.SSHHost, "Port")
|
||||
var username string
|
||||
if opts.SSHUser != "" {
|
||||
username = opts.SSHUser
|
||||
} else if configUser != "" {
|
||||
username = configUser
|
||||
} else {
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user for ssh: %+v", err)
|
||||
}
|
||||
username = user.Username
|
||||
}
|
||||
var hostName string
|
||||
if configHostName != "" {
|
||||
hostName = configHostName
|
||||
} else {
|
||||
hostName = opts.SSHHost
|
||||
}
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
}
|
||||
var port string
|
||||
if opts.SSHPort != 0 && opts.SSHPort != 22 {
|
||||
port = strconv.Itoa(opts.SSHPort)
|
||||
} else if configPort != "" && configPort != "22" {
|
||||
port = configPort
|
||||
} else {
|
||||
port = "22"
|
||||
}
|
||||
networkAddr := hostName + ":" + port
|
||||
return ssh.Dial("tcp", networkAddr, clientConfig)
|
||||
}
|
Loading…
Reference in New Issue
Block a user