working with fdreaders and fdwriters to properly buffer output and not exceed buffer size without acks

This commit is contained in:
sawka 2022-06-23 17:37:05 -07:00
parent c43d3ecc85
commit 29372be4ef
5 changed files with 342 additions and 44 deletions

View File

@ -245,7 +245,7 @@ func handleRemote() {
defer cmd.Close()
startPacket := cmd.MakeCmdStartPacket()
sender.SendPacket(startPacket)
cmd.RunIOAndWait(sender)
cmd.RunIOAndWait(packetCh, sender)
}
func handleServer() {

View File

@ -28,6 +28,7 @@ const (
PingPacketStr = "ping"
InitPacketStr = "init"
DataPacketStr = "data"
DataAckPacketStr = "dataack"
CmdStartPacketStr = "cmdstart"
CmdDonePacketStr = "cmddone"
ResponsePacketStr = "resp"
@ -62,6 +63,7 @@ func init() {
TypeStrToFactory[RawPacketStr] = reflect.TypeOf(RawPacketType{})
TypeStrToFactory[InputPacketStr] = reflect.TypeOf(InputPacketType{})
TypeStrToFactory[DataPacketStr] = reflect.TypeOf(DataPacketType{})
TypeStrToFactory[DataAckPacketStr] = reflect.TypeOf(DataAckPacketType{})
}
func MakePacket(packetType string) (PacketType, error) {
@ -128,6 +130,23 @@ func MakeDataPacket() *DataPacketType {
return &DataPacketType{Type: DataPacketStr}
}
type DataAckPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
FdNum int `json:"fdnum"`
AckLen int `json:"acklen"`
Error string `json:"error"`
}
func (*DataAckPacketType) GetType() string {
return DataAckPacketStr
}
func MakeDataAckPacket() *DataAckPacketType {
return &DataAckPacketType{Type: DataAckPacketStr}
}
// InputData gets written to PTY directly
// SigNum gets sent to process via a signal
// WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH

125
pkg/shexec/bufreader.go Normal file
View File

@ -0,0 +1,125 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package shexec
import (
"io"
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
type FdReader struct {
CVar *sync.Cond
SessionId string
CmdId string
FdNum int
Fd *os.File
BufSize int
Closed bool
}
func MakeFdReader(c *ShExecType, fd *os.File, fdNum int) *FdReader {
return &FdReader{
CVar: sync.NewCond(&sync.Mutex{}),
SessionId: c.RunPacket.SessionId,
CmdId: c.RunPacket.CmdId,
FdNum: fdNum,
Fd: fd,
BufSize: 0,
}
}
func (r *FdReader) Close() {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
if r.Closed {
return
}
if r.Fd != nil {
r.Fd.Close()
}
r.CVar.Broadcast()
}
func (r *FdReader) NotifyAck(ackLen int) {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
r.BufSize -= ackLen
if r.BufSize < 0 {
r.BufSize = 0
}
r.CVar.Broadcast()
}
// returns (success)
func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof bool) bool {
if len(data) == 0 {
return true
}
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
for {
bufAvail := ReadBufSize - r.BufSize
if r.Closed {
return false
}
if bufAvail == 0 {
r.CVar.Wait()
continue
}
writeLen := min(bufAvail, len(data))
pk := r.MakeDataPacket(data[0:writeLen], nil)
sender.SendPacket(pk)
r.BufSize += writeLen
data = data[writeLen:]
if len(data) == 0 {
return true
}
r.CVar.Wait()
}
}
func min(v1 int, v2 int) int {
if v1 <= v2 {
return v1
}
return v2
}
func (r *FdReader) MakeDataPacket(data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = r.SessionId
pk.CmdId = r.CmdId
pk.FdNum = r.FdNum
pk.Data = string(data)
if err != nil {
pk.Error = err.Error()
}
return pk
}
func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) {
defer r.Close()
defer wg.Done()
buf := make([]byte, 4096)
for {
nr, err := r.Fd.Read(buf)
if nr > 0 || err == io.EOF {
isOpen := r.WriteWait(sender, buf[0:nr], (err == io.EOF))
if !isOpen {
return
}
}
if err != nil {
errPk := r.MakeDataPacket(nil, err)
sender.SendPacket(errPk)
return
}
}
}

115
pkg/shexec/bufwriter.go Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package shexec
import (
"fmt"
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
type FdWriter struct {
CVar *sync.Cond
SessionId string
CmdId string
FdNum int
Buffer []byte
Fd *os.File
Eof bool
Closed bool
}
func MakeFdWriter(c *ShExecType, fd *os.File, fdNum int) *FdWriter {
return &FdWriter{
CVar: sync.NewCond(&sync.Mutex{}),
Fd: fd,
SessionId: c.RunPacket.SessionId,
CmdId: c.RunPacket.CmdId,
FdNum: fdNum,
}
}
func (w *FdWriter) Close() {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
if w.Closed {
return
}
w.Closed = true
if w.Fd != nil {
w.Fd.Close()
}
w.Buffer = nil
w.CVar.Broadcast()
}
func (w *FdWriter) WaitForData() ([]byte, bool) {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
for {
if len(w.Buffer) > 0 || w.Eof || w.Closed {
toWrite := w.Buffer
w.Buffer = nil
return toWrite, w.Eof
}
w.CVar.Wait()
}
}
func (w *FdWriter) MakeDataAckPacket(ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = w.SessionId
ack.CmdId = w.CmdId
ack.FdNum = w.FdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (w *FdWriter) AddData(data []byte, eof bool) error {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
if w.Closed {
return fmt.Errorf("write to closed file")
}
if len(data) > 0 {
if len(data)+len(w.Buffer) > WriteBufSize {
return fmt.Errorf("write exceeds buffer size")
}
w.Buffer = append(w.Buffer, data...)
}
if eof {
w.Eof = true
}
w.CVar.Broadcast()
return nil
}
func (w *FdWriter) WriteLoop(sender *packet.PacketSender) {
defer w.Close()
for {
data, isEof := w.WaitForData()
if w.Closed {
return
}
if len(data) > 0 {
nw, err := w.Fd.Write(data)
ack := w.MakeDataAckPacket(nw, err)
sender.SendPacket(ack)
if err != nil {
return
}
}
if isEof {
return
}
}
}

View File

@ -26,36 +26,43 @@ const DefaultRows = 25
const DefaultCols = 80
const MaxRows = 1024
const MaxCols = 1024
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
type ShExecType struct {
Lock *sync.Mutex
StartTs time.Time
RunPacket *packet.RunPacketType
FileNames *base.CommandFileNames
Cmd *exec.Cmd
CmdPty *os.File
FdReaders map[int]*os.File
FdWriters map[int]*os.File
CloseAfterStart []*os.File
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized
}
func MakeShExec(pk *packet.RunPacketType) *ShExecType {
return &ShExecType{
Lock: &sync.Mutex{},
StartTs: time.Now(),
RunPacket: pk,
FdReaders: make(map[int]*os.File),
FdWriters: make(map[int]*os.File),
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
}
}
func (c *ShExecType) Close() {
c.Lock.Lock()
defer c.Lock.Unlock()
if c.CmdPty != nil {
c.CmdPty.Close()
}
for _, fd := range c.FdReaders {
fd.Close()
}
for _, fd := range c.FdWriters {
fd.Close()
for _, fw := range c.FdWriters {
fw.Close()
}
for _, fd := range c.CloseAfterStart {
fd.Close()
@ -221,7 +228,9 @@ func (cmd *ShExecType) makeReaderPipe(fdNum int) (*os.File, error) {
if err != nil {
return nil, err
}
cmd.FdReaders[fdNum] = pr
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
cmd.FdReaders[fdNum] = MakeFdReader(cmd, pr, fdNum)
cmd.CloseAfterStart = append(cmd.CloseAfterStart, pw)
return pw, nil
}
@ -232,51 +241,81 @@ func (cmd *ShExecType) makeWriterPipe(fdNum int) (*os.File, error) {
if err != nil {
return nil, err
}
cmd.FdWriters[fdNum] = pw
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
cmd.FdWriters[fdNum] = MakeFdWriter(cmd, pw, fdNum)
cmd.CloseAfterStart = append(cmd.CloseAfterStart, pr)
return pr, nil
}
func (cmd *ShExecType) MakeDataPacket(fdNum int, data []byte) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = cmd.RunPacket.SessionId
pk.CmdId = cmd.RunPacket.CmdId
pk.FdNum = fdNum
pk.Data = string(data)
return pk
func (cmd *ShExecType) MakeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = cmd.RunPacket.SessionId
ack.CmdId = cmd.RunPacket.CmdId
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (cmd *ShExecType) runReadLoop(wg *sync.WaitGroup, fdNum int, fd *os.File, sender *packet.PacketSender) {
go func() {
defer fd.Close()
defer wg.Done()
buf := make([]byte, 4096)
for {
nr, err := fd.Read(buf)
pk := cmd.MakeDataPacket(fdNum, buf[0:nr])
if err == io.EOF {
pk.Eof = true
sender.SendPacket(pk)
break
} else if err != nil {
pk.Error = err.Error()
sender.SendPacket(pk)
break
} else {
sender.SendPacket(pk)
}
func (cmd *ShExecType) launchWriters(sender *packet.PacketSender) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
for _, fw := range cmd.FdWriters {
go fw.WriteLoop(sender)
}
}
func (cmd *ShExecType) writeDataPacket(dataPacket *packet.DataPacketType) error {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fw := cmd.FdWriters[dataPacket.FdNum]
if fw == nil {
// add a closed FdWriter as a placeholder so we only send one error
fw := MakeFdWriter(cmd, nil, dataPacket.FdNum)
fw.Close()
cmd.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file")
}
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
if err != nil {
fw.Close()
return err
}
return nil
}
func (cmd *ShExecType) runMainWriteLoop(packetCh chan packet.PacketType, sender *packet.PacketSender) {
for pk := range packetCh {
if pk.GetType() != packet.DataPacketStr {
// other packets are ignored
continue
}
}()
dataPacket := pk.(*packet.DataPacketType)
err := cmd.writeDataPacket(dataPacket)
if err != nil {
errPacket := cmd.MakeDataAckPacket(dataPacket.FdNum, 0, err)
sender.SendPacket(errPacket)
}
}
}
func (cmd *ShExecType) RunIOAndWait(sender *packet.PacketSender) {
var wg sync.WaitGroup
func (cmd *ShExecType) launchReaders(wg *sync.WaitGroup, sender *packet.PacketSender) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
wg.Add(len(cmd.FdReaders))
go func() {
for fdNum, fd := range cmd.FdReaders {
cmd.runReadLoop(&wg, fdNum, fd, sender)
}
}()
for _, fr := range cmd.FdReaders {
go fr.ReadLoop(wg, sender)
}
}
func (cmd *ShExecType) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
var wg sync.WaitGroup
cmd.launchReaders(&wg, sender)
cmd.launchWriters(sender)
go cmd.runMainWriteLoop(packetCh, sender)
donePacket := cmd.WaitForCommand()
wg.Wait()
sender.SendPacket(donePacket)