checkpoint -- cleanup and sync optimizations for remote client (basically working). beginning work on local client

This commit is contained in:
sawka 2022-06-24 00:02:18 -07:00
parent 52831dc723
commit 4256ff5231
6 changed files with 222 additions and 29 deletions

View File

@ -9,6 +9,7 @@ package main
import (
"fmt"
"os"
"os/exec"
"os/signal"
"os/user"
"strings"
@ -234,6 +235,11 @@ func handleRemote() {
runPacket, _ = pk.(*packet.RunPacketType)
break
}
if pk.GetType() == packet.RawPacketStr {
rawPk := pk.(*packet.RawPacketType)
sender.SendMessage("got raw packet '%s'", rawPk.Data)
continue
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
return
}
@ -251,26 +257,128 @@ func handleRemote() {
func handleServer() {
}
func handleClient() {
fmt.Printf("mshell client\n")
func detectOpenFds() {
}
func handleUsage(extended bool) {
type ClientOpts struct {
IsSSH bool
SSHOptsTerm bool
SSHOpts []string
Command string
Fds []packet.RemoteFd
Cwd string
}
func parseClientOpts() (*ClientOpts, error) {
opts := &ClientOpts{}
iter := base.MakeOptsIter(os.Args[1:])
for iter.HasNext() {
argStr := iter.Next()
if argStr == "--ssh" {
if opts.IsSSH {
return nil, fmt.Errorf("duplicate '--ssh' option")
}
opts.IsSSH = true
break
}
}
if opts.IsSSH {
// parse SSH opts
for iter.HasNext() {
argStr := iter.Next()
if argStr == "--" {
opts.SSHOptsTerm = true
break
}
if argStr == "--cwd" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
}
opts.SSHOpts = append(opts.SSHOpts, argStr)
}
if !opts.SSHOptsTerm {
return nil, fmt.Errorf("ssh options must be terminated with '--' followed by [command]")
}
if !iter.HasNext() {
return nil, fmt.Errorf("no command specified")
}
opts.Command = strings.Join(iter.Rest(), " ")
if strings.TrimSpace(opts.Command) == "" {
return nil, fmt.Errorf("no command or empty command specified")
}
}
return opts, nil
}
func handleClient() (int, error) {
fmt.Printf("mshell client\n")
opts, err := parseClientOpts()
if err != nil {
return 1, fmt.Errorf("parsing opts: %w", err)
}
if !opts.IsSSH {
return 1, fmt.Errorf("when running in client mode '--ssh' option must be present")
}
fmt.Printf("opts: %v\n", opts)
sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote`
sshOpts := append(opts.SSHOpts, sshRemoteCommand)
ecmd := exec.Command("ssh", sshOpts...)
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return 1, fmt.Errorf("creating stdin pipe: %v", err)
}
outputReader, err := ecmd.StdoutPipe()
if err != nil {
return 1, fmt.Errorf("creating stdout pipe: %v", err)
}
ecmd.Stderr = ecmd.Stdout
err = ecmd.Start()
if err != nil {
return 1, fmt.Errorf("running ssh command: %w", err)
}
parser := packet.PacketParser(outputReader)
go func() {
fmt.Printf("%v %v\n", parser, inputWriter)
}()
exitErr := ecmd.Wait()
return shexec.GetExitCode(exitErr), nil
}
func handleUsage() {
usage := `
Client Usage: mshell [mshell-opts] [ssh-opts] user@host [command]
Client Usage: mshell [mshell-opts] --ssh [ssh-opts] user@host -- [command]
mshell multiplexes input and output streams to a remote command over ssh.
Options:
--env 'X=Y,A=B' - set remote environment variables for command, comma or newline separated
--env 'X=Y;A=B' - set remote environment variables for command, semicolon separated
--env-file [file] - load environment variables from [file] (.env format)
--env-copy [glob] - copy local environment variables to remote using [glob] pattern
--cwd [dir] - execute remote command in [dir]
--no-auto-fds - do not auto-detect additional fds
--sudo - execute "sudo [command]"
--fds [fdspec] - open fds based off [fdspec], comma separated (implies --no-auto-fds)
<[num] opens for reading
>[num] opens for writing
e.g. --fds '<5,>6,>7'
[command] - a single argument (should be quoted)
Examples:
# execute a python script remotely, with stdin still hooked up correctly
mshell --cwd "~/work" --ssh -i key.pem ubuntu@somehost -- "python /dev/fd/4" 4< myscript.py
# capture multiple outputs
mshell --ssh ubuntu@test -- "cat file1.txt > /dev/fd/3; cat file2.txt > /dev/fd/4" 3> file1.txt 4> file2.txt
# environment variable copying, setting working directory
# note the single quotes on command (otherwise the local shell will expand the variables)
TEST1=hello TEST2=world mshell --cwd "~/work" --env-copy "TEST*" --ssh user@host -- 'echo $(pwd) $TEST1 $TEST2'
# execute a script, catpure stdout/stderr in fd-3 and fd-4
# useful if you need to see stdout for interacting with ssh (password or host auth)
mshell --ssh user@host -- "test.sh > /dev/fd/3 2> /dev/fd/4" 3> test.stdout 4> test.stderr
mshell is licensed under the MPLv2
Please see https://github.com/scripthaus-dev/mshell for extended usage modes, source code, bugs, and feature requests
@ -280,12 +388,12 @@ Please see https://github.com/scripthaus-dev/mshell for extended usage modes, so
func main() {
if len(os.Args) == 1 {
handleUsage(false)
handleUsage()
return
}
firstArg := os.Args[1]
if firstArg == "--help" {
handleUsage(true)
handleUsage()
return
} else if firstArg == "--version" {
fmt.Printf("mshell v%s\n", MShellVersion)
@ -297,7 +405,11 @@ func main() {
handleServer()
return
} else {
handleClient()
rtnCode, err := handleClient()
if err != nil {
fmt.Printf("[error] %v\n", err)
}
os.Exit(rtnCode)
return
}

39
pkg/base/optsiter.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package base
import "strings"
type OptsIter struct {
Pos int
Opts []string
}
func MakeOptsIter(opts []string) *OptsIter {
return &OptsIter{Opts: opts}
}
func IsOption(argStr string) bool {
return strings.HasPrefix(argStr, "-") && argStr != "-" && !strings.HasPrefix(argStr, "-/")
}
func (iter *OptsIter) HasNext() bool {
return iter.Pos <= len(iter.Opts)-1
}
func (iter *OptsIter) Next() string {
if iter.Pos >= len(iter.Opts) {
return ""
}
rtn := iter.Opts[iter.Pos]
iter.Pos++
return rtn
}
func (iter *OptsIter) Rest() []string {
return iter.Opts[iter.Pos:]
}

View File

@ -43,6 +43,8 @@ const (
InputPacketStr = "input"
)
const PacketSenderQueueSize = 20
var TypeStrToFactory map[string]reflect.Type
func init() {
@ -450,7 +452,7 @@ type PacketSender struct {
func MakePacketSender(output io.Writer) *PacketSender {
sender := &PacketSender{
Lock: &sync.Mutex{},
SendCh: make(chan PacketType),
SendCh: make(chan PacketType, PacketSenderQueueSize),
DoneCh: make(chan bool),
}
go func() {

View File

@ -50,6 +50,9 @@ func (r *FdReader) Close() {
func (r *FdReader) NotifyAck(ackLen int) {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
if r.Closed {
return
}
r.BufSize -= ackLen
if r.BufSize < 0 {
r.BufSize = 0
@ -57,11 +60,17 @@ func (r *FdReader) NotifyAck(ackLen int) {
r.CVar.Broadcast()
}
// !! inverse locking. must already hold the lock when you call this method.
// will *unlock*, send the packet, and then *relock* once it is done.
// this can prevent an unlikely deadlock where we are holding r.CVar.L and stuck on sender.SendCh
func (r *FdReader) sendPacket_unlock(sender *packet.PacketSender, pk packet.PacketType) {
r.CVar.L.Unlock()
defer r.CVar.L.Lock()
sender.SendPacket(pk)
}
// returns (success)
func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof bool) bool {
if len(data) == 0 {
return true
}
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
for {
@ -75,13 +84,15 @@ func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof boo
}
writeLen := min(bufAvail, len(data))
pk := r.MakeDataPacket(data[0:writeLen], nil)
sender.SendPacket(pk)
pk.Eof = isEof && (writeLen == len(data))
r.BufSize += writeLen
data = data[writeLen:]
r.sendPacket_unlock(sender, pk)
if len(data) == 0 {
return true
}
r.CVar.Wait()
// do *not* do a CVar.Wait() here -- because we *unlocked* to send the packet, we should
// recheck the condition before waiting to avoid deadlock.
}
}
@ -104,12 +115,21 @@ func (r *FdReader) MakeDataPacket(data []byte, err error) *packet.DataPacketType
return pk
}
func (r *FdReader) isClosed() bool {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
return r.Closed
}
func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) {
defer r.Close()
defer wg.Done()
buf := make([]byte, 4096)
for {
nr, err := r.Fd.Read(buf)
if r.isClosed() {
return // should not send data or error if we already closed the fd
}
if nr > 0 || err == io.EOF {
isOpen := r.WriteWait(sender, buf[0:nr], (err == io.EOF))
if !isOpen {

View File

@ -14,6 +14,8 @@ import (
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const MaxSingleWriteSize = 4 * 1024
type FdWriter struct {
CVar *sync.Cond
SessionId string
@ -97,16 +99,20 @@ func (w *FdWriter) WriteLoop(sender *packet.PacketSender) {
defer w.Close()
for {
data, isEof := w.WaitForData()
if w.Closed {
return
}
if len(data) > 0 {
nw, err := w.Fd.Write(data)
// chunk the writes to make sure we send ample ack packets
for len(data) > 0 {
if w.Closed {
return
}
chunkSize := min(len(data), MaxSingleWriteSize)
chunk := data[0:chunkSize]
nw, err := w.Fd.Write(chunk)
ack := w.MakeDataAckPacket(nw, err)
sender.SendPacket(ack)
if err != nil {
return
}
data = data[chunkSize:]
}
if isEof {
return

View File

@ -270,7 +270,7 @@ func (cmd *ShExecType) launchWriters(sender *packet.PacketSender) {
}
}
func (cmd *ShExecType) writeDataPacket(dataPacket *packet.DataPacketType) error {
func (cmd *ShExecType) processDataPacket(dataPacket *packet.DataPacketType) error {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fw := cmd.FdWriters[dataPacket.FdNum]
@ -289,18 +289,32 @@ func (cmd *ShExecType) writeDataPacket(dataPacket *packet.DataPacketType) error
return nil
}
func (cmd *ShExecType) runMainWriteLoop(packetCh chan packet.PacketType, sender *packet.PacketSender) {
func (cmd *ShExecType) processAckPacket(ackPacket *packet.DataAckPacketType) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fr := cmd.FdReaders[ackPacket.FdNum]
if fr == nil {
return
}
fr.NotifyAck(ackPacket.AckLen)
}
func (cmd *ShExecType) runPacketInputLoop(packetCh chan packet.PacketType, sender *packet.PacketSender) {
for pk := range packetCh {
if pk.GetType() != packet.DataPacketStr {
// other packets are ignored
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := cmd.processDataPacket(dataPacket)
if err != nil {
errPacket := cmd.MakeDataAckPacket(dataPacket.FdNum, 0, err)
sender.SendPacket(errPacket)
}
continue
}
dataPacket := pk.(*packet.DataPacketType)
err := cmd.writeDataPacket(dataPacket)
if err != nil {
errPacket := cmd.MakeDataAckPacket(dataPacket.FdNum, 0, err)
sender.SendPacket(errPacket)
if pk.GetType() == packet.DataAckPacketStr {
ackPacket := pk.(*packet.DataAckPacketType)
cmd.processAckPacket(ackPacket)
}
// other packet types are ignored
}
}
@ -317,7 +331,7 @@ func (cmd *ShExecType) RunIOAndWait(packetCh chan packet.PacketType, sender *pac
var wg sync.WaitGroup
cmd.launchReaders(&wg, sender)
cmd.launchWriters(sender)
go cmd.runMainWriteLoop(packetCh, sender)
go cmd.runPacketInputLoop(packetCh, sender)
donePacket := cmd.WaitForCommand()
wg.Wait()
sender.SendPacket(donePacket)