got basic mshell --server functionality working to dispatch multiple commands

This commit is contained in:
sawka 2022-06-28 19:01:33 -07:00
parent 1d44afc10e
commit 9054c3cdcc
5 changed files with 120 additions and 84 deletions

View File

@ -80,34 +80,34 @@ func doSingle(ck base.CommandKey) {
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
err := shexec.ValidateRunPacket(pk) err := shexec.ValidateRunPacket(pk)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
return return
} }
fileNames, err := base.GetCommandFileNames(pk.CK) fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
return return
} }
cmd, err := shexec.MakeRunnerExec(pk.CK) cmd, err := shexec.MakeRunnerExec(pk.CK)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
return return
} }
cmdStdin, err := cmd.StdinPipe() cmdStdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
return return
} }
// touch ptyout file (should exist for tailer to work correctly) // touch ptyout file (should exist for tailer to work correctly)
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
return return
} }
ptyOutFd.Close() // just opened to create the file, can close right after ptyOutFd.Close() // just opened to create the file, can close right after
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
return return
} }
defer runnerOutFd.Close() defer runnerOutFd.Close()
@ -115,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
cmd.Stderr = runnerOutFd cmd.Stderr = runnerOutFd
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
return return
} }
go func() { go func() {
err = packet.SendPacket(cmdStdin, pk) err = packet.SendPacket(cmdStdin, pk)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
return return
} }
cmdStdin.Close() cmdStdin.Close()
@ -237,11 +237,6 @@ func handleSingle() {
runPacket, _ = pk.(*packet.RunPacketType) runPacket, _ = pk.(*packet.RunPacketType)
break break
} }
if pk.GetType() == packet.RawPacketStr {
rawPk := pk.(*packet.RawPacketType)
sender.SendMessage("got raw packet '%s'", rawPk.Data)
continue
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType())) sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
return return
} }
@ -251,7 +246,7 @@ func handleSingle() {
} }
cmd, err := shexec.RunCommand(runPacket, sender) cmd, err := shexec.RunCommand(runPacket, sender)
if err != nil { if err != nil {
sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err)) sender.SendCKErrorPacket(runPacket.CK, fmt.Sprintf("error running command: %v", err))
return return
} }
defer cmd.Close() defer cmd.Close()
@ -433,7 +428,7 @@ func handleClient() (int, error) {
if err != nil { if err != nil {
return 1, err return 1, err
} }
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, opts.Debug) donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
if err != nil { if err != nil {
return 1, err return 1, err
} }

View File

@ -31,16 +31,21 @@ type Multiplexer struct {
Sender *packet.PacketSender Sender *packet.PacketSender
Input *packet.PacketParser Input *packet.PacketParser
Started bool Started bool
UPR packet.UnknownPacketReporter
Debug bool Debug bool
} }
func MakeMultiplexer(ck base.CommandKey) *Multiplexer { func MakeMultiplexer(ck base.CommandKey, upr packet.UnknownPacketReporter) *Multiplexer {
if upr == nil {
upr = packet.DefaultUPR{}
}
return &Multiplexer{ return &Multiplexer{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
CK: ck, CK: ck,
FdReaders: make(map[int]*FdReader), FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter), FdWriters: make(map[int]*FdWriter),
UPR: upr,
} }
} }
@ -207,17 +212,7 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
donePacket := pk.(*packet.CmdDonePacketType) donePacket := pk.(*packet.CmdDonePacketType)
return donePacket return donePacket
} }
if pk.GetType() == packet.ErrorPacketStr { m.UPR.UnknownPacket(pk)
errPacket := pk.(*packet.ErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
return nil
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
continue
}
} }
return nil return nil
} }

View File

