got basic mshell --server functionality working to dispatch multiple commands

This commit is contained in:
sawka 2022-06-28 19:01:33 -07:00
parent 1d44afc10e
commit 9054c3cdcc
5 changed files with 120 additions and 84 deletions

View File

@ -80,34 +80,34 @@ func doSingle(ck base.CommandKey) {
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
err := shexec.ValidateRunPacket(pk)
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
return
}
fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
return
}
cmd, err := shexec.MakeRunnerExec(pk.CK)
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
return
}
cmdStdin, err := cmd.StdinPipe()
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
return
}
// touch ptyout file (should exist for tailer to work correctly)
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
return
}
ptyOutFd.Close() // just opened to create the file, can close right after
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
return
}
defer runnerOutFd.Close()
@ -115,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
cmd.Stderr = runnerOutFd
err = cmd.Start()
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
return
}
go func() {
err = packet.SendPacket(cmdStdin, pk)
if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err)))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
return
}
cmdStdin.Close()
@ -237,11 +237,6 @@ func handleSingle() {
runPacket, _ = pk.(*packet.RunPacketType)
break
}
if pk.GetType() == packet.RawPacketStr {
rawPk := pk.(*packet.RawPacketType)
sender.SendMessage("got raw packet '%s'", rawPk.Data)
continue
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
return
}
@ -251,7 +246,7 @@ func handleSingle() {
}
cmd, err := shexec.RunCommand(runPacket, sender)
if err != nil {
sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err))
sender.SendCKErrorPacket(runPacket.CK, fmt.Sprintf("error running command: %v", err))
return
}
defer cmd.Close()
@ -433,7 +428,7 @@ func handleClient() (int, error) {
if err != nil {
return 1, err
}
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, opts.Debug)
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
if err != nil {
return 1, err
}

View File

