move multiplexed IO to its own package independent of SHExecType (to use in mshell client)

This commit is contained in:
sawka 2022-06-24 10:24:02 -07:00
parent 4256ff5231
commit 0267836376
4 changed files with 270 additions and 217 deletions

View File

@ -4,7 +4,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/. // file, You can obtain one at https://mozilla.org/MPL/2.0/.
package shexec package mpio
import ( import (
"io" "io"
@ -15,24 +15,23 @@ import (
) )
type FdReader struct { type FdReader struct {
CVar *sync.Cond CVar *sync.Cond
SessionId string M *Multiplexer
CmdId string FdNum int
FdNum int Fd *os.File
Fd *os.File BufSize int
BufSize int Closed bool
Closed bool
} }
func MakeFdReader(c *ShExecType, fd *os.File, fdNum int) *FdReader { func MakeFdReader(m *Multiplexer, fd *os.File, fdNum int) *FdReader {
return &FdReader{ fr := &FdReader{
CVar: sync.NewCond(&sync.Mutex{}), CVar: sync.NewCond(&sync.Mutex{}),
SessionId: c.RunPacket.SessionId, M: m,
CmdId: c.RunPacket.CmdId, FdNum: fdNum,
FdNum: fdNum, Fd: fd,
Fd: fd, BufSize: 0,
BufSize: 0,
} }
return fr
} }
func (r *FdReader) Close() { func (r *FdReader) Close() {
@ -63,14 +62,14 @@ func (r *FdReader) NotifyAck(ackLen int) {
// !! inverse locking. must already hold the lock when you call this method. // !! inverse locking. must already hold the lock when you call this method.
// will *unlock*, send the packet, and then *relock* once it is done. // will *unlock*, send the packet, and then *relock* once it is done.
// this can prevent an unlikely deadlock where we are holding r.CVar.L and stuck on sender.SendCh // this can prevent an unlikely deadlock where we are holding r.CVar.L and stuck on sender.SendCh
func (r *FdReader) sendPacket_unlock(sender *packet.PacketSender, pk packet.PacketType) { func (r *FdReader) sendPacket_unlock(pk packet.PacketType) {
r.CVar.L.Unlock() r.CVar.L.Unlock()
defer r.CVar.L.Lock() defer r.CVar.L.Lock()
sender.SendPacket(pk) r.M.sendPacket(pk)
} }
// returns (success) // returns (success)
func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof bool) bool { func (r *FdReader) WriteWait(data []byte, isEof bool) bool {
r.CVar.L.Lock() r.CVar.L.Lock()
defer r.CVar.L.Unlock() defer r.CVar.L.Unlock()
for { for {
@ -83,11 +82,11 @@ func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof boo
continue continue
} }
writeLen := min(bufAvail, len(data)) writeLen := min(bufAvail, len(data))
pk := r.MakeDataPacket(data[0:writeLen], nil) pk := r.M.makeDataPacket(r.FdNum, data[0:writeLen], nil)
pk.Eof = isEof && (writeLen == len(data)) pk.Eof = isEof && (writeLen == len(data))
r.BufSize += writeLen r.BufSize += writeLen
data = data[writeLen:] data = data[writeLen:]
r.sendPacket_unlock(sender, pk) r.sendPacket_unlock(pk)
if len(data) == 0 { if len(data) == 0 {
return true return true
} }
@ -103,25 +102,13 @@ func min(v1 int, v2 int) int {
return v2 return v2
} }
func (r *FdReader) MakeDataPacket(data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = r.SessionId
pk.CmdId = r.CmdId
pk.FdNum = r.FdNum
pk.Data = string(data)
if err != nil {
pk.Error = err.Error()
}
return pk
}
func (r *FdReader) isClosed() bool { func (r *FdReader) isClosed() bool {
r.CVar.L.Lock() r.CVar.L.Lock()
defer r.CVar.L.Unlock() defer r.CVar.L.Unlock()
return r.Closed return r.Closed
} }
func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) { func (r *FdReader) ReadLoop(wg *sync.WaitGroup) {
defer r.Close() defer r.Close()
defer wg.Done() defer wg.Done()
buf := make([]byte, 4096) buf := make([]byte, 4096)
@ -131,14 +118,17 @@ func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) {
return // should not send data or error if we already closed the fd return // should not send data or error if we already closed the fd
} }
if nr > 0 || err == io.EOF { if nr > 0 || err == io.EOF {
isOpen := r.WriteWait(sender, buf[0:nr], (err == io.EOF)) isOpen := r.WriteWait(buf[0:nr], (err == io.EOF))
if !isOpen { if !isOpen {
return return
} }
if err == io.EOF {
return
}
} }
if err != nil { if err != nil {
errPk := r.MakeDataPacket(nil, err) errPk := r.M.makeDataPacket(r.FdNum, nil, err)
sender.SendPacket(errPk) r.M.sendPacket(errPk)
return return
} }
} }

View File

@ -4,37 +4,32 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/. // file, You can obtain one at https://mozilla.org/MPL/2.0/.
package shexec package mpio
import ( import (
"fmt" "fmt"
"os" "os"
"sync" "sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
) )
const MaxSingleWriteSize = 4 * 1024
type FdWriter struct { type FdWriter struct {
CVar *sync.Cond CVar *sync.Cond
SessionId string M *Multiplexer
CmdId string FdNum int
FdNum int Buffer []byte
Buffer []byte Fd *os.File
Fd *os.File Eof bool
Eof bool Closed bool
Closed bool
} }
func MakeFdWriter(c *ShExecType, fd *os.File, fdNum int) *FdWriter { func MakeFdWriter(m *Multiplexer, fd *os.File, fdNum int) *FdWriter {
return &FdWriter{ fw := &FdWriter{
CVar: sync.NewCond(&sync.Mutex{}), CVar: sync.NewCond(&sync.Mutex{}),
Fd: fd, Fd: fd,
SessionId: c.RunPacket.SessionId, M: m,
CmdId: c.RunPacket.CmdId, FdNum: fdNum,
FdNum: fdNum,
} }
return fw
} }
func (w *FdWriter) Close() { func (w *FdWriter) Close() {
@ -64,18 +59,6 @@ func (w *FdWriter) WaitForData() ([]byte, bool) {
} }
} }
func (w *FdWriter) MakeDataAckPacket(ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = w.SessionId
ack.CmdId = w.CmdId
ack.FdNum = w.FdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (w *FdWriter) AddData(data []byte, eof bool) error { func (w *FdWriter) AddData(data []byte, eof bool) error {
w.CVar.L.Lock() w.CVar.L.Lock()
defer w.CVar.L.Unlock() defer w.CVar.L.Unlock()
@ -95,7 +78,7 @@ func (w *FdWriter) AddData(data []byte, eof bool) error {
return nil return nil
} }
func (w *FdWriter) WriteLoop(sender *packet.PacketSender) { func (w *FdWriter) WriteLoop() {
defer w.Close() defer w.Close()
for { for {
data, isEof := w.WaitForData() data, isEof := w.WaitForData()
@ -107,8 +90,8 @@ func (w *FdWriter) WriteLoop(sender *packet.PacketSender) {
chunkSize := min(len(data), MaxSingleWriteSize) chunkSize := min(len(data), MaxSingleWriteSize)
chunk := data[0:chunkSize] chunk := data[0:chunkSize]
nw, err := w.Fd.Write(chunk) nw, err := w.Fd.Write(chunk)
ack := w.MakeDataAckPacket(nw, err) ack := w.M.makeDataAckPacket(w.FdNum, nw, err)
sender.SendPacket(ack) w.M.sendPacket(ack)
if err != nil { if err != nil {
return return
} }

206
pkg/mpio/mpio.go Normal file
View File

@ -0,0 +1,206 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package mpio
import (
"fmt"
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const MaxSingleWriteSize = 4 * 1024
type Multiplexer struct {
Lock *sync.Mutex
SessionId string
CmdId string
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized
Sender *packet.PacketSender
Input chan packet.PacketType
Started bool
}
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
return &Multiplexer{
Lock: &sync.Mutex{},
SessionId: sessionId,
CmdId: cmdId,
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
}
}
func (m *Multiplexer) Close() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fd := range m.FdReaders {
fd.Close()
}
for _, fd := range m.FdWriters {
fd.Close()
}
for _, fd := range m.CloseAfterStart {
fd.Close()
}
}
// returns the *writer* to connect to process, reader is put in FdReaders
func (m *Multiplexer) MakeReaderPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum)
m.CloseAfterStart = append(m.CloseAfterStart, pw)
return pw, nil
}
// returns the *reader* to connect to process, writer is put in FdWriters
func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, pw, fdNum)
m.CloseAfterStart = append(m.CloseAfterStart, pr)
return pr, nil
}
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = m.SessionId
ack.CmdId = m.CmdId
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = m.SessionId
pk.CmdId = m.CmdId
pk.FdNum = fdNum
pk.Data = string(data)
if err != nil {
pk.Error = err.Error()
}
return pk
}
func (m *Multiplexer) sendPacket(p packet.PacketType) {
m.Sender.SendPacket(p)
}
func (m *Multiplexer) launchWriters() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fw := range m.FdWriters {
go fw.WriteLoop()
}
}
func (m *Multiplexer) launchReaders(wg *sync.WaitGroup) {
m.Lock.Lock()
defer m.Lock.Unlock()
wg.Add(len(m.FdReaders))
for _, fr := range m.FdReaders {
go fr.ReadLoop(wg)
}
}
func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.PacketSender) {
m.Lock.Lock()
defer m.Lock.Unlock()
if m.Started {
panic("Multiplexer is already running, cannot start again")
}
m.Input = packetCh
m.Sender = sender
m.Started = true
}
func (m *Multiplexer) runPacketInputLoop() {
for pk := range m.Input {
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := m.processDataPacket(dataPacket)
if err != nil {
errPacket := m.makeDataAckPacket(dataPacket.FdNum, 0, err)
m.sendPacket(errPacket)
}
continue
}
if pk.GetType() == packet.DataAckPacketStr {
ackPacket := pk.(*packet.DataAckPacketType)
m.processAckPacket(ackPacket)
}
// other packet types are ignored
}
}
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
m.Lock.Lock()
defer m.Lock.Unlock()
fw := m.FdWriters[dataPacket.FdNum]
if fw == nil {
// add a closed FdWriter as a placeholder so we only send one error
fw := MakeFdWriter(m, nil, dataPacket.FdNum)
fw.Close()
m.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file")
}
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
if err != nil {
fw.Close()
return err
}
return nil
}
func (m *Multiplexer) processAckPacket(ackPacket *packet.DataAckPacketType) {
m.Lock.Lock()
defer m.Lock.Unlock()
fr := m.FdReaders[ackPacket.FdNum]
if fr == nil {
return
}
fr.NotifyAck(ackPacket.AckLen)
}
func (m *Multiplexer) closeTempStartFds() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fd := range m.CloseAfterStart {
fd.Close()
}
m.CloseAfterStart = nil
}
func (m *Multiplexer) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
m.startIO(packetCh, sender)
m.closeTempStartFds()
var wg sync.WaitGroup
m.launchReaders(&wg)
m.launchWriters()
go m.runPacketInputLoop()
wg.Wait()
}

