Use ssh library for remote connections (#250)

* create proof of concept ssh library integration

This is a first attempt to integrate the golang crypto/ssh library for
handling remote connections. As it stands, this features is limited to
identity files without passphrases. It needs to be expanded to include
key+passphrase and password verifications as well.

* add password and keyboard-interactive ssh auth

This adds several new ssh auth methods. In addition to the PublicKey
method used previously, this adds password authentication,
keyboard-interactive authentication, and PublicKey+Passphrase
authentication.

Furthermore, it refactores the ssh connection code into its own wavesrv
file rather than storing int in waveshell's shexec file.

* clean up old mshell launch methods

In the debugging the addition of the ssh library, i had several versions
of the MShellProc Launch function. Since this seems mostly stable, I
have removed the old version and the experimental version in favor of
the combined version.

* allow switching between new and old ssh for dev

It is inconvenient to create milestones without being able to merge into
the main branch. But due to the experimental nature of the ssh changes,
it is not desired to use these changes in the main branch yet. This
change disables the new ssh launcher by default. It can be used by
changing the UseSshLibrary constant to true in remote.go. With this, it
becomes possible to merge these changes into the main branch without
them being used in production.

* fix: allow retry after ssh auth failure

Previously, the error status was not set when an ssh connection failed.
Because of this, an ssh connection failure would lock the failed remote
until waveterm was rebooted. This fix properly sets the error status so
this cannot happen.
This commit is contained in:
Sylvie Crowe 2024-01-25 10:18:11 -08:00 committed by GitHub
parent 99f5c094d2
commit 018bb14b6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 420 additions and 22 deletions

View File

@ -651,7 +651,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
}
cproc, _, err := shexec.MakeClientProc(context.Background(), ecmd)
cproc, _, err := shexec.MakeClientProc(context.Background(), shexec.CmdWrap{Cmd: ecmd})
if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
return

View File

