move multiplexed IO to its own package independent of SHExecType (to use in mshell client)

This commit is contained in:
sawka 2022-06-24 10:24:02 -07:00
parent 4256ff5231
commit 0267836376
4 changed files with 270 additions and 217 deletions

View File

@ -4,7 +4,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package shexec
package mpio
import (
"io"
@ -16,23 +16,22 @@ import (
type FdReader struct {
CVar *sync.Cond
SessionId string
CmdId string
M *Multiplexer
FdNum int
Fd *os.File
BufSize int
Closed bool
}
func MakeFdReader(c *ShExecType, fd *os.File, fdNum int) *FdReader {
return &FdReader{
func MakeFdReader(m *Multiplexer, fd *os.File, fdNum int) *FdReader {
fr := &FdReader{
CVar: sync.NewCond(&sync.Mutex{}),
SessionId: c.RunPacket.SessionId,
CmdId: c.RunPacket.CmdId,
M: m,
FdNum: fdNum,
Fd: fd,
BufSize: 0,
}
return fr
}
func (r *FdReader) Close() {
@ -63,14 +62,14 @@ func (r *FdReader) NotifyAck(ackLen int) {
// !! inverse locking. must already hold the lock when you call this method.
// will *unlock*, send the packet, and then *relock* once it is done.
// this can prevent an unlikely deadlock where we are holding r.CVar.L and stuck on sender.SendCh
func (r *FdReader) sendPacket_unlock(sender *packet.PacketSender, pk packet.PacketType) {
func (r *FdReader) sendPacket_unlock(pk packet.PacketType) {
r.CVar.L.Unlock()
defer r.CVar.L.Lock()
sender.SendPacket(pk)
r.M.sendPacket(pk)
}
// returns (success)
func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof bool) bool {
func (r *FdReader) WriteWait(data []byte, isEof bool) bool {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
for {
@ -83,11 +82,11 @@ func (r *FdReader) WriteWait(sender *packet.PacketSender, data []byte, isEof boo
continue
}
writeLen := min(bufAvail, len(data))
pk := r.MakeDataPacket(data[0:writeLen], nil)
pk := r.M.makeDataPacket(r.FdNum, data[0:writeLen], nil)
pk.Eof = isEof && (writeLen == len(data))
r.BufSize += writeLen
data = data[writeLen:]
r.sendPacket_unlock(sender, pk)
r.sendPacket_unlock(pk)
if len(data) == 0 {
return true
}
@ -103,25 +102,13 @@ func min(v1 int, v2 int) int {
return v2
}
func (r *FdReader) MakeDataPacket(data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = r.SessionId
pk.CmdId = r.CmdId
pk.FdNum = r.FdNum
pk.Data = string(data)
if err != nil {
pk.Error = err.Error()
}
return pk
}
func (r *FdReader) isClosed() bool {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
return r.Closed
}
func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) {
func (r *FdReader) ReadLoop(wg *sync.WaitGroup) {
defer r.Close()
defer wg.Done()
buf := make([]byte, 4096)
@ -131,14 +118,17 @@ func (r *FdReader) ReadLoop(wg *sync.WaitGroup, sender *packet.PacketSender) {
return // should not send data or error if we already closed the fd
}
if nr > 0 || err == io.EOF {
isOpen := r.WriteWait(sender, buf[0:nr], (err == io.EOF))
isOpen := r.WriteWait(buf[0:nr], (err == io.EOF))
if !isOpen {
return
}
if err == io.EOF {
return
}
}
if err != nil {
errPk := r.MakeDataPacket(nil, err)
sender.SendPacket(errPk)
errPk := r.M.makeDataPacket(r.FdNum, nil, err)
r.M.sendPacket(errPk)
return
}
}

View File

@ -4,22 +4,17 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package shexec
package mpio
import (
"fmt"
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const MaxSingleWriteSize = 4 * 1024
type FdWriter struct {
CVar *sync.Cond
SessionId string
CmdId string
M *Multiplexer
FdNum int
Buffer []byte
Fd *os.File
@ -27,14 +22,14 @@ type FdWriter struct {
Closed bool
}
func MakeFdWriter(c *ShExecType, fd *os.File, fdNum int) *FdWriter {
return &FdWriter{
func MakeFdWriter(m *Multiplexer, fd *os.File, fdNum int) *FdWriter {
fw := &FdWriter{
CVar: sync.NewCond(&sync.Mutex{}),
Fd: fd,
SessionId: c.RunPacket.SessionId,
CmdId: c.RunPacket.CmdId,
M: m,
FdNum: fdNum,
}
return fw
}
func (w *FdWriter) Close() {
@ -64,18 +59,6 @@ func (w *FdWriter) WaitForData() ([]byte, bool) {
}
}
func (w *FdWriter) MakeDataAckPacket(ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = w.SessionId
ack.CmdId = w.CmdId
ack.FdNum = w.FdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (w *FdWriter) AddData(data []byte, eof bool) error {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
@ -95,7 +78,7 @@ func (w *FdWriter) AddData(data []byte, eof bool) error {
return nil
}
func (w *FdWriter) WriteLoop(sender *packet.PacketSender) {
func (w *FdWriter) WriteLoop() {
defer w.Close()
for {
data, isEof := w.WaitForData()
@ -107,8 +90,8 @@ func (w *FdWriter) WriteLoop(sender *packet.PacketSender) {
chunkSize := min(len(data), MaxSingleWriteSize)
chunk := data[0:chunkSize]
nw, err := w.Fd.Write(chunk)
ack := w.MakeDataAckPacket(nw, err)
sender.SendPacket(ack)
ack := w.M.makeDataAckPacket(w.FdNum, nw, err)
w.M.sendPacket(ack)
if err != nil {
return
}

206
pkg/mpio/mpio.go Normal file
View File

@ -0,0 +1,206 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package mpio
import (
"fmt"
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const MaxSingleWriteSize = 4 * 1024
type Multiplexer struct {
Lock *sync.Mutex
SessionId string
CmdId string
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized
Sender *packet.PacketSender
Input chan packet.PacketType
Started bool
}
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
return &Multiplexer{
Lock: &sync.Mutex{},
SessionId: sessionId,
CmdId: cmdId,
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
}
}
func (m *Multiplexer) Close() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fd := range m.FdReaders {
fd.Close()
}
for _, fd := range m.FdWriters {
fd.Close()
}
for _, fd := range m.CloseAfterStart {
fd.Close()
}
}
// returns the *writer* to connect to process, reader is put in FdReaders
func (m *Multiplexer) MakeReaderPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum)
m.CloseAfterStart = append(m.CloseAfterStart, pw)
return pw, nil
}
// returns the *reader* to connect to process, writer is put in FdWriters
func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, pw, fdNum)
m.CloseAfterStart = append(m.CloseAfterStart, pr)
return pr, nil
}
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = m.SessionId
ack.CmdId = m.CmdId
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = m.SessionId
pk.CmdId = m.CmdId
pk.FdNum = fdNum
pk.Data = string(data)
if err != nil {
pk.Error = err.Error()
}
return pk
}
func (m *Multiplexer) sendPacket(p packet.PacketType) {
m.Sender.SendPacket(p)
}
func (m *Multiplexer) launchWriters() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fw := range m.FdWriters {
go fw.WriteLoop()
}
}
func (m *Multiplexer) launchReaders(wg *sync.WaitGroup) {
m.Lock.Lock()
defer m.Lock.Unlock()
wg.Add(len(m.FdReaders))
for _, fr := range m.FdReaders {
go fr.ReadLoop(wg)
}
}
func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.PacketSender) {
m.Lock.Lock()
defer m.Lock.Unlock()
if m.Started {
panic("Multiplexer is already running, cannot start again")
}
m.Input = packetCh
m.Sender = sender
m.Started = true
}
func (m *Multiplexer) runPacketInputLoop() {
for pk := range m.Input {
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := m.processDataPacket(dataPacket)
if err != nil {
errPacket := m.makeDataAckPacket(dataPacket.FdNum, 0, err)
m.sendPacket(errPacket)
}
continue
}
if pk.GetType() == packet.DataAckPacketStr {
ackPacket := pk.(*packet.DataAckPacketType)
m.processAckPacket(ackPacket)
}
// other packet types are ignored
}
}
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
m.Lock.Lock()
defer m.Lock.Unlock()
fw := m.FdWriters[dataPacket.FdNum]
if fw == nil {
// add a closed FdWriter as a placeholder so we only send one error
fw := MakeFdWriter(m, nil, dataPacket.FdNum)
fw.Close()
m.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file")
}
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
if err != nil {
fw.Close()
return err
}
return nil
}
func (m *Multiplexer) processAckPacket(ackPacket *packet.DataAckPacketType) {
m.Lock.Lock()
defer m.Lock.Unlock()
fr := m.FdReaders[ackPacket.FdNum]
if fr == nil {
return
}
fr.NotifyAck(ackPacket.AckLen)
}
func (m *Multiplexer) closeTempStartFds() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fd := range m.CloseAfterStart {
fd.Close()
}
m.CloseAfterStart = nil
}
func (m *Multiplexer) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
m.startIO(packetCh, sender)
m.closeTempStartFds()
var wg sync.WaitGroup
m.launchReaders(&wg)
m.launchWriters()
go m.runPacketInputLoop()
wg.Wait()
}

View File

@ -19,6 +19,7 @@ import (
"github.com/creack/pty"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/mpio"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
@ -26,8 +27,6 @@ const DefaultRows = 25
const DefaultCols = 80
const MaxRows = 1024
const MaxCols = 1024
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const MaxFdNum = 1023
const FirstExtraFilesFdNum = 3
@ -38,9 +37,7 @@ type ShExecType struct {
FileNames *base.CommandFileNames
Cmd *exec.Cmd
CmdPty *os.File
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized
Multiplexer *mpio.Multiplexer
}
func MakeShExec(pk *packet.RunPacketType) *ShExecType {
@ -48,27 +45,15 @@ func MakeShExec(pk *packet.RunPacketType) *ShExecType {
Lock: &sync.Mutex{},
StartTs: time.Now(),
RunPacket: pk,
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
Multiplexer: mpio.MakeMultiplexer(pk.SessionId, pk.CmdId),
}
}
func (c *ShExecType) Close() {
c.Lock.Lock()
defer c.Lock.Unlock()
if c.CmdPty != nil {
c.CmdPty.Close()
}
for _, fd := range c.FdReaders {
fd.Close()
}
for _, fw := range c.FdWriters {
fw.Close()
}
for _, fd := range c.CloseAfterStart {
fd.Close()
}
c.Multiplexer.Close()
}
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType {
@ -224,116 +209,9 @@ func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecT
}
}
// returns the *writer* to connect to process, reader is put in FdReaders
func (cmd *ShExecType) makeReaderPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
cmd.FdReaders[fdNum] = MakeFdReader(cmd, pr, fdNum)
cmd.CloseAfterStart = append(cmd.CloseAfterStart, pw)
return pw, nil
}
// returns the *reader* to connect to process, writer is put in FdWriters
func (cmd *ShExecType) makeWriterPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
cmd.FdWriters[fdNum] = MakeFdWriter(cmd, pw, fdNum)
cmd.CloseAfterStart = append(cmd.CloseAfterStart, pr)
return pr, nil
}
func (cmd *ShExecType) MakeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = cmd.RunPacket.SessionId
ack.CmdId = cmd.RunPacket.CmdId
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (cmd *ShExecType) launchWriters(sender *packet.PacketSender) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
for _, fw := range cmd.FdWriters {
go fw.WriteLoop(sender)
}
}
func (cmd *ShExecType) processDataPacket(dataPacket *packet.DataPacketType) error {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fw := cmd.FdWriters[dataPacket.FdNum]
if fw == nil {
// add a closed FdWriter as a placeholder so we only send one error
fw := MakeFdWriter(cmd, nil, dataPacket.FdNum)
fw.Close()
cmd.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file")
}
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
if err != nil {
fw.Close()
return err
}
return nil
}
func (cmd *ShExecType) processAckPacket(ackPacket *packet.DataAckPacketType) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
fr := cmd.FdReaders[ackPacket.FdNum]
if fr == nil {
return
}
fr.NotifyAck(ackPacket.AckLen)
}
func (cmd *ShExecType) runPacketInputLoop(packetCh chan packet.PacketType, sender *packet.PacketSender) {
for pk := range packetCh {
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := cmd.processDataPacket(dataPacket)
if err != nil {
errPacket := cmd.MakeDataAckPacket(dataPacket.FdNum, 0, err)
sender.SendPacket(errPacket)
}
continue
}
if pk.GetType() == packet.DataAckPacketStr {
ackPacket := pk.(*packet.DataAckPacketType)
cmd.processAckPacket(ackPacket)
}
// other packet types are ignored
}
}
func (cmd *ShExecType) launchReaders(wg *sync.WaitGroup, sender *packet.PacketSender) {
cmd.Lock.Lock()
defer cmd.Lock.Unlock()
wg.Add(len(cmd.FdReaders))
for _, fr := range cmd.FdReaders {
go fr.ReadLoop(wg, sender)
}
}
func (cmd *ShExecType) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
var wg sync.WaitGroup
cmd.launchReaders(&wg, sender)
cmd.launchWriters(sender)
go cmd.runPacketInputLoop(packetCh, sender)
cmd.Multiplexer.RunIOAndWait(packetCh, sender)
donePacket := cmd.WaitForCommand()
wg.Wait()
sender.SendPacket(donePacket)
}
@ -345,17 +223,17 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
cmd.Cmd.Dir = pk.Cwd
}
var err error
cmd.Cmd.Stdin, err = cmd.makeWriterPipe(0)
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
if err != nil {
cmd.Close()
return nil, err
}
cmd.Cmd.Stdout, err = cmd.makeReaderPipe(1)
cmd.Cmd.Stdout, err = cmd.Multiplexer.MakeReaderPipe(1)
if err != nil {
cmd.Close()
return nil, err
}
cmd.Cmd.Stderr, err = cmd.makeReaderPipe(2)
cmd.Cmd.Stderr, err = cmd.Multiplexer.MakeReaderPipe(2)
if err != nil {
cmd.Close()
return nil, err
@ -391,7 +269,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
}
if rfd.Read {
// client file is open for reading, so we make a writer pipe
extraFiles[rfd.FdNum], err = cmd.makeWriterPipe(rfd.FdNum)
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)
if err != nil {
cmd.Close()
return nil, err
@ -399,7 +277,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
}
if rfd.Write {
// client file is open for writing, so we make a reader pipe
extraFiles[rfd.FdNum], err = cmd.makeReaderPipe(rfd.FdNum)
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeReaderPipe(rfd.FdNum)
if err != nil {
cmd.Close()
return nil, err
@ -415,10 +293,6 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
cmd.Close()
return nil, err
}
for _, fd := range cmd.CloseAfterStart {
fd.Close()
}
cmd.CloseAfterStart = nil
return cmd, nil
}