@ -11,6 +11,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"os"
"reflect" "reflect"
"sync" "sync"
@ -609,33 +610,30 @@ func (sender *PacketSender) SendErrorPacket(errVal string) error {
return sender.SendPacket(MakeErrorPacket(errVal)) return sender.SendPacket(MakeErrorPacket(errVal))
} }
func (sender *PacketSender) SendCKErrorPacket(ck base.CommandKey, errVal string) error {
return sender.SendPacket(MakeCKErrorPacket(ck, errVal))
}
func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error { func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error {
return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...))) return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...)))
} }
type ErrorReporter interface { type UnknownPacketReporter interface {
ReportError(err error) UnknownPacket(pk PacketType)
} }
func PacketToByteArrBridge(pkCh chan PacketType, byteCh chan []byte, errorReporter ErrorReporter, closeOnDone bool) { type DefaultUPR struct{}
go func() {
defer func() { func (DefaultUPR) UnknownPacket(pk PacketType) {
if closeOnDone { if pk.GetType() == ErrorPacketStr {
close(byteCh) errPacket := pk.(*ErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "[error] %s\n", errPacket.Error)
} else if pk.GetType() == RawPacketStr {
rawPacket := pk.(*RawPacketType)
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
} else {
fmt.Fprintf(os.Stderr, "[error] invalid packet received '%s'", AsExtType(pk))
} }
}()
for pk := range pkCh {
if pk == nil {
continue
}
jsonBytes, err := json.Marshal(pk)
if err != nil {
if errorReporter != nil {
errorReporter.ReportError(fmt.Errorf("error marshaling packet: %w", err))
}
continue
}
byteCh <- jsonBytes
}
}()
} }

View File

@ -22,7 +22,8 @@ type MServer struct {
Lock *sync.Mutex Lock *sync.Mutex
MainInput *packet.PacketParser MainInput *packet.PacketParser
Sender *packet.PacketSender Sender *packet.PacketSender
FdContext *serverFdContext FdContextMap map[base.CommandKey]*serverFdContext
Debug bool
} }
func (m *MServer) Close() { func (m *MServer) Close() {
@ -38,17 +39,6 @@ type serverFdContext struct {
Readers map[int]*mpio.PacketReader Readers map[int]*mpio.PacketReader
} }
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
rtn := &serverFdContext{
M: m,
Lock: &sync.Mutex{},
Sender: m.Sender,
CK: ck,
Readers: make(map[int]*mpio.PacketReader),
}
return rtn
}
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) { func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
c.Lock.Lock() c.Lock.Lock()
reader := c.Readers[pk.FdNum] reader := c.Readers[pk.FdNum]
@ -62,7 +52,43 @@ func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
return return
} }
reader.AddData(pk) reader.AddData(pk)
}
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
rtn := &serverFdContext{
M: m,
Lock: &sync.Mutex{},
Sender: m.Sender,
CK: ck,
Readers: make(map[int]*mpio.PacketReader),
}
return rtn
}
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
m.Sender.SendErrorPacket(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
return return
}
m.Lock.Lock()
fdContext := m.FdContextMap[ck]
m.Lock.Unlock()
if fdContext == nil {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("no server context for ck '%s'", ck))
return
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
fdContext.processDataPacket(dataPacket)
return
} else if pk.GetType() == packet.DataAckPacketStr {
m.Sender.SendPacket(pk)
return
} else {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("invalid packet '%s' received", packet.AsExtType(pk)))
return
}
} }
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser { func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
@ -78,21 +104,42 @@ func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
} }
func (m *MServer) runCommand(runPacket *packet.RunPacketType) { func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
if err := runPacket.CK.Validate("packet"); err != nil {
m.Sender.SendErrorPacket(fmt.Sprintf("server run packets require valid ck: %s", err))
return
}
fdContext := m.MakeServerFdContext(runPacket.CK) fdContext := m.MakeServerFdContext(runPacket.CK)
m.Lock.Lock() m.Lock.Lock()
m.FdContext = fdContext m.FdContextMap[runPacket.CK] = fdContext
m.Lock.Unlock() m.Lock.Unlock()
go func() { go func() {
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, true) donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, m, m.Debug)
fmt.Printf("done: err:%v, %v\n", err, donePk) if donePk != nil {
m.Sender.SendPacket(donePk)
}
if err != nil {
m.Sender.SendCKErrorPacket(runPacket.CK, err.Error())
}
}() }()
} }
func (m *MServer) UnknownPacket(pk packet.PacketType) {
m.Sender.SendPacket(pk)
}
func RunServer() (int, error) { func RunServer() (int, error) {
debug := false
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
debug = true
}
server := &MServer{ server := &MServer{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
FdContextMap: make(map[base.CommandKey]*serverFdContext),
Debug: debug,
} }
if debug {
packet.GlobalDebug = true packet.GlobalDebug = true
}
server.MainInput = packet.MakePacketParser(os.Stdin) server.MainInput = packet.MakePacketParser(os.Stdin)
server.Sender = packet.MakePacketSender(os.Stdout) server.Sender = packet.MakePacketSender(os.Stdout)
defer server.Close() defer server.Close()
@ -101,7 +148,9 @@ func RunServer() (int, error) {
initPacket.Version = base.MShellVersion initPacket.Version = base.MShellVersion
server.Sender.SendPacket(initPacket) server.Sender.SendPacket(initPacket)
for pk := range server.MainInput.MainCh { for pk := range server.MainInput.MainCh {
if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk)) fmt.Printf("PK> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.PingPacketStr { if pk.GetType() == packet.PingPacketStr {
continue continue
} }
@ -110,9 +159,8 @@ func RunServer() (int, error) {
server.runCommand(runPacket) server.runCommand(runPacket)
continue continue
} }
if pk.GetType() == packet.DataPacketStr { if cmdPk, ok := pk.(packet.CommandPacketType); ok {
dataPacket := pk.(*packet.DataPacketType) server.ProcessCommandPacket(cmdPk)
server.FdContext.processDataPacket(dataPacket)
continue continue
} }
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk))) server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk)))