@ -31,16 +31,21 @@ type Multiplexer struct {
Sender *packet.PacketSender
Input *packet.PacketParser
Started bool
UPR packet.UnknownPacketReporter
Debug bool
}
func MakeMultiplexer(ck base.CommandKey) *Multiplexer {
func MakeMultiplexer(ck base.CommandKey, upr packet.UnknownPacketReporter) *Multiplexer {
if upr == nil {
upr = packet.DefaultUPR{}
}
return &Multiplexer{
Lock: &sync.Mutex{},
CK: ck,
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
UPR: upr,
}
}
@ -207,17 +212,7 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
donePacket := pk.(*packet.CmdDonePacketType)
return donePacket
}
if pk.GetType() == packet.ErrorPacketStr {
errPacket := pk.(*packet.ErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
return nil
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
continue
}
m.UPR.UnknownPacket(pk)
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"reflect"
"sync"
@ -609,33 +610,30 @@ func (sender *PacketSender) SendErrorPacket(errVal string) error {
return sender.SendPacket(MakeErrorPacket(errVal))
}
func (sender *PacketSender) SendCKErrorPacket(ck base.CommandKey, errVal string) error {
return sender.SendPacket(MakeCKErrorPacket(ck, errVal))
}
func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error {
return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...)))
}
type ErrorReporter interface {
ReportError(err error)
type UnknownPacketReporter interface {
UnknownPacket(pk PacketType)
}
func PacketToByteArrBridge(pkCh chan PacketType, byteCh chan []byte, errorReporter ErrorReporter, closeOnDone bool) {
go func() {
defer func() {
if closeOnDone {
close(byteCh)
type DefaultUPR struct{}
func (DefaultUPR) UnknownPacket(pk PacketType) {
if pk.GetType() == ErrorPacketStr {
errPacket := pk.(*ErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "[error] %s\n", errPacket.Error)
} else if pk.GetType() == RawPacketStr {
rawPacket := pk.(*RawPacketType)
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
} else {
fmt.Fprintf(os.Stderr, "[error] invalid packet received '%s'", AsExtType(pk))
}
}()
for pk := range pkCh {
if pk == nil {
continue
}
jsonBytes, err := json.Marshal(pk)
if err != nil {
if errorReporter != nil {
errorReporter.ReportError(fmt.Errorf("error marshaling packet: %w", err))
}
continue
}
byteCh <- jsonBytes
}
}()
}

View File

@ -22,7 +22,8 @@ type MServer struct {
Lock *sync.Mutex
MainInput *packet.PacketParser
Sender *packet.PacketSender
FdContext *serverFdContext
FdContextMap map[base.CommandKey]*serverFdContext
Debug bool
}
func (m *MServer) Close() {
@ -38,17 +39,6 @@ type serverFdContext struct {
Readers map[int]*mpio.PacketReader
}
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
rtn := &serverFdContext{
M: m,
Lock: &sync.Mutex{},
Sender: m.Sender,
CK: ck,
Readers: make(map[int]*mpio.PacketReader),
}
return rtn
}
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
c.Lock.Lock()
reader := c.Readers[pk.FdNum]
@ -62,8 +52,44 @@ func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
return
}
reader.AddData(pk)
}
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
rtn := &serverFdContext{
M: m,
Lock: &sync.Mutex{},
Sender: m.Sender,
CK: ck,
Readers: make(map[int]*mpio.PacketReader),
}
return rtn
}
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
m.Sender.SendErrorPacket(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
return
}
m.Lock.Lock()
fdContext := m.FdContextMap[ck]
m.Lock.Unlock()
if fdContext == nil {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("no server context for ck '%s'", ck))
return
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
fdContext.processDataPacket(dataPacket)
return
} else if pk.GetType() == packet.DataAckPacketStr {
m.Sender.SendPacket(pk)
return
} else {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("invalid packet '%s' received", packet.AsExtType(pk)))
return
}
}
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
@ -78,21 +104,42 @@ func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
}
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
if err := runPacket.CK.Validate("packet"); err != nil {
m.Sender.SendErrorPacket(fmt.Sprintf("server run packets require valid ck: %s", err))
return
}
fdContext := m.MakeServerFdContext(runPacket.CK)
m.Lock.Lock()
m.FdContext = fdContext
m.FdContextMap[runPacket.CK] = fdContext
m.Lock.Unlock()
go func() {
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, true)
fmt.Printf("done: err:%v, %v\n", err, donePk)
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, m, m.Debug)
if donePk != nil {
m.Sender.SendPacket(donePk)
}
if err != nil {
m.Sender.SendCKErrorPacket(runPacket.CK, err.Error())
}
}()
}
func (m *MServer) UnknownPacket(pk packet.PacketType) {
m.Sender.SendPacket(pk)
}
func RunServer() (int, error) {
debug := false
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
debug = true
}
server := &MServer{
Lock: &sync.Mutex{},
FdContextMap: make(map[base.CommandKey]*serverFdContext),
Debug: debug,
}
if debug {
packet.GlobalDebug = true
}
server.MainInput = packet.MakePacketParser(os.Stdin)
server.Sender = packet.MakePacketSender(os.Stdout)
defer server.Close()
@ -101,7 +148,9 @@ func RunServer() (int, error) {
initPacket.Version = base.MShellVersion
server.Sender.SendPacket(initPacket)
for pk := range server.MainInput.MainCh {
if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.PingPacketStr {
continue
}
@ -110,9 +159,8 @@ func RunServer() (int, error) {
server.runCommand(runPacket)
continue
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
server.FdContext.processDataPacket(dataPacket)
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
server.ProcessCommandPacket(cmdPk)
continue
}
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk)))

View File

@ -99,12 +99,12 @@ type FdContext interface {
GetReader(fdNum int) io.ReadCloser
}
func MakeShExec(ck base.CommandKey) *ShExecType {
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
return &ShExecType{
Lock: &sync.Mutex{},
StartTs: time.Now(),
CK: ck,
Multiplexer: mpio.MakeMultiplexer(ck),
Multiplexer: mpio.MakeMultiplexer(ck, upr),
}
}
@ -524,8 +524,8 @@ func HasDupStdin(fds []packet.RemoteFd) bool {
return false
}
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) {
cmd := MakeShExec("")
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, upr packet.UnknownPacketReporter, debug bool) (*packet.CmdDonePacketType, error) {
cmd := MakeShExec(runPacket.CK, upr)
ecmd := sshOpts.MakeSSHExecCmd(ClientCommand)
cmd.Cmd = ecmd
inputWriter, err := ecmd.StdinPipe()
@ -656,7 +656,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sen
}
func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
cmd := MakeShExec(pk.CK)
cmd := MakeShExec(pk.CK, nil)
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(cmd.Cmd, pk.Env)
if pk.Cwd != "" {
@ -736,7 +736,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
defer func() {
cmdTty.Close()
}()
rtn := MakeShExec(pk.CK)
rtn := MakeShExec(pk.CK, nil)
ecmd := MakeExecCmd(pk, cmdTty)
err = ecmd.Start()
if err != nil {
@ -750,14 +750,14 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
// copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmdPty)
if copyErr != nil {
sender.SendErrorPacket(fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
}
}()
go func() {
// copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
if copyFifoErr != nil {
sender.SendErrorPacket(fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
}
}()
rtn.FileNames = fileNames