2022-06-24 19:24:02 +02:00
|
|
|
// Copyright 2022 Dashborg Inc
|
|
|
|
//
|
|
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
|
|
|
|
|
|
package mpio
|
|
|
|
|
|
|
|
import (
|
2022-06-25 08:42:00 +02:00
|
|
|
"encoding/base64"
|
2022-06-24 19:24:02 +02:00
|
|
|
"fmt"
|
|
|
|
"os"
|
|
|
|
"sync"
|
|
|
|
|
|
|
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
|
|
|
)
|
|
|
|
|
2022-06-25 08:42:00 +02:00
|
|
|
const ReadBufSize = 32 * 1024
|
|
|
|
const WriteBufSize = 32 * 1024
|
2022-06-24 19:24:02 +02:00
|
|
|
const MaxSingleWriteSize = 4 * 1024
|
|
|
|
|
|
|
|
type Multiplexer struct {
|
|
|
|
Lock *sync.Mutex
|
|
|
|
SessionId string
|
|
|
|
CmdId string
|
|
|
|
FdReaders map[int]*FdReader // synchronized
|
|
|
|
FdWriters map[int]*FdWriter // synchronized
|
|
|
|
CloseAfterStart []*os.File // synchronized
|
|
|
|
|
|
|
|
Sender *packet.PacketSender
|
|
|
|
Input chan packet.PacketType
|
|
|
|
Started bool
|
2022-06-25 08:42:00 +02:00
|
|
|
|
|
|
|
Debug bool
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
|
|
|
|
return &Multiplexer{
|
|
|
|
Lock: &sync.Mutex{},
|
|
|
|
SessionId: sessionId,
|
|
|
|
CmdId: cmdId,
|
|
|
|
FdReaders: make(map[int]*FdReader),
|
|
|
|
FdWriters: make(map[int]*FdWriter),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) Close() {
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
|
2022-06-24 22:25:09 +02:00
|
|
|
for _, fr := range m.FdReaders {
|
|
|
|
fr.Close()
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
2022-06-24 22:25:09 +02:00
|
|
|
for _, fw := range m.FdWriters {
|
|
|
|
fw.Close()
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
|
|
|
for _, fd := range m.CloseAfterStart {
|
|
|
|
fd.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-06-24 22:25:09 +02:00
|
|
|
func (m *Multiplexer) HandleInputDone() {
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
|
|
|
|
// close readers (obviously the done command needs no more input)
|
|
|
|
for _, fr := range m.FdReaders {
|
|
|
|
fr.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
// ensure EOF on all writers (ignore error)
|
|
|
|
for _, fw := range m.FdWriters {
|
|
|
|
fw.AddData(nil, true)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-06-24 19:24:02 +02:00
|
|
|
// returns the *writer* to connect to process, reader is put in FdReaders
|
|
|
|
func (m *Multiplexer) MakeReaderPipe(fdNum int) (*os.File, error) {
|
|
|
|
pr, pw, err := os.Pipe()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
2022-06-24 22:25:09 +02:00
|
|
|
m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum, true)
|
2022-06-24 19:24:02 +02:00
|
|
|
m.CloseAfterStart = append(m.CloseAfterStart, pw)
|
|
|
|
return pw, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// returns the *reader* to connect to process, writer is put in FdWriters
|
|
|
|
func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
|
|
|
|
pr, pw, err := os.Pipe()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
2022-06-24 22:25:09 +02:00
|
|
|
m.FdWriters[fdNum] = MakeFdWriter(m, pw, fdNum, true)
|
2022-06-24 19:24:02 +02:00
|
|
|
m.CloseAfterStart = append(m.CloseAfterStart, pr)
|
|
|
|
return pr, nil
|
|
|
|
}
|
|
|
|
|
2022-06-26 10:41:58 +02:00
|
|
|
func (m *Multiplexer) MakeStringFdReader(fdNum int, contents string) error {
|
|
|
|
pw, err := m.MakeReaderPipe(fdNum)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
go func() {
|
|
|
|
pw.Write([]byte(contents))
|
|
|
|
pw.Close()
|
|
|
|
}()
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2022-06-25 08:42:00 +02:00
|
|
|
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) {
|
2022-06-24 22:25:09 +02:00
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
2022-06-25 08:42:00 +02:00
|
|
|
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
|
2022-06-24 22:25:09 +02:00
|
|
|
}
|
|
|
|
|
2022-06-25 08:42:00 +02:00
|
|
|
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) {
|
2022-06-24 22:25:09 +02:00
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
2022-06-25 08:42:00 +02:00
|
|
|
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)
|
2022-06-24 22:25:09 +02:00
|
|
|
}
|
|
|
|
|
2022-06-24 19:24:02 +02:00
|
|
|
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
|
|
|
|
ack := packet.MakeDataAckPacket()
|
|
|
|
ack.SessionId = m.SessionId
|
|
|
|
ack.CmdId = m.CmdId
|
|
|
|
ack.FdNum = fdNum
|
|
|
|
ack.AckLen = ackLen
|
|
|
|
if err != nil {
|
|
|
|
ack.Error = err.Error()
|
|
|
|
}
|
|
|
|
return ack
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
|
|
|
|
pk := packet.MakeDataPacket()
|
|
|
|
pk.SessionId = m.SessionId
|
|
|
|
pk.CmdId = m.CmdId
|
|
|
|
pk.FdNum = fdNum
|
2022-06-25 08:42:00 +02:00
|
|
|
pk.Data64 = base64.StdEncoding.EncodeToString(data)
|
2022-06-24 19:24:02 +02:00
|
|
|
if err != nil {
|
|
|
|
pk.Error = err.Error()
|
|
|
|
}
|
|
|
|
return pk
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) sendPacket(p packet.PacketType) {
|
|
|
|
m.Sender.SendPacket(p)
|
|
|
|
}
|
|
|
|
|
2022-06-24 22:25:09 +02:00
|
|
|
func (m *Multiplexer) launchWriters(wg *sync.WaitGroup) {
|
2022-06-24 19:24:02 +02:00
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
2022-06-24 22:25:09 +02:00
|
|
|
if wg != nil {
|
|
|
|
wg.Add(len(m.FdWriters))
|
|
|
|
}
|
2022-06-24 19:24:02 +02:00
|
|
|
for _, fw := range m.FdWriters {
|
2022-06-24 22:25:09 +02:00
|
|
|
go fw.WriteLoop(wg)
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) launchReaders(wg *sync.WaitGroup) {
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
2022-06-24 22:25:09 +02:00
|
|
|
if wg != nil {
|
|
|
|
wg.Add(len(m.FdReaders))
|
|
|
|
}
|
2022-06-24 19:24:02 +02:00
|
|
|
for _, fr := range m.FdReaders {
|
|
|
|
go fr.ReadLoop(wg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.PacketSender) {
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
if m.Started {
|
|
|
|
panic("Multiplexer is already running, cannot start again")
|
|
|
|
}
|
|
|
|
m.Input = packetCh
|
|
|
|
m.Sender = sender
|
|
|
|
m.Started = true
|
|
|
|
}
|
|
|
|
|
2022-06-24 22:25:09 +02:00
|
|
|
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
|
|
|
defer m.HandleInputDone()
|
2022-06-24 19:24:02 +02:00
|
|
|
for pk := range m.Input {
|
2022-06-25 08:42:00 +02:00
|
|
|
if m.Debug {
|
|
|
|
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
|
|
|
}
|
2022-06-24 19:24:02 +02:00
|
|
|
if pk.GetType() == packet.DataPacketStr {
|
|
|
|
dataPacket := pk.(*packet.DataPacketType)
|
|
|
|
err := m.processDataPacket(dataPacket)
|
|
|
|
if err != nil {
|
|
|
|
errPacket := m.makeDataAckPacket(dataPacket.FdNum, 0, err)
|
|
|
|
m.sendPacket(errPacket)
|
|
|
|
}
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if pk.GetType() == packet.DataAckPacketStr {
|
|
|
|
ackPacket := pk.(*packet.DataAckPacketType)
|
|
|
|
m.processAckPacket(ackPacket)
|
2022-06-24 22:25:09 +02:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
if pk.GetType() == packet.CmdDonePacketStr {
|
|
|
|
donePacket := pk.(*packet.CmdDonePacketType)
|
|
|
|
return donePacket
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
2022-06-25 08:42:00 +02:00
|
|
|
if pk.GetType() == packet.ErrorPacketStr {
|
|
|
|
errPacket := pk.(*packet.ErrorPacketType)
|
|
|
|
// at this point, just send the error packet to stderr rather than try to do something special
|
|
|
|
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
if pk.GetType() == packet.RawPacketStr {
|
|
|
|
rawPacket := pk.(*packet.RawPacketType)
|
|
|
|
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
|
|
|
|
continue
|
|
|
|
}
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
2022-06-24 22:25:09 +02:00
|
|
|
return nil
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
|
2022-06-25 08:42:00 +02:00
|
|
|
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("decoding base64 data: %w", err)
|
|
|
|
}
|
2022-06-24 19:24:02 +02:00
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
fw := m.FdWriters[dataPacket.FdNum]
|
|
|
|
if fw == nil {
|
|
|
|
// add a closed FdWriter as a placeholder so we only send one error
|
2022-06-24 22:25:09 +02:00
|
|
|
fw := MakeFdWriter(m, nil, dataPacket.FdNum, false)
|
2022-06-24 19:24:02 +02:00
|
|
|
fw.Close()
|
|
|
|
m.FdWriters[dataPacket.FdNum] = fw
|
2022-06-25 08:42:00 +02:00
|
|
|
return fmt.Errorf("write to closed file (no fd)")
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|
2022-06-25 08:42:00 +02:00
|
|
|
err = fw.AddData(realData, dataPacket.Eof)
|
2022-06-24 19:24:02 +02:00
|
|
|
if err != nil {
|
|
|
|
fw.Close()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) processAckPacket(ackPacket *packet.DataAckPacketType) {
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
fr := m.FdReaders[ackPacket.FdNum]
|
|
|
|
if fr == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
fr.NotifyAck(ackPacket.AckLen)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Multiplexer) closeTempStartFds() {
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
for _, fd := range m.CloseAfterStart {
|
|
|
|
fd.Close()
|
|
|
|
}
|
|
|
|
m.CloseAfterStart = nil
|
|
|
|
}
|
|
|
|
|
2022-06-24 22:25:09 +02:00
|
|
|
func (m *Multiplexer) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender, waitOnReaders bool, waitOnWriters bool, waitForInputLoop bool) *packet.CmdDonePacketType {
|
2022-06-24 19:24:02 +02:00
|
|
|
m.startIO(packetCh, sender)
|
|
|
|
m.closeTempStartFds()
|
|
|
|
var wg sync.WaitGroup
|
2022-06-24 22:25:09 +02:00
|
|
|
if waitOnReaders {
|
|
|
|
m.launchReaders(&wg)
|
|
|
|
} else {
|
|
|
|
m.launchReaders(nil)
|
|
|
|
}
|
|
|
|
if waitOnWriters {
|
|
|
|
m.launchWriters(&wg)
|
|
|
|
} else {
|
|
|
|
m.launchWriters(nil)
|
|
|
|
}
|
|
|
|
var donePacket *packet.CmdDonePacketType
|
|
|
|
if waitForInputLoop {
|
|
|
|
wg.Add(1)
|
|
|
|
}
|
|
|
|
go func() {
|
|
|
|
if waitForInputLoop {
|
|
|
|
defer wg.Done()
|
|
|
|
}
|
|
|
|
pkRtn := m.runPacketInputLoop()
|
|
|
|
if pkRtn != nil {
|
|
|
|
m.Lock.Lock()
|
|
|
|
donePacket = pkRtn
|
|
|
|
m.Lock.Unlock()
|
|
|
|
}
|
|
|
|
}()
|
2022-06-24 19:24:02 +02:00
|
|
|
wg.Wait()
|
2022-06-24 22:25:09 +02:00
|
|
|
|
|
|
|
m.Lock.Lock()
|
|
|
|
defer m.Lock.Unlock()
|
|
|
|
return donePacket
|
2022-06-24 19:24:02 +02:00
|
|
|
}
|