View File

@ -19,6 +19,7 @@ import (
"github.com/creack/pty" "github.com/creack/pty"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/mpio"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
) )
@ -26,49 +27,33 @@ const DefaultRows = 25
const DefaultCols = 80 const DefaultCols = 80
const MaxRows = 1024 const MaxRows = 1024
const MaxCols = 1024 const MaxCols = 1024
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const MaxFdNum = 1023 const MaxFdNum = 1023
const FirstExtraFilesFdNum = 3 const FirstExtraFilesFdNum = 3
type ShExecType struct { type ShExecType struct {
Lock *sync.Mutex Lock *sync.Mutex
StartTs time.Time StartTs time.Time
RunPacket *packet.RunPacketType RunPacket *packet.RunPacketType
FileNames *base.CommandFileNames FileNames *base.CommandFileNames
Cmd *exec.Cmd Cmd *exec.Cmd
CmdPty *os.File CmdPty *os.File
FdReaders map[int]*FdReader // synchronized Multiplexer *mpio.Multiplexer
FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized
} }
func MakeShExec(pk *packet.RunPacketType) *ShExecType { func MakeShExec(pk *packet.RunPacketType) *ShExecType {
return &ShExecType{ return &ShExecType{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
StartTs: time.Now(), StartTs: time.Now(),
RunPacket: pk, RunPacket: pk,
FdReaders: make(map[int]*FdReader), Multiplexer: mpio.MakeMultiplexer(pk.SessionId, pk.CmdId),
FdWriters: make(map[int]*FdWriter),
} }
} }
func (c *ShExecType) Close() { func (c *ShExecType) Close() {
c.Lock.Lock()
defer c.Lock.Unlock()
if c.CmdPty != nil { if c.CmdPty != nil {
c.CmdPty.Close() c.CmdPty.Close()
} }
for _, fd := range c.FdReaders { c.Multiplexer.Close()
fd.Close()
}
for _, fw := range c.FdWriters {
fw.Close()
}
for _, fd := range c.CloseAfterStart {
fd.Close()
}
} }
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType { func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType {
@ -224,116 +209,9 @@ func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecT
} }
} }
// returns the *writer* to connect to process, reader is put in FdReaders
func (cmd *ShExecType) makeReaderPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
cmd.FdReaders[fdNum] = MakeFdReader(cmd, pr, fdNum)
cmd.CloseAfterStart = append(cmd.CloseAfterStart, pw)
return pw, nil
}
// returns the *reader* to connect to process, writer is put in FdWriters
func (cmd *ShExecType) makeWriterPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
cmd.FdWriters[fdNum] = MakeFdWriter(cmd, pw, fdNum)
cmd.CloseAfterStart = append(cmd.CloseAfterStart, pr)
return pr, nil
}
func (cmd *ShExecType) MakeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = cmd.RunPacket.SessionId
ack.CmdId = cmd.RunPacket.CmdId
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (cmd *ShExecType) launchWriters(sender *packet.PacketSender) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
for _, fw := range cmd.FdWriters {
go fw.WriteLoop(sender)
}
}
func (cmd *ShExecType) processDataPacket(dataPacket *packet.DataPacketType) error {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fw := cmd.FdWriters[dataPacket.FdNum]
if fw == nil {
// add a closed FdWriter as a placeholder so we only send one error
fw := MakeFdWriter(cmd, nil, dataPacket.FdNum)
fw.Close()
cmd.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file")
}
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
if err != nil {
fw.Close()
return err
}
return nil
}
func (cmd *ShExecType) processAckPacket(ackPacket *packet.DataAckPacketType) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fr := cmd.FdReaders[ackPacket.FdNum]
if fr == nil {
return
}
fr.NotifyAck(ackPacket.AckLen)
}
func (cmd *ShExecType) runPacketInputLoop(packetCh chan packet.PacketType, sender *packet.PacketSender) {
for pk := range packetCh {
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := cmd.processDataPacket(dataPacket)
if err != nil {
errPacket := cmd.MakeDataAckPacket(dataPacket.FdNum, 0, err)
sender.SendPacket(errPacket)
}
continue
}
if pk.GetType() == packet.DataAckPacketStr {
ackPacket := pk.(*packet.DataAckPacketType)
cmd.processAckPacket(ackPacket)
}
// other packet types are ignored
}
}
func (cmd *ShExecType) launchReaders(wg *sync.WaitGroup, sender *packet.PacketSender) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
wg.Add(len(cmd.FdReaders))
for _, fr := range cmd.FdReaders {
go fr.ReadLoop(wg, sender)
}
}
func (cmd *ShExecType) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) { func (cmd *ShExecType) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
var wg sync.WaitGroup cmd.Multiplexer.RunIOAndWait(packetCh, sender)
cmd.launchReaders(&wg, sender)
cmd.launchWriters(sender)
go cmd.runPacketInputLoop(packetCh, sender)
donePacket := cmd.WaitForCommand() donePacket := cmd.WaitForCommand()
wg.Wait()
sender.SendPacket(donePacket) sender.SendPacket(donePacket)
} }
@ -345,17 +223,17 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
cmd.Cmd.Dir = pk.Cwd cmd.Cmd.Dir = pk.Cwd
} }
var err error var err error
cmd.Cmd.Stdin, err = cmd.makeWriterPipe(0) cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
if err != nil { if err != nil {
cmd.Close() cmd.Close()
return nil, err return nil, err
} }
cmd.Cmd.Stdout, err = cmd.makeReaderPipe(1) cmd.Cmd.Stdout, err = cmd.Multiplexer.MakeReaderPipe(1)
if err != nil { if err != nil {
cmd.Close() cmd.Close()
return nil, err return nil, err
} }
cmd.Cmd.Stderr, err = cmd.makeReaderPipe(2) cmd.Cmd.Stderr, err = cmd.Multiplexer.MakeReaderPipe(2)
if err != nil { if err != nil {
cmd.Close() cmd.Close()
return nil, err return nil, err
@ -391,7 +269,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
} }
if rfd.Read { if rfd.Read {
// client file is open for reading, so we make a writer pipe // client file is open for reading, so we make a writer pipe
extraFiles[rfd.FdNum], err = cmd.makeWriterPipe(rfd.FdNum) extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)
if err != nil { if err != nil {
cmd.Close() cmd.Close()
return nil, err return nil, err
@ -399,7 +277,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
} }
if rfd.Write { if rfd.Write {
// client file is open for writing, so we make a reader pipe // client file is open for writing, so we make a reader pipe
extraFiles[rfd.FdNum], err = cmd.makeReaderPipe(rfd.FdNum) extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeReaderPipe(rfd.FdNum)
if err != nil { if err != nil {
cmd.Close() cmd.Close()
return nil, err return nil, err
@ -415,10 +293,6 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
cmd.Close() cmd.Close()
return nil, err return nil, err
} }
for _, fd := range cmd.CloseAfterStart {
fd.Close()
}
cmd.CloseAfterStart = nil
return cmd, nil return cmd, nil
} }