waveterm/waveshell/pkg/shexec/client.go
Sylvie Crowe 6c115716b0
Integrate SSH Library with Waveshell Installation/Auto-update (#322)
* refactor launch code to integrate install easier

The previous set up of launch was difficult to navigate. This makes it
much clearer which will make the auto install flow easier to manage.

* feat: integrate auto install into new ssh setup

This change makes it possible to auto install using the ssh library
instead of making a call to the ssh cli command. This will auto install
if the installed waveshell version is incorrect or cannot be found.

* chore: clean up some lints for sshclient

There was a context that didn't have it's cancel function deferred and
an error that wasn't being handle. They're fixed now.

* fix: disconnect client if requested or launch fail

A recent commit made it so a client remained part of the MShellProc
after being disconnected. This is undesireable since a manual
disconnection indicates that the user will need to enter their
credentials again if required. Similarly, if the launch fails with an
error, the expectation is that credentials will have to be entered
again.

* fix: use legacy timer for the time being

The legacy timer frustrates me because it adds a lot of state to the
MShellProc struct that is complicated to manage. But it currently works,
so I will be keeping it for the time being.

* fix: change separator between remoteref and name

With the inclusion of the port number in the canonical id, the :
separator between the remoteref and remote name causes problems if the
port is parsed instead. This changes it to a # in order to avoid this
conflict.

* fix: check for null when closing extra files

It is possible for the list of extra files to contain null files. This
change ensures the null files will not be erroneously closed.

* fix: change connecting method to show port once

With port added to the canonicalname, it no longer makes sense to append
the port afterward.

* feat: use user input modal for sudo connection

The sudo connection used to have a unique way of entering a password.
This change provides an alternative method using the user input modal
that the other connection methods use. It does not work perfectly with
this revision, but the basic building blocks are in place. It needs a
few timer updates to be complete.

* fix: remove old timer to prevent conflicts with it

With this change the old timer is no longer needed. It is not fully
removed yet, but it is disabled so as to not get in the way.
Additionally, error handling has been slightly improved.

There is still a bug where an incorrect password prints a new password
prompt after the error message. That needs to be fixed in the future.
2024-02-29 11:37:03 -08:00

282 lines
7.2 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shexec
import (
"context"
"fmt"
"io"
"os/exec"
"time"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
)
// TODO - track buffer sizes for sending input
const NotFoundVersion = "v0.0"
type CmdWrap struct {
Cmd *exec.Cmd
}
func (cw CmdWrap) Kill() {
cw.Cmd.Process.Kill()
}
func (cw CmdWrap) Wait() error {
return cw.Cmd.Wait()
}
func (cw CmdWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
inputWriter, err := cw.Cmd.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
return sender, inputWriter, nil
}
func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
stdoutReader, err := cw.Cmd.StdoutPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := cw.Cmd.StderrPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
return packetParser, stdoutReader, stderrReader, nil
}
func (cw CmdWrap) Start() error {
defer func() {
for _, extraFile := range cw.Cmd.ExtraFiles {
if extraFile != nil {
extraFile.Close()
}
}
}()
return cw.Cmd.Start()
}
func (cw CmdWrap) StdinPipe() (io.WriteCloser, error) {
return cw.Cmd.StdinPipe()
}
func (cw CmdWrap) StdoutPipe() (io.ReadCloser, error) {
return cw.Cmd.StdoutPipe()
}
func (cw CmdWrap) StderrPipe() (io.ReadCloser, error) {
return cw.Cmd.StderrPipe()
}
type SessionWrap struct {
Session *ssh.Session
StartCmd string
}
func (sw SessionWrap) Kill() {
sw.Session.Close()
}
func (sw SessionWrap) Wait() error {
return sw.Session.Wait()
}
func (sw SessionWrap) Start() error {
return sw.Session.Start(sw.StartCmd)
}
func (sw SessionWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
inputWriter, err := sw.Session.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
return sender, inputWriter, nil
}
func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
stdoutReader, err := sw.Session.StdoutPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := sw.Session.StderrPipe()
if err != nil {
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
}
func (sw SessionWrap) StdinPipe() (io.WriteCloser, error) {
return sw.Session.StdinPipe()
}
func (sw SessionWrap) StdoutPipe() (io.ReadCloser, error) {
stdoutReader, err := sw.Session.StdoutPipe()
if err != nil {
return nil, err
}
return io.NopCloser(stdoutReader), nil
}
func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
stderrReader, err := sw.Session.StderrPipe()
if err != nil {
return nil, err
}
return io.NopCloser(stderrReader), nil
}
type ConnInterface interface {
Kill()
Wait() error
Sender() (*packet.PacketSender, io.WriteCloser, error)
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
Start() error
StdinPipe() (io.WriteCloser, error)
StdoutPipe() (io.ReadCloser, error)
StderrPipe() (io.ReadCloser, error)
}
type ClientProc struct {
Cmd ConnInterface
InitPk *packet.InitPacketType
StartTs time.Time
StdinWriter io.WriteCloser
StdoutReader io.ReadCloser
StderrReader io.ReadCloser
Input *packet.PacketSender
Output *packet.PacketParser
}
type WaveshellLaunchError struct {
InitPk *packet.InitPacketType
}
func (wle WaveshellLaunchError) Error() string {
if wle.InitPk.NotFound {
return "waveshell client not found"
} else if semver.MajorMinor(wle.InitPk.Version) != semver.MajorMinor(base.MShellVersion) {
return fmt.Sprintf("invalid remote waveshell version '%s', must be '=%s'", wle.InitPk.Version, semver.MajorMinor(base.MShellVersion))
}
return fmt.Sprintf("invalid waveshell: init packet=%v", *wle.InitPk)
}
type InvalidPacketError struct {
InvalidPk *packet.PacketType
}
func (ipe InvalidPacketError) Error() string {
if ipe.InvalidPk == nil {
return "no init packet received from waveshell client"
}
return fmt.Sprintf("invalid packet received from waveshell client: %s", packet.AsString(*ipe.InvalidPk))
}
// returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, error) {
startTs := time.Now()
sender, inputWriter, err := ecmd.Sender()
if err != nil {
return nil, err
}
packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
if err != nil {
return nil, err
}
err = ecmd.Start()
if err != nil {
return nil, fmt.Errorf("running local client: %w", err)
}
cproc := &ClientProc{
Cmd: ecmd,
StartTs: startTs,
StdinWriter: inputWriter,
StdoutReader: stdoutReader,
StderrReader: stderrReader,
Input: sender,
Output: packetParser,
}
var pk packet.PacketType
select {
case pk = <-packetParser.MainCh:
case <-ctx.Done():
cproc.Close()
return nil, ctx.Err()
}
if pk == nil {
cproc.Close()
return nil, InvalidPacketError{}
}
if pk.GetType() != packet.InitPacketStr {
cproc.Close()
return nil, InvalidPacketError{InvalidPk: &pk}
}
initPk := pk.(*packet.InitPacketType)
if initPk.NotFound {
cproc.Close()
return nil, WaveshellLaunchError{InitPk: initPk}
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close()
return nil, WaveshellLaunchError{InitPk: initPk}
}
cproc.InitPk = initPk
return cproc, nil
}
func (cproc *ClientProc) Close() {
if cproc.Input != nil {
cproc.Input.Close()
}
if cproc.StdinWriter != nil {
cproc.StdinWriter.Close()
}
if cproc.StdoutReader != nil {
cproc.StdoutReader.Close()
}
if cproc.StderrReader != nil {
cproc.StderrReader.Close()
}
if cproc.Cmd != nil {
cproc.Cmd.Kill()
}
}
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
sentDonePk := false
for pk := range cproc.Output.MainCh {
if packetCallback != nil {
packetCallback(pk)
}
if pk.GetType() == packet.CmdDonePacketStr {
sentDonePk = true
}
sender.SendPacket(pk)
}
exitErr := cproc.Cmd.Wait()
if !sentDonePk {
endTs := time.Now()
cmdDuration := endTs.Sub(cproc.StartTs)
donePacket := packet.MakeCmdDonePacket(ck)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetExitCode(exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
sender.SendPacket(donePacket)
}
}