diff --git a/main-mshell.go b/main-mshell.go index 0bdb4e779..e5120fd84 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -245,7 +245,7 @@ func handleRemote() { defer cmd.Close() startPacket := cmd.MakeCmdStartPacket() sender.SendPacket(startPacket) - cmd.RunIOAndWait(sender) + cmd.RunIOAndWait(packetCh, sender) } func handleServer() { diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 33989978e..fc9fcaa4c 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -28,6 +28,7 @@ const ( PingPacketStr = "ping" InitPacketStr = "init" DataPacketStr = "data" + DataAckPacketStr = "dataack" CmdStartPacketStr = "cmdstart" CmdDonePacketStr = "cmddone" ResponsePacketStr = "resp" @@ -62,6 +63,7 @@ func init() { TypeStrToFactory[RawPacketStr] = reflect.TypeOf(RawPacketType{}) TypeStrToFactory[InputPacketStr] = reflect.TypeOf(InputPacketType{}) TypeStrToFactory[DataPacketStr] = reflect.TypeOf(DataPacketType{}) + TypeStrToFactory[DataAckPacketStr] = reflect.TypeOf(DataAckPacketType{}) } func MakePacket(packetType string) (PacketType, error) { @@ -128,6 +130,23 @@ func MakeDataPacket() *DataPacketType { return &DataPacketType{Type: DataPacketStr} } +type DataAckPacketType struct { + Type string `json:"type"` + SessionId string `json:"sessionid,omitempty"` + CmdId string `json:"cmdid,omitempty"` + FdNum int `json:"fdnum"` + AckLen int `json:"acklen"` + Error string `json:"error"` +} + +func (*DataAckPacketType) GetType() string { + return DataAckPacketStr +} + +func MakeDataAckPacket() *DataAckPacketType { + return &DataAckPacketType{Type: DataAckPacketStr} +} + // InputData gets written to PTY directly // SigNum gets sent to process via a signal // WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH diff --git a/pkg/shexec/bufreader.go b/pkg/shexec/bufreader.go new file mode 100644 index 000000000..73aefc7bf --- /dev/null +++ b/pkg/shexec/bufreader.go @@ -0,0 +1,125 @@ +// Copyright 2022 Dashborg Inc +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package shexec + +import ( + "io" + "os" + "sync" + + "github.com/scripthaus-dev/mshell/pkg/packet" +) + +type FdReader struct { + CVar *sync.Cond + SessionId string + CmdId string + FdNum int + Fd *os.File + BufSize int + Closed bool +} + +func MakeFdReader(c *ShExecType, fd *os.File, fdNum int) *FdReader { + return &FdReader{ + CVar: sync.NewCond(&sync.Mutex{}), + SessionId: c.RunPacket.SessionId, + CmdId: c.RunPacket.CmdId, + FdNum: fdNum, + Fd: fd, + BufSize: 0, + } +} + +func (r *FdReader) Close() { + r.CVar.L.Lock() + defer r.CVar.L.Unlock() + if r.Closed { + return + } + if r.Fd != nil { + r.Fd.Close() + } + r.CVar.Broadcast() +} + +func (r *FdReader) NotifyAck(ackLen int) { + r.CVar.L.Lock() + defer r.CVar.L.Unlock() + r.BufSize -= ackLen + if r.BufSize < 0 { + r.BufSize = 0 + } + r.CVar.Broadcast() +} + +// returns (success) +func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof bool) bool { + if len(data) == 0 { + return true + } + r.CVar.L.Lock() + defer r.CVar.L.Unlock() + for { + bufAvail := ReadBufSize - r.BufSize + if r.Closed { + return false + } + if bufAvail == 0 { + r.CVar.Wait() + continue + } + writeLen := min(bufAvail, len(data)) + pk := r.MakeDataPacket(data[0:writeLen], nil) + sender.SendPacket(pk) + r.BufSize += writeLen + data = data[writeLen:] + if len(data) == 0 { + return true + } + r.CVar.Wait() + } +} + +func min(v1 int, v2 int) int { + if v1 <= v2 { + return v1 + } + return v2 +} + +func (r *FdReader) MakeDataPacket(data []byte, err error) *packet.DataPacketType { + pk := packet.MakeDataPacket() + pk.SessionId = r.SessionId + pk.CmdId = r.CmdId + pk.FdNum = r.FdNum + pk.Data = string(data) + if err != nil { + pk.Error = err.Error() + } + return pk +} + +func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) { + defer r.Close() + defer wg.Done() + buf := make([]byte, 4096) + for { + nr, err := r.Fd.Read(buf) + if nr > 0 || err == io.EOF { + isOpen := r.WriteWait(sender, buf[0:nr], (err == io.EOF)) + if !isOpen { + return + } + } + if err != nil { + errPk := r.MakeDataPacket(nil, err) + sender.SendPacket(errPk) + return + } + } +} diff --git a/pkg/shexec/bufwriter.go b/pkg/shexec/bufwriter.go new file mode 100644 index 000000000..f51a06049 --- /dev/null +++ b/pkg/shexec/bufwriter.go @@ -0,0 +1,115 @@ +// Copyright 2022 Dashborg Inc +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package shexec + +import ( + "fmt" + "os" + "sync" + + "github.com/scripthaus-dev/mshell/pkg/packet" +) + +type FdWriter struct { + CVar *sync.Cond + SessionId string + CmdId string + FdNum int + Buffer []byte + Fd *os.File + Eof bool + Closed bool +} + +func MakeFdWriter(c *ShExecType, fd *os.File, fdNum int) *FdWriter { + return &FdWriter{ + CVar: sync.NewCond(&sync.Mutex{}), + Fd: fd, + SessionId: c.RunPacket.SessionId, + CmdId: c.RunPacket.CmdId, + FdNum: fdNum, + } +} + +func (w *FdWriter) Close() { + w.CVar.L.Lock() + defer w.CVar.L.Unlock() + if w.Closed { + return + } + w.Closed = true + if w.Fd != nil { + w.Fd.Close() + } + w.Buffer = nil + w.CVar.Broadcast() +} + +func (w *FdWriter) WaitForData() ([]byte, bool) { + w.CVar.L.Lock() + defer w.CVar.L.Unlock() + for { + if len(w.Buffer) > 0 || w.Eof || w.Closed { + toWrite := w.Buffer + w.Buffer = nil + return toWrite, w.Eof + } + w.CVar.Wait() + } +} + +func (w *FdWriter) MakeDataAckPacket(ackLen int, err error) *packet.DataAckPacketType { + ack := packet.MakeDataAckPacket() + ack.SessionId = w.SessionId + ack.CmdId = w.CmdId + ack.FdNum = w.FdNum + ack.AckLen = ackLen + if err != nil { + ack.Error = err.Error() + } + return ack +} + +func (w *FdWriter) AddData(data []byte, eof bool) error { + w.CVar.L.Lock() + defer w.CVar.L.Unlock() + if w.Closed { + return fmt.Errorf("write to closed file") + } + if len(data) > 0 { + if len(data)+len(w.Buffer) > WriteBufSize { + return fmt.Errorf("write exceeds buffer size") + } + w.Buffer = append(w.Buffer, data...) + } + if eof { + w.Eof = true + } + w.CVar.Broadcast() + return nil +} + +func (w *FdWriter) WriteLoop(sender *packet.PacketSender) { + defer w.Close() + for { + data, isEof := w.WaitForData() + if w.Closed { + return + } + if len(data) > 0 { + nw, err := w.Fd.Write(data) + ack := w.MakeDataAckPacket(nw, err) + sender.SendPacket(ack) + if err != nil { + return + } + } + if isEof { + return + } + } +} diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index e2cd6a8ec..3afedc296 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -26,36 +26,43 @@ const DefaultRows = 25 const DefaultCols = 80 const MaxRows = 1024 const MaxCols = 1024 +const ReadBufSize = 128 * 1024 +const WriteBufSize = 128 * 1024 type ShExecType struct { + Lock *sync.Mutex StartTs time.Time RunPacket *packet.RunPacketType FileNames *base.CommandFileNames Cmd *exec.Cmd CmdPty *os.File - FdReaders map[int]*os.File - FdWriters map[int]*os.File - CloseAfterStart []*os.File + FdReaders map[int]*FdReader // synchronized + FdWriters map[int]*FdWriter // synchronized + CloseAfterStart []*os.File // synchronized } func MakeShExec(pk *packet.RunPacketType) *ShExecType { return &ShExecType{ + Lock: &sync.Mutex{}, StartTs: time.Now(), RunPacket: pk, - FdReaders: make(map[int]*os.File), - FdWriters: make(map[int]*os.File), + FdReaders: make(map[int]*FdReader), + FdWriters: make(map[int]*FdWriter), } } func (c *ShExecType) Close() { + c.Lock.Lock() + defer c.Lock.Unlock() + if c.CmdPty != nil { c.CmdPty.Close() } for _, fd := range c.FdReaders { fd.Close() } - for _, fd := range c.FdWriters { - fd.Close() + for _, fw := range c.FdWriters { + fw.Close() } for _, fd := range c.CloseAfterStart { fd.Close() @@ -221,7 +228,9 @@ func (cmd *ShExecType) makeReaderPipe(fdNum int) (*os.File, error) { if err != nil { return nil, err } - cmd.FdReaders[fdNum] = pr + cmd.Lock.Lock() + defer cmd.Lock.Unlock() + cmd.FdReaders[fdNum] = MakeFdReader(cmd, pr, fdNum) cmd.CloseAfterStart = append(cmd.CloseAfterStart, pw) return pw, nil } @@ -232,51 +241,81 @@ func (cmd *ShExecType) makeWriterPipe(fdNum int) (*os.File, error) { if err != nil { return nil, err } - cmd.FdWriters[fdNum] = pw + cmd.Lock.Lock() + defer cmd.Lock.Unlock() + cmd.FdWriters[fdNum] = MakeFdWriter(cmd, pw, fdNum) cmd.CloseAfterStart = append(cmd.CloseAfterStart, pr) return pr, nil } -func (cmd *ShExecType) MakeDataPacket(fdNum int, data []byte) *packet.DataPacketType { - pk := packet.MakeDataPacket() - pk.SessionId = cmd.RunPacket.SessionId - pk.CmdId = cmd.RunPacket.CmdId - pk.FdNum = fdNum - pk.Data = string(data) - return pk +func (cmd *ShExecType) MakeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType { + ack := packet.MakeDataAckPacket() + ack.SessionId = cmd.RunPacket.SessionId + ack.CmdId = cmd.RunPacket.CmdId + ack.FdNum = fdNum + ack.AckLen = ackLen + if err != nil { + ack.Error = err.Error() + } + return ack } -func (cmd *ShExecType) runReadLoop(wg *sync.WaitGroup, fdNum int, fd *os.File, sender *packet.PacketSender) { - go func() { - defer fd.Close() - defer wg.Done() - buf := make([]byte, 4096) - for { - nr, err := fd.Read(buf) - pk := cmd.MakeDataPacket(fdNum, buf[0:nr]) - if err == io.EOF { - pk.Eof = true - sender.SendPacket(pk) - break - } else if err != nil { - pk.Error = err.Error() - sender.SendPacket(pk) - break - } else { - sender.SendPacket(pk) - } +func (cmd *ShExecType) launchWriters(sender *packet.PacketSender) { + cmd.Lock.Lock() + defer cmd.Lock.Unlock() + for _, fw := range cmd.FdWriters { + go fw.WriteLoop(sender) + } +} + +func (cmd *ShExecType) writeDataPacket(dataPacket *packet.DataPacketType) error { + cmd.Lock.Lock() + defer cmd.Lock.Unlock() + fw := cmd.FdWriters[dataPacket.FdNum] + if fw == nil { + // add a closed FdWriter as a placeholder so we only send one error + fw := MakeFdWriter(cmd, nil, dataPacket.FdNum) + fw.Close() + cmd.FdWriters[dataPacket.FdNum] = fw + return fmt.Errorf("write to closed file") + } + err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof) + if err != nil { + fw.Close() + return err + } + return nil +} + +func (cmd *ShExecType) runMainWriteLoop(packetCh chan packet.PacketType, sender *packet.PacketSender) { + for pk := range packetCh { + if pk.GetType() != packet.DataPacketStr { + // other packets are ignored + continue } - }() + dataPacket := pk.(*packet.DataPacketType) + err := cmd.writeDataPacket(dataPacket) + if err != nil { + errPacket := cmd.MakeDataAckPacket(dataPacket.FdNum, 0, err) + sender.SendPacket(errPacket) + } + } } -func (cmd *ShExecType) RunIOAndWait(sender *packet.PacketSender) { - var wg sync.WaitGroup +func (cmd *ShExecType) launchReaders(wg *sync.WaitGroup, sender *packet.PacketSender) { + cmd.Lock.Lock() + defer cmd.Lock.Unlock() wg.Add(len(cmd.FdReaders)) - go func() { - for fdNum, fd := range cmd.FdReaders { - cmd.runReadLoop(&wg, fdNum, fd, sender) - } - }() + for _, fr := range cmd.FdReaders { + go fr.ReadLoop(wg, sender) + } +} + +func (cmd *ShExecType) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) { + var wg sync.WaitGroup + cmd.launchReaders(&wg, sender) + cmd.launchWriters(sender) + go cmd.runMainWriteLoop(packetCh, sender) donePacket := cmd.WaitForCommand() wg.Wait() sender.SendPacket(donePacket)