mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-08 19:38:51 +01:00
fix mshell server to just blindly proxy mshell single command input/output, better simpler code
This commit is contained in:
parent
51df0479ff
commit
1b69bb0ac8
@ -8,21 +8,20 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/mpio"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||
)
|
||||
|
||||
// TODO create unblockable packet-sender (backed by an array) for clientproc
|
||||
type MServer struct {
|
||||
Lock *sync.Mutex
|
||||
MainInput *packet.PacketParser
|
||||
Sender *packet.PacketSender
|
||||
FdContextMap map[base.CommandKey]*serverFdContext
|
||||
ClientMap map[base.CommandKey]*shexec.ClientProc
|
||||
Debug bool
|
||||
}
|
||||
|
||||
@ -31,43 +30,6 @@ func (m *MServer) Close() {
|
||||
m.Sender.WaitForDone()
|
||||
}
|
||||
|
||||
type serverFdContext struct {
|
||||
M *MServer
|
||||
Lock *sync.Mutex
|
||||
Sender *packet.PacketSender
|
||||
CK base.CommandKey
|
||||
Readers map[int]*mpio.PacketReader
|
||||
}
|
||||
|
||||
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
|
||||
c.Lock.Lock()
|
||||
reader := c.Readers[pk.FdNum]
|
||||
c.Lock.Unlock()
|
||||
if reader == nil {
|
||||
ackPacket := packet.MakeDataAckPacket()
|
||||
ackPacket.CK = c.CK
|
||||
ackPacket.FdNum = pk.FdNum
|
||||
ackPacket.Error = "write to closed file (no fd)"
|
||||
c.M.Sender.SendPacket(ackPacket)
|
||||
return
|
||||
}
|
||||
reader.AddData(pk)
|
||||
}
|
||||
|
||||
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
rtn := &serverFdContext{
|
||||
M: m,
|
||||
Lock: &sync.Mutex{},
|
||||
Sender: m.Sender,
|
||||
CK: ck,
|
||||
Readers: make(map[int]*mpio.PacketReader),
|
||||
}
|
||||
m.FdContextMap[ck] = rtn
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
ck := pk.GetCK()
|
||||
if ck == "" {
|
||||
@ -75,41 +37,14 @@ func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
return
|
||||
}
|
||||
m.Lock.Lock()
|
||||
fdContext := m.FdContextMap[ck]
|
||||
cproc := m.ClientMap[ck]
|
||||
m.Lock.Unlock()
|
||||
if fdContext == nil {
|
||||
m.Sender.SendCmdError(ck, fmt.Errorf("no server context for ck '%s'", ck))
|
||||
if cproc == nil {
|
||||
m.Sender.SendCmdError(ck, fmt.Errorf("no client proc for ck '%s'", ck))
|
||||
return
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
fdContext.processDataPacket(dataPacket)
|
||||
cproc.Input.SendPacket(pk)
|
||||
return
|
||||
} else if pk.GetType() == packet.DataAckPacketStr {
|
||||
m.Sender.SendPacket(pk)
|
||||
return
|
||||
} else {
|
||||
m.Sender.SendCmdError(ck, fmt.Errorf("invalid packet '%s' received", packet.AsExtType(pk)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
|
||||
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
|
||||
}
|
||||
|
||||
func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
|
||||
c.Lock.Lock()
|
||||
defer c.Lock.Unlock()
|
||||
reader := mpio.MakePacketReader(fdNum)
|
||||
c.Readers[fdNum] = reader
|
||||
return reader
|
||||
}
|
||||
|
||||
func (m *MServer) RemoveFdContext(ck base.CommandKey) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
delete(m.FdContextMap, ck)
|
||||
}
|
||||
|
||||
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
@ -117,21 +52,26 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
}
|
||||
fdContext := m.MakeServerFdContext(runPacket.CK)
|
||||
go func() {
|
||||
defer m.RemoveFdContext(runPacket.CK)
|
||||
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, m, m.Debug)
|
||||
if donePk != nil && !runPacket.Detached {
|
||||
m.Sender.SendPacket(donePk)
|
||||
}
|
||||
cproc, err := shexec.MakeClientProc(runPacket.CK)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, err)
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
|
||||
return
|
||||
}
|
||||
fmt.Printf("client start: %v\n", runPacket.CK)
|
||||
m.Lock.Lock()
|
||||
m.ClientMap[runPacket.CK] = cproc
|
||||
m.Lock.Unlock()
|
||||
go func() {
|
||||
defer func() {
|
||||
m.Lock.Lock()
|
||||
delete(m.ClientMap, runPacket.CK)
|
||||
m.Lock.Unlock()
|
||||
cproc.Close()
|
||||
fmt.Printf("client done: %v\n", runPacket.CK)
|
||||
}()
|
||||
shexec.SendRunPacketAndRunData(cproc.Input, runPacket)
|
||||
cproc.ProxyOutput(m.Sender)
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *MServer) UnknownPacket(pk packet.PacketType) {
|
||||
m.Sender.SendPacket(pk)
|
||||
}
|
||||
|
||||
func RunServer() (int, error) {
|
||||
@ -141,7 +81,7 @@ func RunServer() (int, error) {
|
||||
}
|
||||
server := &MServer{
|
||||
Lock: &sync.Mutex{},
|
||||
FdContextMap: make(map[base.CommandKey]*serverFdContext),
|
||||
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
|
||||
Debug: debug,
|
||||
}
|
||||
if debug {
|
||||
@ -161,12 +101,7 @@ func RunServer() (int, error) {
|
||||
if server.Debug {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
}
|
||||
|
||||
// run-start combo
|
||||
ok, runPacket := builder.ProcessPacket(pk)
|
||||
if server.Debug {
|
||||
fmt.Printf("PP> %s | %v\n", pk.GetType(), ok)
|
||||
}
|
||||
if ok {
|
||||
if runPacket != nil {
|
||||
server.runCommand(runPacket)
|
||||
@ -174,20 +109,11 @@ func RunServer() (int, error) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if startPk, ok := pk.(*packet.CmdStartPacketType); ok {
|
||||
if server.Debug {
|
||||
fmt.Printf("START> %v", startPk)
|
||||
}
|
||||
server.Sender.SendPacket(startPk)
|
||||
continue
|
||||
}
|
||||
|
||||
// command packet
|
||||
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
|
||||
server.ProcessCommandPacket(cmdPk)
|
||||
continue
|
||||
}
|
||||
server.Sender.SendMessage(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsString(pk)))
|
||||
server.Sender.SendMessage(fmt.Sprintf("invalid packet '%s' sent to mshell server", packet.AsString(pk)))
|
||||
continue
|
||||
}
|
||||
return 0, nil
|
||||
|
121
pkg/shexec/client.go
Normal file
121
pkg/shexec/client.go
Normal file
@ -0,0 +1,121 @@
|
||||
package shexec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
|
||||
type ClientProc struct {
|
||||
Cmd *exec.Cmd
|
||||
CK base.CommandKey
|
||||
StartTs time.Time
|
||||
StdinWriter io.WriteCloser
|
||||
StdoutReader io.ReadCloser
|
||||
StderrReader io.ReadCloser
|
||||
Input *packet.PacketSender
|
||||
Output *packet.PacketParser
|
||||
}
|
||||
|
||||
func MakeClientProc(ck base.CommandKey) (*ClientProc, error) {
|
||||
ecmd, err := SSHOpts{}.MakeMShellSingleCmd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inputWriter, err := ecmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stdin pipe: %v", err)
|
||||
}
|
||||
stdoutReader, err := ecmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stdout pipe: %v", err)
|
||||
}
|
||||
stderrReader, err := ecmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
startTs := time.Now()
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running local client: %w", err)
|
||||
}
|
||||
sender := packet.MakePacketSender(inputWriter)
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader)
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser)
|
||||
cproc := &ClientProc{
|
||||
Cmd: ecmd,
|
||||
CK: ck,
|
||||
StartTs: startTs,
|
||||
StdinWriter: inputWriter,
|
||||
StdoutReader: stdoutReader,
|
||||
StderrReader: stderrReader,
|
||||
Input: sender,
|
||||
Output: packetParser,
|
||||
}
|
||||
versionOk := false
|
||||
for pk := range packetParser.MainCh {
|
||||
if pk.GetType() != packet.InitPacketStr {
|
||||
cproc.Close()
|
||||
return nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
|
||||
}
|
||||
initPk := pk.(*packet.InitPacketType)
|
||||
if initPk.NotFound {
|
||||
cproc.Close()
|
||||
return nil, fmt.Errorf("mshell command not found on local server")
|
||||
}
|
||||
if initPk.Version != base.MShellVersion {
|
||||
cproc.Close()
|
||||
return nil, fmt.Errorf("invalid remote mshell version 'v%s', must be v%s", initPk.Version, base.MShellVersion)
|
||||
}
|
||||
versionOk = true
|
||||
break
|
||||
}
|
||||
if !versionOk {
|
||||
cproc.Close()
|
||||
return nil, fmt.Errorf("no init packet received from mshell client")
|
||||
}
|
||||
return cproc, nil
|
||||
}
|
||||
|
||||
func (cproc *ClientProc) Close() {
|
||||
if cproc.Input != nil {
|
||||
cproc.Input.Close()
|
||||
}
|
||||
if cproc.StdinWriter != nil {
|
||||
cproc.StdinWriter.Close()
|
||||
}
|
||||
if cproc.StdoutReader != nil {
|
||||
cproc.StdoutReader.Close()
|
||||
}
|
||||
if cproc.StderrReader != nil {
|
||||
cproc.StderrReader.Close()
|
||||
}
|
||||
if cproc.Cmd != nil {
|
||||
cproc.Cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
func (cproc *ClientProc) ProxyOutput(sender *packet.PacketSender) {
|
||||
sentDonePk := false
|
||||
for pk := range cproc.Output.MainCh {
|
||||
if pk.GetType() == packet.CmdDonePacketStr {
|
||||
sentDonePk = true
|
||||
}
|
||||
sender.SendPacket(pk)
|
||||
}
|
||||
exitErr := cproc.Cmd.Wait()
|
||||
if !sentDonePk {
|
||||
endTs := time.Now()
|
||||
cmdDuration := endTs.Sub(cproc.StartTs)
|
||||
donePacket := packet.MakeCmdDonePacket(cproc.CK)
|
||||
donePacket.Ts = endTs.UnixMilli()
|
||||
donePacket.ExitCode = GetExitCode(exitErr)
|
||||
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
||||
sender.SendPacket(donePacket)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user