waveterm/waveshell/pkg/shexec/client.go
Sylvie Crowe 018bb14b6a
Use ssh library for remote connections (#250)
* create proof of concept ssh library integration

This is a first attempt to integrate the golang crypto/ssh library for
handling remote connections. As it stands, this features is limited to
identity files without passphrases. It needs to be expanded to include
key+passphrase and password verifications as well.

* add password and keyboard-interactive ssh auth

This adds several new ssh auth methods. In addition to the PublicKey
method used previously, this adds password authentication,
keyboard-interactive authentication, and PublicKey+Passphrase
authentication.

Furthermore, it refactores the ssh connection code into its own wavesrv
file rather than storing int in waveshell's shexec file.

* clean up old mshell launch methods

In the debugging the addition of the ssh library, i had several versions
of the MShellProc Launch function. Since this seems mostly stable, I
have removed the old version and the experimental version in favor of
the combined version.

* allow switching between new and old ssh for dev

It is inconvenient to create milestones without being able to merge into
the main branch. But due to the experimental nature of the ssh changes,
it is not desired to use these changes in the main branch yet. This
change disables the new ssh launcher by default. It can be used by
changing the UseSshLibrary constant to true in remote.go. With this, it
becomes possible to merge these changes into the main branch without
them being used in production.

* fix: allow retry after ssh auth failure

Previously, the error status was not set when an ssh connection failed.
Because of this, an ssh connection failure would lock the failed remote
until waveterm was rebooted. This fix properly sets the error status so
this cannot happen.
2024-01-25 10:18:11 -08:00

218 lines
5.8 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shexec
import (
"context"
"fmt"
"io"
"os/exec"
"time"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
)
// TODO - track buffer sizes for sending input
const NotFoundVersion = "v0.0"
type CmdWrap struct {
Cmd *exec.Cmd
}
func (cw CmdWrap) Kill() {
cw.Cmd.Process.Kill()
}
func (cw CmdWrap) Wait() error {
return cw.Cmd.Wait()
}
func (cw CmdWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
inputWriter, err := cw.Cmd.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
return sender, inputWriter, nil
}
func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
stdoutReader, err := cw.Cmd.StdoutPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := cw.Cmd.StderrPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
return packetParser, stdoutReader, stderrReader, nil
}
func (cw CmdWrap) Start() error {
return cw.Cmd.Start()
}
type SessionWrap struct {
Session *ssh.Session
StartCmd string
}
func (sw SessionWrap) Kill() {
sw.Session.Close()
}
func (sw SessionWrap) Wait() error {
return sw.Session.Wait()
}
func (sw SessionWrap) Start() error {
return sw.Session.Start(sw.StartCmd)
}
func (sw SessionWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
inputWriter, err := sw.Session.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
return sender, inputWriter, nil
}
func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
stdoutReader, err := sw.Session.StdoutPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := sw.Session.StderrPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
}
type ConnInterface interface {
Kill()
Wait() error
Sender() (*packet.PacketSender, io.WriteCloser, error)
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
Start() error
}
type ClientProc struct {
Cmd ConnInterface
InitPk *packet.InitPacketType
StartTs time.Time
StdinWriter io.WriteCloser
StdoutReader io.ReadCloser
StderrReader io.ReadCloser
Input *packet.PacketSender
Output *packet.PacketParser
}
// returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) {
startTs := time.Now()
sender, inputWriter, err := ecmd.Sender()
if err != nil {
return nil, nil, err
}
packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
if err != nil {
return nil, nil, err
}
err = ecmd.Start()
if err != nil {
return nil, nil, fmt.Errorf("running local client: %w", err)
}
cproc := &ClientProc{
Cmd: ecmd,
StartTs: startTs,
StdinWriter: inputWriter,
StdoutReader: stdoutReader,
StderrReader: stderrReader,
Input: sender,
Output: packetParser,
}
var pk packet.PacketType
select {
case pk = <-packetParser.MainCh:
case <-ctx.Done():
cproc.Close()
return nil, nil, ctx.Err()
}
if pk != nil {
if pk.GetType() != packet.InitPacketStr {
cproc.Close()
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
}
initPk := pk.(*packet.InitPacketType)
if initPk.NotFound {
cproc.Close()
return nil, initPk, fmt.Errorf("mshell client not found")
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close()
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
}
cproc.InitPk = initPk
}
if cproc.InitPk == nil {
cproc.Close()
return nil, nil, fmt.Errorf("no init packet received from mshell client")
}
return cproc, cproc.InitPk, nil
}
func (cproc *ClientProc) Close() {
if cproc.Input != nil {
cproc.Input.Close()
}
if cproc.StdinWriter != nil {
cproc.StdinWriter.Close()
}
if cproc.StdoutReader != nil {
cproc.StdoutReader.Close()
}
if cproc.StderrReader != nil {
cproc.StderrReader.Close()
}
if cproc.Cmd != nil {
cproc.Cmd.Kill()
}
}
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
sentDonePk := false
for pk := range cproc.Output.MainCh {
if packetCallback != nil {
packetCallback(pk)
}
if pk.GetType() == packet.CmdDonePacketStr {
sentDonePk = true
}
sender.SendPacket(pk)
}
exitErr := cproc.Cmd.Wait()
if !sentDonePk {
endTs := time.Now()
cmdDuration := endTs.Sub(cproc.StartTs)
donePacket := packet.MakeCmdDonePacket(ck)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetExitCode(exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
sender.SendPacket(donePacket)
}
}