View File

@ -99,12 +99,12 @@ type FdContext interface {
GetReader(fdNum int) io.ReadCloser GetReader(fdNum int) io.ReadCloser
} }
func MakeShExec(ck base.CommandKey) *ShExecType { func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
return &ShExecType{ return &ShExecType{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
StartTs: time.Now(), StartTs: time.Now(),
CK: ck, CK: ck,
Multiplexer: mpio.MakeMultiplexer(ck), Multiplexer: mpio.MakeMultiplexer(ck, upr),
} }
} }
@ -524,8 +524,8 @@ func HasDupStdin(fds []packet.RemoteFd) bool {
return false return false
} }
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) { func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, upr packet.UnknownPacketReporter, debug bool) (*packet.CmdDonePacketType, error) {
cmd := MakeShExec("") cmd := MakeShExec(runPacket.CK, upr)
ecmd := sshOpts.MakeSSHExecCmd(ClientCommand) ecmd := sshOpts.MakeSSHExecCmd(ClientCommand)
cmd.Cmd = ecmd cmd.Cmd = ecmd
inputWriter, err := ecmd.StdinPipe() inputWriter, err := ecmd.StdinPipe()
@ -656,7 +656,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sen
} }
func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) { func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
cmd := MakeShExec(pk.CK) cmd := MakeShExec(pk.CK, nil)
cmd.Cmd = exec.Command("bash", "-c", pk.Command) cmd.Cmd = exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(cmd.Cmd, pk.Env) UpdateCmdEnv(cmd.Cmd, pk.Env)
if pk.Cwd != "" { if pk.Cwd != "" {
@ -736,7 +736,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
defer func() { defer func() {
cmdTty.Close() cmdTty.Close()
}() }()
rtn := MakeShExec(pk.CK) rtn := MakeShExec(pk.CK, nil)
ecmd := MakeExecCmd(pk, cmdTty) ecmd := MakeExecCmd(pk, cmdTty)
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
@ -750,14 +750,14 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
// copy pty output to .ptyout file // copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmdPty) _, copyErr := io.Copy(ptyOutFd, cmdPty)
if copyErr != nil { if copyErr != nil {
sender.SendErrorPacket(fmt.Sprintf("copying pty output to ptyout file: %v", copyErr)) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
} }
}() }()
go func() { go func() {
// copy .stdin fifo contents to pty input // copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo) copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
if copyFifoErr != nil { if copyFifoErr != nil {
sender.SendErrorPacket(fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr)) sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
} }
}() }()
rtn.FileNames = fileNames rtn.FileNames = fileNames