diff --git a/pkg/mpio/packetreader.go b/pkg/mpio/packetreader.go deleted file mode 100644 index 8dfb30241..000000000 --- a/pkg/mpio/packetreader.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2022 Dashborg Inc -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -package mpio - -import ( - "encoding/base64" - "errors" - "io" - "sync" - - "github.com/scripthaus-dev/mshell/pkg/packet" -) - -type PacketReader struct { - CVar *sync.Cond - FdNum int - Buf []byte - Eof bool - Err error -} - -func MakePacketReader(fdNum int) *PacketReader { - return &PacketReader{ - CVar: sync.NewCond(&sync.Mutex{}), - FdNum: fdNum, - } -} - -func (pr *PacketReader) AddData(pk *packet.DataPacketType) { - pr.CVar.L.Lock() - defer pr.CVar.L.Unlock() - defer pr.CVar.Broadcast() - if pr.Eof || pr.Err != nil { - return - } - if pk.Data64 != "" { - realData, err := base64.StdEncoding.DecodeString(pk.Data64) - if err != nil { - pr.Err = err - return - } - pr.Buf = append(pr.Buf, realData...) - } - pr.Eof = pk.Eof - if pk.Error != "" { - pr.Err = errors.New(pk.Error) - } - return -} - -func (pr *PacketReader) Read(buf []byte) (int, error) { - pr.CVar.L.Lock() - defer pr.CVar.L.Unlock() - for { - if pr.Err != nil { - return 0, pr.Err - } - if pr.Eof { - return 0, io.EOF - } - if len(pr.Buf) == 0 { - pr.CVar.Wait() - continue - } - nr := copy(buf, pr.Buf) - pr.Buf = pr.Buf[nr:] - if len(pr.Buf) == 0 { - pr.Buf = nil - } - return nr, nil - } -} - -func (pr *PacketReader) Close() error { - pr.CVar.L.Lock() - defer pr.CVar.L.Unlock() - defer pr.CVar.Broadcast() - if pr.Err == nil { - pr.Err = io.ErrClosedPipe - } - return nil -} - -type NullReader struct{} - -func (NullReader) Read(buf []byte) (int, error) { - return 0, io.EOF -} - -func (NullReader) Close() error { - return nil -} diff --git a/pkg/mpio/packetwriter.go b/pkg/mpio/packetwriter.go deleted file mode 100644 index 0665f0441..000000000 --- a/pkg/mpio/packetwriter.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 Dashborg Inc -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -package mpio - -import ( - "encoding/base64" - - "github.com/scripthaus-dev/mshell/pkg/base" - "github.com/scripthaus-dev/mshell/pkg/packet" -) - -type PacketWriter struct { - FdNum int - Sender *packet.PacketSender - CK base.CommandKey -} - -func MakePacketWriter(fdNum int, sender *packet.PacketSender, ck base.CommandKey) *PacketWriter { - return &PacketWriter{FdNum: fdNum, Sender: sender, CK: ck} -} - -func (pw *PacketWriter) Write(data []byte) (int, error) { - pk := packet.MakeDataPacket() - pk.CK = pw.CK - pk.FdNum = pw.FdNum - pk.Data64 = base64.StdEncoding.EncodeToString(data) - return len(data), pw.Sender.SendPacket(pk) -} - -func (pw *PacketWriter) Close() error { - pk := packet.MakeDataPacket() - pk.CK = pw.CK - pk.FdNum = pw.FdNum - pk.Eof = true - return pw.Sender.SendPacket(pk) -} diff --git a/pkg/packet/parser.go b/pkg/packet/parser.go index d405d3b03..2e4199c22 100644 --- a/pkg/packet/parser.go +++ b/pkg/packet/parser.go @@ -37,14 +37,22 @@ func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser { wg.Add(2) go func() { defer wg.Done() - for v := range p1.MainCh { - rtnParser.MainCh <- v + for pk := range p1.MainCh { + sent := rtnParser.trySendRpcResponse(pk) + if sent { + continue + } + rtnParser.MainCh <- pk } }() go func() { defer wg.Done() - for v := range p2.MainCh { - rtnParser.MainCh <- v + for pk := range p2.MainCh { + sent := rtnParser.trySendRpcResponse(pk) + if sent { + continue + } + rtnParser.MainCh <- pk } }() go func() { @@ -56,7 +64,7 @@ func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser { // should have already registered rpc func (p *PacketParser) WaitForResponse(ctx context.Context, reqId string) RpcResponsePacketType { - entry := p.getRpcEntry(reqId, false) + entry := p.getRpcEntry(reqId) if entry == nil { return nil } @@ -92,18 +100,18 @@ func (p *PacketParser) RegisterRpcSz(reqId string, queueSize int) chan RpcRespon return ch } -func (p *PacketParser) getRpcEntry(reqId string, remove bool) *RpcEntry { +func (p *PacketParser) getRpcEntry(reqId string) *RpcEntry { p.Lock.Lock() defer p.Lock.Unlock() entry := p.RpcMap[reqId] - if entry != nil && remove { - delete(p.RpcMap, reqId) - close(entry.RespCh) - } return entry } -func (p *PacketParser) trySendRpcResponse(respPk RpcResponsePacketType) bool { +func (p *PacketParser) trySendRpcResponse(pk PacketType) bool { + respPk, ok := pk.(RpcResponsePacketType) + if !ok { + return false + } p.Lock.Lock() defer p.Lock.Unlock() entry := p.RpcMap[respPk.GetResponseId()] @@ -185,11 +193,9 @@ func MakePacketParser(input io.Reader) *PacketParser { if pk.GetType() == PingPacketStr { continue } - if respPk, ok := pk.(RpcResponsePacketType); ok { - sent := parser.trySendRpcResponse(respPk) - if sent { - continue - } + sent := parser.trySendRpcResponse(pk) + if sent { + continue } parser.MainCh <- pk }