@ -12,6 +12,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
)
@ -19,8 +20,97 @@ import (
const NotFoundVersion = "v0.0"
type CmdWrap struct {
Cmd *exec.Cmd
}
func (cw CmdWrap) Kill() {
cw.Cmd.Process.Kill()
}
func (cw CmdWrap) Wait() error {
return cw.Cmd.Wait()
}
func (cw CmdWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
inputWriter, err := cw.Cmd.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
return sender, inputWriter, nil
}
func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
stdoutReader, err := cw.Cmd.StdoutPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := cw.Cmd.StderrPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
return packetParser, stdoutReader, stderrReader, nil
}
func (cw CmdWrap) Start() error {
return cw.Cmd.Start()
}
type SessionWrap struct {
Session *ssh.Session
StartCmd string
}
func (sw SessionWrap) Kill() {
sw.Session.Close()
}
func (sw SessionWrap) Wait() error {
return sw.Session.Wait()
}
func (sw SessionWrap) Start() error {
return sw.Session.Start(sw.StartCmd)
}
func (sw SessionWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
inputWriter, err := sw.Session.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
return sender, inputWriter, nil
}
func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
stdoutReader, err := sw.Session.StdoutPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := sw.Session.StderrPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
}
type ConnInterface interface {
Kill()
Wait() error
Sender() (*packet.PacketSender, io.WriteCloser, error)
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
Start() error
}
type ClientProc struct {
Cmd *exec.Cmd
Cmd ConnInterface
InitPk *packet.InitPacketType
StartTs time.Time
StdinWriter io.WriteCloser
@ -31,28 +121,20 @@ type ClientProc struct {
}
// returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.InitPacketType, error) {
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) {
startTs := time.Now()
sender, inputWriter, err := ecmd.Sender()
if err != nil {
return nil, nil, err
}
packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
if err != nil {
return nil, nil, err
}
err = ecmd.Start()
if err != nil {
return nil, nil, fmt.Errorf("running local client: %w", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
cproc := &ClientProc{
Cmd: ecmd,
StartTs: startTs,
@ -107,7 +189,7 @@ func (cproc *ClientProc) Close() {
cproc.StderrReader.Close()
}
if cproc.Cmd != nil {
cproc.Cmd.Process.Kill()
cproc.Cmd.Kill()
}
}

View File

@ -36,9 +36,12 @@ import (
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbase"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
)
const UseSshLibrary = false
const RemoteTypeMShell = "mshell"
const DefaultTerm = "xterm-256color"
const DefaultMaxPtySize = 1024 * 1024
@ -77,6 +80,12 @@ else
fi
`
const WaveshellServerRunOnlyFmt = `
PATH=$PATH:~/.mshell;
[%PINGPACKET%]
mshell-[%VERSION%] --server
`
func MakeLocalMShellCommandStr(isSudo bool) (string, error) {
mshellPath, err := scbase.LocalMShellBinaryPath()
if err != nil {
@ -95,6 +104,13 @@ func MakeServerCommandStr() string {
return rtn
}
func MakeServerRunOnlyCommandStr() string {
rtn := strings.ReplaceAll(WaveshellServerRunOnlyFmt, "[%VERSION%]", semver.MajorMinor(scbase.MShellVersion))
rtn = strings.ReplaceAll(rtn, "[%PINGPACKET%]", PrintPingPacket)
return rtn
}
const (
StatusConnected = sstore.RemoteStatus_Connected
StatusConnecting = sstore.RemoteStatus_Connecting
@ -121,6 +137,12 @@ type pendingStateKey struct {
RemotePtr sstore.RemotePtrType
}
// for conditional launch method based on ssh library in use
// remove once ssh library is stabilized
type Launcher interface {
Launch(*MShellProc, bool)
}
type MShellProc struct {
Lock *sync.Mutex
Remote *sstore.RemoteType
@ -149,6 +171,7 @@ type MShellProc struct {
RunningCmds map[base.CommandKey]RunCmdType
WaitingCmds []RunCmdType
PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name]
launcher Launcher // for conditional launch method based on ssh library in use. remove once ssh library is stabilized
}
type RunCmdType struct {
@ -169,6 +192,12 @@ func CanComplete(remoteType string) bool {
}
}
// for conditional launch method based on ssh library in use
// remove once ssh library is stabilized
func (msh *MShellProc) Launch(interactive bool) {
msh.launcher.Launch(msh, interactive)
}
func (msh *MShellProc) GetStatus() string {
msh.Lock.Lock()
defer msh.Lock.Unlock()
@ -662,7 +691,14 @@ func MakeMShell(r *sstore.RemoteType) *MShellProc {
RunningCmds: make(map[base.CommandKey]RunCmdType),
PendingStateCmds: make(map[pendingStateKey]base.CommandKey),
StateMap: server.MakeShellStateMap(),
launcher: LegacyLauncher{}, // for conditional launch method based on ssh library in use. remove once ssh library is stabilized
}
// for conditional launch method based on ssh library in use
// remove once ssh library is stabilized
if UseSshLibrary {
rtn.launcher = NewLauncher{}
}
rtn.WriteToPtyBuffer("console for connection [%s]\n", r.GetName())
return rtn
}
@ -1218,7 +1254,179 @@ func (msh *MShellProc) getActiveShellTypes(ctx context.Context) ([]string, error
return utilfn.CombineStrArrays(rtn, activeShells), nil
}
func (msh *MShellProc) Launch(interactive bool) {
// for conditional launch method based on ssh library in use
// remove once ssh library is stabilized
type NewLauncher struct{}
// func (msh *MShellProc) LaunchNew(interactive bool) {
func (NewLauncher) Launch(msh *MShellProc, interactive bool) {
remoteCopy := msh.GetRemoteCopy()
if remoteCopy.Archived {
msh.WriteToPtyBuffer("cannot launch archived remote\n")
return
}
curStatus := msh.GetStatus()
if curStatus == StatusConnected {
msh.WriteToPtyBuffer("remote is already connected (no action taken)\n")
return
}
if curStatus == StatusConnecting {
msh.WriteToPtyBuffer("remote is already connecting, disconnect before trying to connect again\n")
return
}
sapi, err := shellapi.MakeShellApi(msh.GetShellType())
if err != nil {
msh.WriteToPtyBuffer("*error, %v\n", err)
return
}
istatus := msh.GetInstallStatus()
if istatus == StatusConnecting {
msh.WriteToPtyBuffer("remote is trying to install, cancel install before trying to connect again\n")
return
}
if remoteCopy.SSHOpts.SSHPort != 0 && remoteCopy.SSHOpts.SSHPort != 22 {
msh.WriteToPtyBuffer("connecting to %s (port %d)...\n", remoteCopy.RemoteCanonicalName, remoteCopy.SSHOpts.SSHPort)
} else {
msh.WriteToPtyBuffer("connecting to %s...\n", remoteCopy.RemoteCanonicalName)
}
sshOpts := convertSSHOpts(remoteCopy.SSHOpts)
sshOpts.SSHErrorsToTty = true
if remoteCopy.ConnectMode != sstore.ConnectModeManual && remoteCopy.SSHOpts.SSHPassword == "" && !interactive {
sshOpts.BatchMode = true
}
makeClientCtx, makeClientCancelFn := context.WithCancel(context.Background())
defer makeClientCancelFn()
msh.WithLock(func() {
msh.Err = nil
msh.ErrNoInitPk = false
msh.Status = StatusConnecting
msh.MakeClientCancelFn = makeClientCancelFn
deadlineTime := time.Now().Add(RemoteConnectTimeout)
msh.MakeClientDeadline = &deadlineTime
go msh.NotifyRemoteUpdate()
})
go msh.watchClientDeadlineTime()
var cmdStr string
var cproc *shexec.ClientProc
var initPk *packet.InitPacketType
if sshOpts.SSHHost == "" && remoteCopy.Local {
cmdStr, err = MakeLocalMShellCommandStr(remoteCopy.IsSudo())
if err != nil {
msh.WriteToPtyBuffer("*error, cannot find local mshell binary: %v\n", err)
return
}
ecmd := sshOpts.MakeSSHExecCmd(cmdStr, sapi)
var cmdPty *os.File
cmdPty, err = msh.addControllingTty(ecmd)
if err != nil {
statusErr := fmt.Errorf("cannot attach controlling tty to mshell command: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
defer func() {
if len(ecmd.ExtraFiles) > 0 {
ecmd.ExtraFiles[len(ecmd.ExtraFiles)-1].Close()
}
}()
go msh.RunPtyReadLoop(cmdPty)
if remoteCopy.SSHOpts.SSHPassword != "" {
go msh.WaitAndSendPassword(remoteCopy.SSHOpts.SSHPassword)
}
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
} else {
var client *ssh.Client
client, err = ConnectToClient(remoteCopy.SSHOpts)
if err != nil {
statusErr := fmt.Errorf("ssh cannot connect to client: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
var session *ssh.Session
session, err = client.NewSession()
if err != nil {
statusErr := fmt.Errorf("ssh cannot create session: %w", err)
msh.WriteToPtyBuffer("*error, %s\n", statusErr.Error())
msh.setErrorStatus(statusErr)
return
}
cproc, initPk, err = shexec.MakeClientProc(makeClientCtx, shexec.SessionWrap{Session: session, StartCmd: MakeServerRunOnlyCommandStr()})
}
// TODO check if initPk.State is not nil
var mshellVersion string
var hitDeadline bool
msh.WithLock(func() {
msh.MakeClientCancelFn = nil
if time.Now().After(*msh.MakeClientDeadline) {
hitDeadline = true
}
msh.MakeClientDeadline = nil
if initPk == nil {
msh.ErrNoInitPk = true
}
if initPk != nil {
msh.UName = initPk.UName
mshellVersion = initPk.Version
if semver.Compare(mshellVersion, scbase.MShellVersion) < 0 {
// only set NeedsMShellUpgrade if we got an InitPk
msh.NeedsMShellUpgrade = true
}
msh.InitPkShellType = initPk.Shell
}
msh.StateMap.Clear()
// no notify here, because we'll call notify in either case below
})
if err == context.Canceled {
if hitDeadline {
msh.WriteToPtyBuffer("*connect timeout\n")
msh.setErrorStatus(errors.New("connect timeout"))
} else {
msh.WriteToPtyBuffer("*forced disconnection\n")
msh.WithLock(func() {
msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate()
})
}
return
}
if err == nil && semver.MajorMinor(mshellVersion) != semver.MajorMinor(scbase.MShellVersion) {
err = fmt.Errorf("mshell version is not compatible current=%s remote=%s", scbase.MShellVersion, mshellVersion)
}
if err != nil {
msh.setErrorStatus(err)
msh.WriteToPtyBuffer("*error connecting to remote: %v\n", err)
go msh.tryAutoInstall()
return
}
msh.updateRemoteStateVars(context.Background(), msh.RemoteId, initPk)
msh.WithLock(func() {
msh.ServerProc = cproc
msh.Status = StatusConnected
})
go func() {
exitErr := cproc.Cmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
msh.Status = StatusDisconnected
go msh.NotifyRemoteUpdate()
}
})
msh.WriteToPtyBuffer("*disconnected exitcode=%d\n", exitCode)
}()
go msh.ProcessPackets()
msh.initActiveShells()
go msh.NotifyRemoteUpdate()
return
}
// for conditional launch method based on ssh library in use
// remove once ssh library is stabilized
type LegacyLauncher struct{}
// func (msh *MShellProc) LaunchLegacy(interactive bool) {
func (LegacyLauncher) Launch(msh *MShellProc, interactive bool) {
remoteCopy := msh.GetRemoteCopy()
if remoteCopy.Archived {
msh.WriteToPtyBuffer("cannot launch archived remote\n")
@ -1293,7 +1501,7 @@ func (msh *MShellProc) Launch(interactive bool) {
go msh.NotifyRemoteUpdate()
})
go msh.watchClientDeadlineTime()
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, ecmd)
cproc, initPk, err := shexec.MakeClientProc(makeClientCtx, shexec.CmdWrap{Cmd: ecmd})
// TODO check if initPk.State is not nil
var mshellVersion string
var hitDeadline bool

View File

@ -0,0 +1,108 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package remote
import (
"errors"
"fmt"
"os"
"os/user"
"strconv"
"strings"
"github.com/kevinburke/ssh_config"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"golang.org/x/crypto/ssh"
)
func createPublicKeyAuth(identityFile string, passphrase string) (ssh.AuthMethod, error) {
privateKey, err := os.ReadFile(base.ExpandHomeDir(identityFile))
if err != nil {
return nil, fmt.Errorf("failed to read ssh key file. err: %+v", err)
}
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
if errors.Is(err, &ssh.PassphraseMissingError{}) {
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
if err != nil {
return nil, fmt.Errorf("failed to parse private ssh key with passphrase. err: %+v", err)
}
} else {
return nil, fmt.Errorf("failed to parse private ssh key. err: %+v", err)
}
}
return ssh.PublicKeys(signer), nil
}
func createKeyboardInteractiveAuth(password string) ssh.AuthMethod {
challenge := func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
for _, q := range questions {
if strings.Contains(strings.ToLower(q), "password") {
answers = append(answers, password)
} else {
answers = append(answers, "")
}
}
return answers, nil
}
return ssh.KeyboardInteractive(challenge)
}
func ConnectToClient(opts *sstore.SSHOpts) (*ssh.Client, error) {
ssh_config.ReloadConfigs()
configIdentity, _ := ssh_config.GetStrict(opts.SSHHost, "IdentityFile")
var identityFile string
if opts.SSHIdentity != "" {
identityFile = opts.SSHIdentity
} else {
identityFile = configIdentity
}
hostKeyCallback := ssh.InsecureIgnoreHostKey()
var authMethods []ssh.AuthMethod
publicKeyAuth, err := createPublicKeyAuth(identityFile, opts.SSHPassword)
if err == nil {
authMethods = append(authMethods, publicKeyAuth)
}
authMethods = append(authMethods, createKeyboardInteractiveAuth(opts.SSHPassword))
authMethods = append(authMethods, ssh.Password(opts.SSHPassword))
configUser, _ := ssh_config.GetStrict(opts.SSHHost, "User")
configHostName, _ := ssh_config.GetStrict(opts.SSHHost, "HostName")
configPort, _ := ssh_config.GetStrict(opts.SSHHost, "Port")
var username string
if opts.SSHUser != "" {
username = opts.SSHUser
} else if configUser != "" {
username = configUser
} else {
user, err := user.Current()
if err != nil {
return nil, fmt.Errorf("failed to get user for ssh: %+v", err)
}
username = user.Username
}
var hostName string
if configHostName != "" {
hostName = configHostName
} else {
hostName = opts.SSHHost
}
clientConfig := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
}
var port string
if opts.SSHPort != 0 && opts.SSHPort != 22 {
port = strconv.Itoa(opts.SSHPort)
} else if configPort != "" && configPort != "22" {
port = configPort
} else {
port = "22"
}
networkAddr := hostName + ":" + port
return ssh.Dial("tcp", networkAddr, clientConfig)
}