From 657440269141d01d9a741e5e83a107f9ab17cc04 Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 27 Jun 2022 12:14:07 -0700 Subject: [PATCH] create packetparser type, refactor to use --- main-mshell.go | 14 +++---- pkg/mpio/mpio.go | 12 +++--- pkg/packet/packet.go | 73 --------------------------------- pkg/packet/parser.go | 97 ++++++++++++++++++++++++++++++++++++++++++++ pkg/shexec/shexec.go | 14 +++---- 5 files changed, 117 insertions(+), 93 deletions(-) create mode 100644 pkg/packet/parser.go diff --git a/main-mshell.go b/main-mshell.go index 64ff09488..ea8df6151 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -38,10 +38,10 @@ func setupSingleSignals(cmd *shexec.ShExecType) { } func doSingle(ck base.CommandKey) { - packetCh := packet.PacketParser(os.Stdin) + packetParser := packet.MakePacketParser(os.Stdin) sender := packet.MakePacketSender(os.Stdout) var runPacket *packet.RunPacketType - for pk := range packetCh { + for pk := range packetParser.MainCh { if pk.GetType() == packet.PingPacketStr { continue } @@ -156,7 +156,7 @@ func doMain() { packet.SendErrorPacket(os.Stdout, err.Error()) return } - packetCh := packet.PacketParser(os.Stdin) + packetParser := packet.MakePacketParser(os.Stdin) sender := packet.MakePacketSender(os.Stdout) tailer, err := cmdtail.MakeTailer(sender.SendCh) if err != nil { @@ -172,7 +172,7 @@ func doMain() { initPacket.User = user.Username } sender.SendPacket(initPacket) - for pk := range packetCh { + for pk := range packetParser.MainCh { if pk.GetType() == packet.PingPacketStr { continue } @@ -212,7 +212,7 @@ func doMain() { } func handleRemote() { - packetCh := packet.PacketParser(os.Stdin) + packetParser := packet.MakePacketParser(os.Stdin) sender := packet.MakePacketSender(os.Stdout) defer func() { // wait for sender to complete @@ -223,7 +223,7 @@ func handleRemote() { initPacket.Version = MShellVersion sender.SendPacket(initPacket) var runPacket *packet.RunPacketType - for pk := range packetCh { + for pk := range packetParser.MainCh { if pk.GetType() == packet.PingPacketStr { continue } @@ -251,7 +251,7 @@ func handleRemote() { defer cmd.Close() startPacket := cmd.MakeCmdStartPacket() sender.SendPacket(startPacket) - cmd.RunRemoteIOAndWait(packetCh, sender) + cmd.RunRemoteIOAndWait(packetParser, sender) } func handleServer() { diff --git a/pkg/mpio/mpio.go b/pkg/mpio/mpio.go index ef56355f2..227625a18 100644 --- a/pkg/mpio/mpio.go +++ b/pkg/mpio/mpio.go @@ -28,7 +28,7 @@ type Multiplexer struct { CloseAfterStart []*os.File // synchronized Sender *packet.PacketSender - Input chan packet.PacketType + Input *packet.PacketParser Started bool Debug bool @@ -171,20 +171,20 @@ func (m *Multiplexer) launchReaders(wg *sync.WaitGroup) { } } -func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.PacketSender) { +func (m *Multiplexer) startIO(packetParser *packet.PacketParser, sender *packet.PacketSender) { m.Lock.Lock() defer m.Lock.Unlock() if m.Started { panic("Multiplexer is already running, cannot start again") } - m.Input = packetCh + m.Input = packetParser m.Sender = sender m.Started = true } func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType { defer m.HandleInputDone() - for pk := range m.Input { + for pk := range m.Input.MainCh { if m.Debug { fmt.Printf("PK> %s\n", packet.AsString(pk)) } @@ -263,8 +263,8 @@ func (m *Multiplexer) closeTempStartFds() { m.CloseAfterStart = nil } -func (m *Multiplexer) RunIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender, waitOnReaders bool, waitOnWriters bool, waitForInputLoop bool) *packet.CmdDonePacketType { - m.startIO(packetCh, sender) +func (m *Multiplexer) RunIOAndWait(packetParser *packet.PacketParser, sender *packet.PacketSender, waitOnReaders bool, waitOnWriters bool, waitForInputLoop bool) *packet.CmdDonePacketType { + m.startIO(packetParser, sender) m.closeTempStartFds() var wg sync.WaitGroup if waitOnReaders { diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index d996ae9a0..87848fce8 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -7,14 +7,11 @@ package packet import ( - "bufio" "bytes" "encoding/json" "fmt" "io" "reflect" - "strconv" - "strings" "sync" "github.com/scripthaus-dev/mshell/pkg/base" @@ -602,76 +599,6 @@ func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) erro return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...))) } -func CombinePacketParsers(p1 chan PacketType, p2 chan PacketType) chan PacketType { - rtnCh := make(chan PacketType) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - for v := range p1 { - rtnCh <- v - } - }() - go func() { - defer wg.Done() - for v := range p2 { - rtnCh <- v - } - }() - go func() { - wg.Wait() - close(rtnCh) - }() - return rtnCh -} - -func PacketParser(input io.Reader) chan PacketType { - rtnCh := make(chan PacketType) - bufReader := bufio.NewReader(input) - go func() { - defer func() { - close(rtnCh) - }() - for { - line, err := bufReader.ReadString('\n') - if err == io.EOF { - return - } - if err != nil { - errPacket := MakeErrorPacket(fmt.Sprintf("reading packets from input: %v", err)) - rtnCh <- errPacket - return - } - if line == "\n" { - continue - } - // ##[len][json]\n - // ##14{"hello":true}\n - bracePos := strings.Index(line, "{") - if !strings.HasPrefix(line, "##") || bracePos == -1 { - rtnCh <- MakeRawPacket(line[:len(line)-1]) - continue - } - packetLen, err := strconv.Atoi(line[2:bracePos]) - if err != nil || packetLen != len(line)-bracePos-1 { - rtnCh <- MakeRawPacket(line[:len(line)-1]) - continue - } - pk, err := ParseJsonPacket([]byte(line[bracePos:])) - if err != nil { - errPk := MakeErrorPacket(fmt.Sprintf("parsing packet json from input: %v", err)) - rtnCh <- errPk - return - } - if pk.GetType() == DonePacketStr { - return - } - rtnCh <- pk - } - }() - return rtnCh -} - type ErrorReporter interface { ReportError(err error) } diff --git a/pkg/packet/parser.go b/pkg/packet/parser.go new file mode 100644 index 000000000..907ae83fe --- /dev/null +++ b/pkg/packet/parser.go @@ -0,0 +1,97 @@ +// Copyright 2022 Dashborg Inc +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package packet + +import ( + "bufio" + "fmt" + "io" + "strconv" + "strings" + "sync" +) + +type PacketParser struct { + Lock *sync.Mutex + MainCh chan PacketType +} + +func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser { + rtnParser := &PacketParser{ + Lock: &sync.Mutex{}, + MainCh: make(chan PacketType), + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for v := range p1.MainCh { + rtnParser.MainCh <- v + } + }() + go func() { + defer wg.Done() + for v := range p2.MainCh { + rtnParser.MainCh <- v + } + }() + go func() { + wg.Wait() + close(rtnParser.MainCh) + }() + return rtnParser +} + +func MakePacketParser(input io.Reader) *PacketParser { + parser := &PacketParser{ + Lock: &sync.Mutex{}, + MainCh: make(chan PacketType), + } + bufReader := bufio.NewReader(input) + go func() { + defer func() { + close(parser.MainCh) + }() + for { + line, err := bufReader.ReadString('\n') + if err == io.EOF { + return + } + if err != nil { + errPacket := MakeErrorPacket(fmt.Sprintf("reading packets from input: %v", err)) + parser.MainCh <- errPacket + return + } + if line == "\n" { + continue + } + // ##[len][json]\n + // ##14{"hello":true}\n + bracePos := strings.Index(line, "{") + if !strings.HasPrefix(line, "##") || bracePos == -1 { + parser.MainCh <- MakeRawPacket(line[:len(line)-1]) + continue + } + packetLen, err := strconv.Atoi(line[2:bracePos]) + if err != nil || packetLen != len(line)-bracePos-1 { + parser.MainCh <- MakeRawPacket(line[:len(line)-1]) + continue + } + pk, err := ParseJsonPacket([]byte(line[bracePos:])) + if err != nil { + errPk := MakeErrorPacket(fmt.Sprintf("parsing packet json from input: %v", err)) + parser.MainCh <- errPk + return + } + if pk.GetType() == DonePacketStr { + return + } + parser.MainCh <- pk + } + }() + return parser +} diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 5ea68d2b3..fad15ca77 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -373,12 +373,12 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er return nil, fmt.Errorf("running ssh command: %w", err) } defer cmd.Close() - stdoutPacketCh := packet.PacketParser(stdoutReader) - stderrPacketCh := packet.PacketParser(stderrReader) - packetCh := packet.CombinePacketParsers(stdoutPacketCh, stderrPacketCh) + stdoutPacketParser := packet.MakePacketParser(stdoutReader) + stderrPacketParser := packet.MakePacketParser(stderrReader) + packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser) sender := packet.MakePacketSender(inputWriter) versionOk := false - for pk := range packetCh { + for pk := range packetParser.MainCh { if pk.GetType() == packet.RawPacketStr { rawPk := pk.(*packet.RawPacketType) fmt.Printf("%s\n", rawPk.Data) @@ -400,7 +400,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er if opts.Debug { cmd.Multiplexer.Debug = true } - remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true) + remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetParser, sender, false, true, true) donePacket := cmd.WaitForCommand() if remoteDonePacket != nil { donePacket = remoteDonePacket @@ -408,9 +408,9 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er return donePacket, nil } -func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) { +func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sender *packet.PacketSender) { defer cmd.Close() - cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false) + cmd.Multiplexer.RunIOAndWait(packetParser, sender, true, false, false) donePacket := cmd.WaitForCommand() sender.SendPacket(donePacket) }