mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-22 21:42:49 +01:00
018bb14b6a
* create proof of concept ssh library integration This is a first attempt to integrate the golang crypto/ssh library for handling remote connections. As it stands, this features is limited to identity files without passphrases. It needs to be expanded to include key+passphrase and password verifications as well. * add password and keyboard-interactive ssh auth This adds several new ssh auth methods. In addition to the PublicKey method used previously, this adds password authentication, keyboard-interactive authentication, and PublicKey+Passphrase authentication. Furthermore, it refactores the ssh connection code into its own wavesrv file rather than storing int in waveshell's shexec file. * clean up old mshell launch methods In the debugging the addition of the ssh library, i had several versions of the MShellProc Launch function. Since this seems mostly stable, I have removed the old version and the experimental version in favor of the combined version. * allow switching between new and old ssh for dev It is inconvenient to create milestones without being able to merge into the main branch. But due to the experimental nature of the ssh changes, it is not desired to use these changes in the main branch yet. This change disables the new ssh launcher by default. It can be used by changing the UseSshLibrary constant to true in remote.go. With this, it becomes possible to merge these changes into the main branch without them being used in production. * fix: allow retry after ssh auth failure Previously, the error status was not set when an ssh connection failed. Because of this, an ssh connection failure would lock the failed remote until waveterm was rebooted. This fix properly sets the error status so this cannot happen.
218 lines
5.8 KiB
Go
218 lines
5.8 KiB
Go
// Copyright 2023, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package shexec
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os/exec"
|
|
"time"
|
|
|
|
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
|
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/mod/semver"
|
|
)
|
|
|
|
// TODO - track buffer sizes for sending input
|
|
|
|
const NotFoundVersion = "v0.0"
|
|
|
|
type CmdWrap struct {
|
|
Cmd *exec.Cmd
|
|
}
|
|
|
|
func (cw CmdWrap) Kill() {
|
|
cw.Cmd.Process.Kill()
|
|
}
|
|
|
|
func (cw CmdWrap) Wait() error {
|
|
return cw.Cmd.Wait()
|
|
}
|
|
|
|
func (cw CmdWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
|
|
inputWriter, err := cw.Cmd.StdinPipe()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
|
|
}
|
|
sender := packet.MakePacketSender(inputWriter, nil)
|
|
return sender, inputWriter, nil
|
|
}
|
|
|
|
func (cw CmdWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
|
|
stdoutReader, err := cw.Cmd.StdoutPipe()
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
|
|
}
|
|
stderrReader, err := cw.Cmd.StderrPipe()
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
|
|
}
|
|
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
|
|
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
|
|
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
|
|
return packetParser, stdoutReader, stderrReader, nil
|
|
}
|
|
|
|
func (cw CmdWrap) Start() error {
|
|
return cw.Cmd.Start()
|
|
}
|
|
|
|
type SessionWrap struct {
|
|
Session *ssh.Session
|
|
StartCmd string
|
|
}
|
|
|
|
func (sw SessionWrap) Kill() {
|
|
sw.Session.Close()
|
|
}
|
|
|
|
func (sw SessionWrap) Wait() error {
|
|
return sw.Session.Wait()
|
|
}
|
|
|
|
func (sw SessionWrap) Start() error {
|
|
return sw.Session.Start(sw.StartCmd)
|
|
}
|
|
|
|
func (sw SessionWrap) Sender() (*packet.PacketSender, io.WriteCloser, error) {
|
|
inputWriter, err := sw.Session.StdinPipe()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
|
|
}
|
|
sender := packet.MakePacketSender(inputWriter, nil)
|
|
return sender, inputWriter, nil
|
|
}
|
|
|
|
func (sw SessionWrap) Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error) {
|
|
stdoutReader, err := sw.Session.StdoutPipe()
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
|
|
}
|
|
stderrReader, err := sw.Session.StderrPipe()
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
|
|
}
|
|
stdoutPacketParser := packet.MakePacketParser(stdoutReader, &packet.PacketParserOpts{IgnoreUntilValid: true})
|
|
stderrPacketParser := packet.MakePacketParser(stderrReader, nil)
|
|
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
|
|
return packetParser, io.NopCloser(stdoutReader), io.NopCloser(stderrReader), nil
|
|
}
|
|
|
|
type ConnInterface interface {
|
|
Kill()
|
|
Wait() error
|
|
Sender() (*packet.PacketSender, io.WriteCloser, error)
|
|
Parser() (*packet.PacketParser, io.ReadCloser, io.ReadCloser, error)
|
|
Start() error
|
|
}
|
|
|
|
type ClientProc struct {
|
|
Cmd ConnInterface
|
|
InitPk *packet.InitPacketType
|
|
StartTs time.Time
|
|
StdinWriter io.WriteCloser
|
|
StdoutReader io.ReadCloser
|
|
StderrReader io.ReadCloser
|
|
Input *packet.PacketSender
|
|
Output *packet.PacketParser
|
|
}
|
|
|
|
// returns (clientproc, initpk, error)
|
|
func MakeClientProc(ctx context.Context, ecmd ConnInterface) (*ClientProc, *packet.InitPacketType, error) {
|
|
startTs := time.Now()
|
|
sender, inputWriter, err := ecmd.Sender()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
packetParser, stdoutReader, stderrReader, err := ecmd.Parser()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
err = ecmd.Start()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("running local client: %w", err)
|
|
}
|
|
cproc := &ClientProc{
|
|
Cmd: ecmd,
|
|
StartTs: startTs,
|
|
StdinWriter: inputWriter,
|
|
StdoutReader: stdoutReader,
|
|
StderrReader: stderrReader,
|
|
Input: sender,
|
|
Output: packetParser,
|
|
}
|
|
|
|
var pk packet.PacketType
|
|
select {
|
|
case pk = <-packetParser.MainCh:
|
|
case <-ctx.Done():
|
|
cproc.Close()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
if pk != nil {
|
|
if pk.GetType() != packet.InitPacketStr {
|
|
cproc.Close()
|
|
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
|
|
}
|
|
initPk := pk.(*packet.InitPacketType)
|
|
if initPk.NotFound {
|
|
cproc.Close()
|
|
return nil, initPk, fmt.Errorf("mshell client not found")
|
|
}
|
|
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
|
|
cproc.Close()
|
|
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
|
|
}
|
|
cproc.InitPk = initPk
|
|
}
|
|
if cproc.InitPk == nil {
|
|
cproc.Close()
|
|
return nil, nil, fmt.Errorf("no init packet received from mshell client")
|
|
}
|
|
return cproc, cproc.InitPk, nil
|
|
}
|
|
|
|
func (cproc *ClientProc) Close() {
|
|
if cproc.Input != nil {
|
|
cproc.Input.Close()
|
|
}
|
|
if cproc.StdinWriter != nil {
|
|
cproc.StdinWriter.Close()
|
|
}
|
|
if cproc.StdoutReader != nil {
|
|
cproc.StdoutReader.Close()
|
|
}
|
|
if cproc.StderrReader != nil {
|
|
cproc.StderrReader.Close()
|
|
}
|
|
if cproc.Cmd != nil {
|
|
cproc.Cmd.Kill()
|
|
}
|
|
}
|
|
|
|
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
|
|
sentDonePk := false
|
|
for pk := range cproc.Output.MainCh {
|
|
if packetCallback != nil {
|
|
packetCallback(pk)
|
|
}
|
|
if pk.GetType() == packet.CmdDonePacketStr {
|
|
sentDonePk = true
|
|
}
|
|
sender.SendPacket(pk)
|
|
}
|
|
exitErr := cproc.Cmd.Wait()
|
|
if !sentDonePk {
|
|
endTs := time.Now()
|
|
cmdDuration := endTs.Sub(cproc.StartTs)
|
|
donePacket := packet.MakeCmdDonePacket(ck)
|
|
donePacket.Ts = endTs.UnixMilli()
|
|
donePacket.ExitCode = GetExitCode(exitErr)
|
|
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
|
sender.SendPacket(donePacket)
|
|
}
|
|
}
|