mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-17 20:51:55 +01:00
got basic mshell --server functionality working to dispatch multiple commands
This commit is contained in:
parent
1d44afc10e
commit
9054c3cdcc
@ -80,34 +80,34 @@ func doSingle(ck base.CommandKey) {
|
||||
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
err := shexec.ValidateRunPacket(pk)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
|
||||
return
|
||||
}
|
||||
fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
|
||||
return
|
||||
}
|
||||
cmd, err := shexec.MakeRunnerExec(pk.CK)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
|
||||
return
|
||||
}
|
||||
cmdStdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
|
||||
return
|
||||
}
|
||||
// touch ptyout file (should exist for tailer to work correctly)
|
||||
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
|
||||
return
|
||||
}
|
||||
ptyOutFd.Close() // just opened to create the file, can close right after
|
||||
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
|
||||
return
|
||||
}
|
||||
defer runnerOutFd.Close()
|
||||
@ -115,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
cmd.Stderr = runnerOutFd
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err = packet.SendPacket(cmdStdin, pk)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err)))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
|
||||
return
|
||||
}
|
||||
cmdStdin.Close()
|
||||
@ -237,11 +237,6 @@ func handleSingle() {
|
||||
runPacket, _ = pk.(*packet.RunPacketType)
|
||||
break
|
||||
}
|
||||
if pk.GetType() == packet.RawPacketStr {
|
||||
rawPk := pk.(*packet.RawPacketType)
|
||||
sender.SendMessage("got raw packet '%s'", rawPk.Data)
|
||||
continue
|
||||
}
|
||||
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
|
||||
return
|
||||
}
|
||||
@ -251,7 +246,7 @@ func handleSingle() {
|
||||
}
|
||||
cmd, err := shexec.RunCommand(runPacket, sender)
|
||||
if err != nil {
|
||||
sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err))
|
||||
sender.SendCKErrorPacket(runPacket.CK, fmt.Sprintf("error running command: %v", err))
|
||||
return
|
||||
}
|
||||
defer cmd.Close()
|
||||
@ -433,7 +428,7 @@ func handleClient() (int, error) {
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, opts.Debug)
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
@ -31,16 +31,21 @@ type Multiplexer struct {
|
||||
Sender *packet.PacketSender
|
||||
Input *packet.PacketParser
|
||||
Started bool
|
||||
UPR packet.UnknownPacketReporter
|
||||
|
||||
Debug bool
|
||||
}
|
||||
|
||||
func MakeMultiplexer(ck base.CommandKey) *Multiplexer {
|
||||
func MakeMultiplexer(ck base.CommandKey, upr packet.UnknownPacketReporter) *Multiplexer {
|
||||
if upr == nil {
|
||||
upr = packet.DefaultUPR{}
|
||||
}
|
||||
return &Multiplexer{
|
||||
Lock: &sync.Mutex{},
|
||||
CK: ck,
|
||||
FdReaders: make(map[int]*FdReader),
|
||||
FdWriters: make(map[int]*FdWriter),
|
||||
UPR: upr,
|
||||
}
|
||||
}
|
||||
|
||||
@ -207,17 +212,7 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||
donePacket := pk.(*packet.CmdDonePacketType)
|
||||
return donePacket
|
||||
}
|
||||
if pk.GetType() == packet.ErrorPacketStr {
|
||||
errPacket := pk.(*packet.ErrorPacketType)
|
||||
// at this point, just send the error packet to stderr rather than try to do something special
|
||||
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
|
||||
return nil
|
||||
}
|
||||
if pk.GetType() == packet.RawPacketStr {
|
||||
rawPacket := pk.(*packet.RawPacketType)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
|
||||
continue
|
||||
}
|
||||
m.UPR.UnknownPacket(pk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
@ -609,33 +610,30 @@ func (sender *PacketSender) SendErrorPacket(errVal string) error {
|
||||
return sender.SendPacket(MakeErrorPacket(errVal))
|
||||
}
|
||||
|
||||
func (sender *PacketSender) SendCKErrorPacket(ck base.CommandKey, errVal string) error {
|
||||
return sender.SendPacket(MakeCKErrorPacket(ck, errVal))
|
||||
}
|
||||
|
||||
func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error {
|
||||
return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...)))
|
||||
}
|
||||
|
||||
type ErrorReporter interface {
|
||||
ReportError(err error)
|
||||
type UnknownPacketReporter interface {
|
||||
UnknownPacket(pk PacketType)
|
||||
}
|
||||
|
||||
func PacketToByteArrBridge(pkCh chan PacketType, byteCh chan []byte, errorReporter ErrorReporter, closeOnDone bool) {
|
||||
go func() {
|
||||
defer func() {
|
||||
if closeOnDone {
|
||||
close(byteCh)
|
||||
type DefaultUPR struct{}
|
||||
|
||||
func (DefaultUPR) UnknownPacket(pk PacketType) {
|
||||
if pk.GetType() == ErrorPacketStr {
|
||||
errPacket := pk.(*ErrorPacketType)
|
||||
// at this point, just send the error packet to stderr rather than try to do something special
|
||||
fmt.Fprintf(os.Stderr, "[error] %s\n", errPacket.Error)
|
||||
} else if pk.GetType() == RawPacketStr {
|
||||
rawPacket := pk.(*RawPacketType)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "[error] invalid packet received '%s'", AsExtType(pk))
|
||||
}
|
||||
}()
|
||||
for pk := range pkCh {
|
||||
if pk == nil {
|
||||
continue
|
||||
}
|
||||
jsonBytes, err := json.Marshal(pk)
|
||||
if err != nil {
|
||||
if errorReporter != nil {
|
||||
errorReporter.ReportError(fmt.Errorf("error marshaling packet: %w", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
byteCh <- jsonBytes
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
@ -22,7 +22,8 @@ type MServer struct {
|
||||
Lock *sync.Mutex
|
||||
MainInput *packet.PacketParser
|
||||
Sender *packet.PacketSender
|
||||
FdContext *serverFdContext
|
||||
FdContextMap map[base.CommandKey]*serverFdContext
|
||||
Debug bool
|
||||
}
|
||||
|
||||
func (m *MServer) Close() {
|
||||
@ -38,17 +39,6 @@ type serverFdContext struct {
|
||||
Readers map[int]*mpio.PacketReader
|
||||
}
|
||||
|
||||
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
|
||||
rtn := &serverFdContext{
|
||||
M: m,
|
||||
Lock: &sync.Mutex{},
|
||||
Sender: m.Sender,
|
||||
CK: ck,
|
||||
Readers: make(map[int]*mpio.PacketReader),
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
|
||||
c.Lock.Lock()
|
||||
reader := c.Readers[pk.FdNum]
|
||||
@ -62,8 +52,44 @@ func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
|
||||
return
|
||||
}
|
||||
reader.AddData(pk)
|
||||
}
|
||||
|
||||
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
|
||||
rtn := &serverFdContext{
|
||||
M: m,
|
||||
Lock: &sync.Mutex{},
|
||||
Sender: m.Sender,
|
||||
CK: ck,
|
||||
Readers: make(map[int]*mpio.PacketReader),
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
ck := pk.GetCK()
|
||||
if ck == "" {
|
||||
m.Sender.SendErrorPacket(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
|
||||
return
|
||||
}
|
||||
m.Lock.Lock()
|
||||
fdContext := m.FdContextMap[ck]
|
||||
m.Lock.Unlock()
|
||||
if fdContext == nil {
|
||||
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("no server context for ck '%s'", ck))
|
||||
return
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
fdContext.processDataPacket(dataPacket)
|
||||
return
|
||||
} else if pk.GetType() == packet.DataAckPacketStr {
|
||||
m.Sender.SendPacket(pk)
|
||||
return
|
||||
} else {
|
||||
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("invalid packet '%s' received", packet.AsExtType(pk)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
|
||||
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
|
||||
@ -78,21 +104,42 @@ func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
|
||||
}
|
||||
|
||||
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
if err := runPacket.CK.Validate("packet"); err != nil {
|
||||
m.Sender.SendErrorPacket(fmt.Sprintf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
}
|
||||
fdContext := m.MakeServerFdContext(runPacket.CK)
|
||||
m.Lock.Lock()
|
||||
m.FdContext = fdContext
|
||||
m.FdContextMap[runPacket.CK] = fdContext
|
||||
m.Lock.Unlock()
|
||||
go func() {
|
||||
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, true)
|
||||
fmt.Printf("done: err:%v, %v\n", err, donePk)
|
||||
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, m, m.Debug)
|
||||
if donePk != nil {
|
||||
m.Sender.SendPacket(donePk)
|
||||
}
|
||||
if err != nil {
|
||||
m.Sender.SendCKErrorPacket(runPacket.CK, err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *MServer) UnknownPacket(pk packet.PacketType) {
|
||||
m.Sender.SendPacket(pk)
|
||||
}
|
||||
|
||||
func RunServer() (int, error) {
|
||||
debug := false
|
||||
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
|
||||
debug = true
|
||||
}
|
||||
server := &MServer{
|
||||
Lock: &sync.Mutex{},
|
||||
FdContextMap: make(map[base.CommandKey]*serverFdContext),
|
||||
Debug: debug,
|
||||
}
|
||||
if debug {
|
||||
packet.GlobalDebug = true
|
||||
}
|
||||
server.MainInput = packet.MakePacketParser(os.Stdin)
|
||||
server.Sender = packet.MakePacketSender(os.Stdout)
|
||||
defer server.Close()
|
||||
@ -101,7 +148,9 @@ func RunServer() (int, error) {
|
||||
initPacket.Version = base.MShellVersion
|
||||
server.Sender.SendPacket(initPacket)
|
||||
for pk := range server.MainInput.MainCh {
|
||||
if server.Debug {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
}
|
||||
if pk.GetType() == packet.PingPacketStr {
|
||||
continue
|
||||
}
|
||||
@ -110,9 +159,8 @@ func RunServer() (int, error) {
|
||||
server.runCommand(runPacket)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
server.FdContext.processDataPacket(dataPacket)
|
||||
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
|
||||
server.ProcessCommandPacket(cmdPk)
|
||||
continue
|
||||
}
|
||||
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk)))
|
||||
|
@ -99,12 +99,12 @@ type FdContext interface {
|
||||
GetReader(fdNum int) io.ReadCloser
|
||||
}
|
||||
|
||||
func MakeShExec(ck base.CommandKey) *ShExecType {
|
||||
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
|
||||
return &ShExecType{
|
||||
Lock: &sync.Mutex{},
|
||||
StartTs: time.Now(),
|
||||
CK: ck,
|
||||
Multiplexer: mpio.MakeMultiplexer(ck),
|
||||
Multiplexer: mpio.MakeMultiplexer(ck, upr),
|
||||
}
|
||||
}
|
||||
|
||||
@ -524,8 +524,8 @@ func HasDupStdin(fds []packet.RemoteFd) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) {
|
||||
cmd := MakeShExec("")
|
||||
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, upr packet.UnknownPacketReporter, debug bool) (*packet.CmdDonePacketType, error) {
|
||||
cmd := MakeShExec(runPacket.CK, upr)
|
||||
ecmd := sshOpts.MakeSSHExecCmd(ClientCommand)
|
||||
cmd.Cmd = ecmd
|
||||
inputWriter, err := ecmd.StdinPipe()
|
||||
@ -656,7 +656,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sen
|
||||
}
|
||||
|
||||
func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
|
||||
cmd := MakeShExec(pk.CK)
|
||||
cmd := MakeShExec(pk.CK, nil)
|
||||
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
|
||||
UpdateCmdEnv(cmd.Cmd, pk.Env)
|
||||
if pk.Cwd != "" {
|
||||
@ -736,7 +736,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
|
||||
defer func() {
|
||||
cmdTty.Close()
|
||||
}()
|
||||
rtn := MakeShExec(pk.CK)
|
||||
rtn := MakeShExec(pk.CK, nil)
|
||||
ecmd := MakeExecCmd(pk, cmdTty)
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
@ -750,14 +750,14 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
|
||||
// copy pty output to .ptyout file
|
||||
_, copyErr := io.Copy(ptyOutFd, cmdPty)
|
||||
if copyErr != nil {
|
||||
sender.SendErrorPacket(fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// copy .stdin fifo contents to pty input
|
||||
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
|
||||
if copyFifoErr != nil {
|
||||
sender.SendErrorPacket(fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
|
||||
sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
|
||||
}
|
||||
}()
|
||||
rtn.FileNames = fileNames
|
||||
|
Loading…
Reference in New Issue
Block a user