From d7eb2526f07d9a7518798d7610eb4c6936116451 Mon Sep 17 00:00:00 2001 From: sawka Date: Tue, 28 Jun 2022 15:04:08 -0700 Subject: [PATCH] refactor RunClientSSHCommandAndWait for server code --- main-mshell.go | 44 +++++++++++------- pkg/cmdtail/cmdtail.go | 12 ++--- pkg/packet/packet.go | 26 ++++++++--- pkg/server/server.go | 53 ++++++++++++++++++++++ pkg/shexec/shexec.go | 100 ++++++++++++++++------------------------- 5 files changed, 146 insertions(+), 89 deletions(-) create mode 100644 pkg/server/server.go diff --git a/main-mshell.go b/main-mshell.go index 103bf4511..d43aafd5f 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -19,6 +19,7 @@ import ( "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/cmdtail" "github.com/scripthaus-dev/mshell/pkg/packet" + "github.com/scripthaus-dev/mshell/pkg/server" "github.com/scripthaus-dev/mshell/pkg/shexec" "golang.org/x/sys/unix" ) @@ -72,7 +73,7 @@ func doSingle(ck base.CommandKey) { sender.SendPacket(startPacket) donePacket := cmd.WaitForCommand() sender.SendPacket(donePacket) - sender.CloseSendCh() + sender.Close() sender.WaitForDone() } @@ -157,7 +158,7 @@ func doMain() { } packetParser := packet.MakePacketParser(os.Stdin) sender := packet.MakePacketSender(os.Stdout) - tailer, err := cmdtail.MakeTailer(sender.SendCh) + tailer, err := cmdtail.MakeTailer(sender) if err != nil { packet.SendErrorPacket(os.Stdout, err.Error()) return @@ -215,8 +216,8 @@ func handleSingle() { sender := packet.MakePacketSender(os.Stdout) defer func() { // wait for sender to complete - close(sender.SendCh) - <-sender.DoneCh + sender.Close() + sender.WaitForDone() }() if len(os.Args) >= 3 && os.Args[2] == "--version" { initPacket := packet.MakeInitPacket() @@ -259,9 +260,6 @@ func handleSingle() { cmd.RunRemoteIOAndWait(packetParser, sender) } -func handleServer() { -} - func detectOpenFds() ([]packet.RemoteFd, error) { var fds []packet.RemoteFd for fdNum := 3; fdNum <= 64; fdNum++ { @@ -309,7 +307,7 @@ func parseInstallOpts() (*shexec.InstallOpts, error) { return opts, nil } -func tryParseSSHOpt(iter *base.OptsIter, sshOpts *shexec.SharedSSHOpts) (bool, error) { +func tryParseSSHOpt(iter *base.OptsIter, sshOpts *shexec.SSHOpts) (bool, error) { argStr := iter.Current() if argStr == "--ssh" { if !iter.IsNextPlain() { @@ -378,7 +376,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) { return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password") } opts.Sudo = true - opts.SudoWithPass = true + opts.SSHOpts.SudoWithPass = true opts.SudoPw = iter.Next() continue } @@ -387,7 +385,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) { return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file") } opts.Sudo = true - opts.SudoWithPass = true + opts.SSHOpts.SudoWithPass = true fileName := iter.Next() contents, err := os.ReadFile(fileName) if err != nil { @@ -427,7 +425,15 @@ func handleClient() (int, error) { return 1, err } opts.Fds = fds - donePacket, err := shexec.RunClientSSHCommandAndWait(opts) + err = shexec.ValidateRemoteFds(opts.Fds) + if err != nil { + return 1, err + } + runPacket, err := opts.MakeRunPacket() // modifies opts + if err != nil { + return 1, err + } + donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, opts.SSHOpts, opts.Debug) if err != nil { return 1, err } @@ -530,21 +536,29 @@ func main() { handleSingle() return } else if firstArg == "--server" { - handleServer() + rtnCode, err := server.RunServer() + if err != nil { + fmt.Fprintf(os.Stderr, "[error] %v\n", err) + } + if rtnCode != 0 { + os.Exit(rtnCode) + } return } else if firstArg == "--install" { rtnCode, err := handleInstall() if err != nil { - fmt.Printf("[error] %v\n", err) + fmt.Fprintf(os.Stderr, "[error] %v\n", err) } os.Exit(rtnCode) return } else { rtnCode, err := handleClient() if err != nil { - fmt.Printf("[error] %v\n", err) + fmt.Fprintf(os.Stderr, "[error] %v\n", err) + } + if rtnCode != 0 { + os.Exit(rtnCode) } - os.Exit(rtnCode) return } diff --git a/pkg/cmdtail/cmdtail.go b/pkg/cmdtail/cmdtail.go index 517b35f8c..663f8065a 100644 --- a/pkg/cmdtail/cmdtail.go +++ b/pkg/cmdtail/cmdtail.go @@ -77,7 +77,7 @@ type Tailer struct { WatchList map[base.CommandKey]CmdWatchEntry ScHomeDir string Watcher *fsnotify.Watcher - SendCh chan packet.PacketType + Sender *packet.PacketSender } func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) { @@ -129,7 +129,7 @@ func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (Cm return entry, pos, true } -func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) { +func MakeTailer(sender *packet.PacketSender) (*Tailer, error) { scHomeDir, err := base.GetScHomeDir() if err != nil { return nil, err @@ -138,7 +138,7 @@ func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) { Lock: &sync.Mutex{}, WatchList: make(map[base.CommandKey]CmdWatchEntry), ScHomeDir: scHomeDir, - SendCh: sendCh, + Sender: sender, } rtn.Watcher, err = fsnotify.NewWatcher() if err != nil { @@ -241,7 +241,7 @@ func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) { for { dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId) if dataPacket != nil { - t.SendCh <- dataPacket + t.Sender.SendPacket(dataPacket) } if !keepRunning { t.checkRemoveNoFollow(key, reqId) @@ -273,7 +273,7 @@ func (t *Tailer) updateFile(relFileName string) { } finfo, err := os.Stat(relFileName) if err != nil { - t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err) + t.Sender.SendPacket(packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err)) return } cmdKey := base.MakeCommandKey(m[1], m[2]) @@ -311,7 +311,7 @@ func (t *Tailer) Run() { return } // what to do with this error? just send a message - t.SendCh <- packet.FmtMessagePacket("error in tailer: %v", err) + t.Sender.SendPacket(packet.FmtMessagePacket("error in tailer: %v", err)) } } return diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 01191ed39..07f58c194 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -483,6 +483,16 @@ type CommandPacketType interface { GetCK() base.CommandKey } +func AsExtType(pk PacketType) string { + if rpcPacket, ok := pk.(RpcPacketType); ok { + return fmt.Sprintf("%s[%s]", rpcPacket.GetType(), rpcPacket.GetPacketId()) + } else if cmdPacket, ok := pk.(CommandPacketType); ok { + return fmt.Sprintf("%s[%s]", cmdPacket.GetType(), cmdPacket.GetCK()) + } else { + return pk.GetType() + } +} + func ParseJsonPacket(jsonBuf []byte) (PacketType, error) { var bareCmd BarePacketType err := json.Unmarshal(jsonBuf, &bareCmd) @@ -545,12 +555,8 @@ func MakePacketSender(output io.Writer) *PacketSender { DoneCh: make(chan bool), } go func() { - defer func() { - sender.Lock.Lock() - sender.Done = true - sender.Lock.Unlock() - close(sender.DoneCh) - }() + defer close(sender.DoneCh) + defer sender.Close() for pk := range sender.SendCh { err := SendPacket(output, pk) if err != nil { @@ -564,7 +570,13 @@ func MakePacketSender(output io.Writer) *PacketSender { return sender } -func (sender *PacketSender) CloseSendCh() { +func (sender *PacketSender) Close() { + sender.Lock.Lock() + defer sender.Lock.Unlock() + if sender.Done { + return + } + sender.Done = true close(sender.SendCh) } diff --git a/pkg/server/server.go b/pkg/server/server.go new file mode 100644 index 000000000..61bd17601 --- /dev/null +++ b/pkg/server/server.go @@ -0,0 +1,53 @@ +// Copyright 2022 Dashborg Inc +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package server + +import ( + "fmt" + "os" + "sync" + + "github.com/scripthaus-dev/mshell/pkg/base" + "github.com/scripthaus-dev/mshell/pkg/packet" +) + +type MServer struct { + Lock *sync.Mutex + MainInput *packet.PacketParser + Sender *packet.PacketSender +} + +func (m *MServer) Close() { + m.Sender.Close() + m.Sender.WaitForDone() +} + +func RunServer() (int, error) { + server := &MServer{ + Lock: &sync.Mutex{}, + } + server.MainInput = packet.MakePacketParser(os.Stdin) + server.Sender = packet.MakePacketSender(os.Stdout) + defer server.Close() + initPacket := packet.MakeInitPacket() + initPacket.Version = base.MShellVersion + server.Sender.SendPacket(initPacket) + for pk := range server.MainInput.MainCh { + fmt.Printf("PK> %s\n", packet.AsString(pk)) + if pk.GetType() == packet.PingPacketStr { + continue + } + if pk.GetType() == packet.RunPacketStr { + runPacket := pk.(*packet.RunPacketType) + fmt.Printf("RUN> %s\n", runPacket) + continue + } + server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk))) + continue + } + return 0, nil +} diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 9a45077bc..f0334e3f4 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -236,91 +236,74 @@ func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecT } } -type SharedSSHOpts struct { - SSHHost string - SSHOptsStr string - SSHIdentity string - SSHUser string +type SSHOpts struct { + SSHHost string + SSHOptsStr string + SSHIdentity string + SSHUser string + SudoWithPass bool } type InstallOpts struct { - SSHOpts SharedSSHOpts + SSHOpts SSHOpts ArchStr string OptName string Detect bool } type ClientOpts struct { - SSHOpts SharedSSHOpts + SSHOpts SSHOpts Command string Fds []packet.RemoteFd Cwd string Debug bool Sudo bool - SudoWithPass bool SudoPw string CommandStdinFdNum int Detach bool } -func (opts *ClientOpts) MakeExecCmd() *exec.Cmd { - if opts.SSHOpts.SSHHost == "" { - ecmd := exec.Command("bash", "-c", strings.TrimSpace(ClientCommand)) +func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string) *exec.Cmd { + remoteCommand = strings.TrimSpace(remoteCommand) + if opts.SSHHost == "" { + ecmd := exec.Command("bash", "-c", remoteCommand) return ecmd } else { var moreSSHOpts []string - if opts.SSHOpts.SSHIdentity != "" { - identityOpt := fmt.Sprintf("-i %s", shellescape.Quote(opts.SSHOpts.SSHIdentity)) + if opts.SSHIdentity != "" { + identityOpt := fmt.Sprintf("-i %s", shellescape.Quote(opts.SSHIdentity)) moreSSHOpts = append(moreSSHOpts, identityOpt) } - if opts.SSHOpts.SSHUser != "" { - userOpt := fmt.Sprintf("-l %s", shellescape.Quote(opts.SSHOpts.SSHUser)) + if opts.SSHUser != "" { + userOpt := fmt.Sprintf("-l %s", shellescape.Quote(opts.SSHUser)) moreSSHOpts = append(moreSSHOpts, userOpt) } - remoteCommand := strings.TrimSpace(ClientCommand) // note that SSHOptsStr is *not* escaped - sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOpts.SSHOptsStr, shellescape.Quote(opts.SSHOpts.SSHHost), shellescape.Quote(remoteCommand)) + sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOptsStr, shellescape.Quote(opts.SSHHost), shellescape.Quote(remoteCommand)) ecmd := exec.Command("bash", "-c", sshCmd) return ecmd } } -func (opts *InstallOpts) MakeExecCmd() *exec.Cmd { +func (opts SSHOpts) MakeMShellSSHOpts() string { var moreSSHOpts []string - if opts.SSHOpts.SSHIdentity != "" { - identityOpt := fmt.Sprintf("-i %s", shellescape.Quote(opts.SSHOpts.SSHIdentity)) + if opts.SSHIdentity != "" { + identityOpt := fmt.Sprintf("-i %s", shellescape.Quote(opts.SSHIdentity)) moreSSHOpts = append(moreSSHOpts, identityOpt) } - if opts.SSHOpts.SSHUser != "" { - userOpt := fmt.Sprintf("-l %s", shellescape.Quote(opts.SSHOpts.SSHUser)) + if opts.SSHUser != "" { + userOpt := fmt.Sprintf("-l %s", shellescape.Quote(opts.SSHUser)) moreSSHOpts = append(moreSSHOpts, userOpt) } - // note that SSHOptsStr is *not* escaped - command := strings.TrimSpace(InstallCommand) - sshCmd := fmt.Sprintf("ssh %s %s %s %s", strings.Join(moreSSHOpts, " "), opts.SSHOpts.SSHOptsStr, shellescape.Quote(opts.SSHOpts.SSHHost), shellescape.Quote(command)) - ecmd := exec.Command("bash", "-c", sshCmd) - return ecmd -} - -func (opts *ClientOpts) MakeInstallCommandString(goos string, goarch string) string { - var moreSSHOpts []string - if opts.SSHOpts.SSHIdentity != "" { - identityOpt := fmt.Sprintf("-i %s", shellescape.Quote(opts.SSHOpts.SSHIdentity)) - moreSSHOpts = append(moreSSHOpts, identityOpt) - } - if opts.SSHOpts.SSHUser != "" { - userOpt := fmt.Sprintf("-l %s", shellescape.Quote(opts.SSHOpts.SSHUser)) - moreSSHOpts = append(moreSSHOpts, userOpt) - } - if opts.SSHOpts.SSHOptsStr != "" { - optsOpt := fmt.Sprintf("--ssh-opts %s", shellescape.Quote(opts.SSHOpts.SSHOptsStr)) + if opts.SSHOptsStr != "" { + optsOpt := fmt.Sprintf("--ssh-opts %s", shellescape.Quote(opts.SSHOptsStr)) moreSSHOpts = append(moreSSHOpts, optsOpt) } - if opts.SSHOpts.SSHHost != "" { - sshArg := fmt.Sprintf("--ssh %s", shellescape.Quote(opts.SSHOpts.SSHHost)) + if opts.SSHHost != "" { + sshArg := fmt.Sprintf("--ssh %s", shellescape.Quote(opts.SSHHost)) moreSSHOpts = append(moreSSHOpts, sshArg) } - return fmt.Sprintf("mshell --install %s %s.%s", strings.Join(moreSSHOpts, " "), goos, goarch) + return strings.Join(moreSSHOpts, " ") } func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) { @@ -333,7 +316,7 @@ func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) { runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command) return runPacket, nil } - if opts.SudoWithPass { + if opts.SSHOpts.SudoWithPass { pwFdNum, err := opts.NextFreeFdNum() if err != nil { return nil, err @@ -434,7 +417,7 @@ func sendOptFile(input io.WriteCloser, optName string) error { func RunInstallSSHCommand(opts *InstallOpts) error { tryDetect := opts.Detect - ecmd := opts.MakeExecCmd() + ecmd := opts.SSHOpts.MakeSSHExecCmd(InstallCommand) inputWriter, err := ecmd.StdinPipe() if err != nil { return fmt.Errorf("creating stdin pipe: %v", err) @@ -497,17 +480,9 @@ func RunInstallSSHCommand(opts *InstallOpts) error { return fmt.Errorf("did not receive version string from client, install not successful") } -func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) { - err := ValidateRemoteFds(opts.Fds) - if err != nil { - return nil, err - } - runPacket, err := opts.MakeRunPacket() // modifies opts - if err != nil { - return nil, err - } +func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) { cmd := MakeShExec("") - ecmd := opts.MakeExecCmd() + ecmd := sshOpts.MakeSSHExecCmd(ClientCommand) cmd.Cmd = ecmd inputWriter, err := ecmd.StdinPipe() if err != nil { @@ -521,7 +496,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er if err != nil { return nil, fmt.Errorf("creating stderr pipe: %v", err) } - if !opts.SudoWithPass { + if !sshOpts.SudoWithPass { cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false) } cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false) @@ -567,6 +542,9 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er if pk.GetType() == packet.InitPacketStr { initPk := pk.(*packet.InitPacketType) if initPk.NotFound { + if sshOpts.SSHHost == "" { + return nil, fmt.Errorf("mshell command not found on local server") + } if initPk.UName == "" { return nil, fmt.Errorf("mshell command not found on remote server, no uname detected") } @@ -574,14 +552,14 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er if err != nil { return nil, fmt.Errorf("mshell command not found on remote server, architecture cannot be detected (might be incompatible with mshell): %w", err) } - installCmd := opts.MakeInstallCommandString(goos, goarch) - return nil, fmt.Errorf("mshell command not found on remote server, can install with '%s' (or --auto-install)", installCmd) + sshOptsStr := sshOpts.MakeMShellSSHOpts() + return nil, fmt.Errorf("mshell command not found on remote server, can install with 'mshell --install %s %s.%s'", sshOptsStr, goos, goarch) } if initPk.Version != base.MShellVersion { return nil, fmt.Errorf("invalid remote mshell version 'v%s', must be v%s", initPk.Version, base.MShellVersion) } versionOk = true - if opts.Debug { + if debug { fmt.Printf("VERSION> %s\n", initPk.Version) } break @@ -591,7 +569,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er return nil, fmt.Errorf("did not receive version from remote mshell") } sender.SendPacket(runPacket) - if opts.Debug { + if debug { cmd.Multiplexer.Debug = true } remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetParser, sender, false, true, true)