add rpc to combined packet parser

This commit is contained in:
sawka 2022-07-06 22:46:59 -07:00
parent 353605f815
commit 2652a3509b
3 changed files with 22 additions and 152 deletions

View File

@ -1,96 +0,0 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package mpio
import (
"encoding/base64"
"errors"
"io"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
type PacketReader struct {
CVar *sync.Cond
FdNum int
Buf []byte
Eof bool
Err error
}
func MakePacketReader(fdNum int) *PacketReader {
return &PacketReader{
CVar: sync.NewCond(&sync.Mutex{}),
FdNum: fdNum,
}
}
func (pr *PacketReader) AddData(pk *packet.DataPacketType) {
pr.CVar.L.Lock()
defer pr.CVar.L.Unlock()
defer pr.CVar.Broadcast()
if pr.Eof || pr.Err != nil {
return
}
if pk.Data64 != "" {
realData, err := base64.StdEncoding.DecodeString(pk.Data64)
if err != nil {
pr.Err = err
return
}
pr.Buf = append(pr.Buf, realData...)
}
pr.Eof = pk.Eof
if pk.Error != "" {
pr.Err = errors.New(pk.Error)
}
return
}
func (pr *PacketReader) Read(buf []byte) (int, error) {
pr.CVar.L.Lock()
defer pr.CVar.L.Unlock()
for {
if pr.Err != nil {
return 0, pr.Err
}
if pr.Eof {
return 0, io.EOF
}
if len(pr.Buf) == 0 {
pr.CVar.Wait()
continue
}
nr := copy(buf, pr.Buf)
pr.Buf = pr.Buf[nr:]
if len(pr.Buf) == 0 {
pr.Buf = nil
}
return nr, nil
}
}
func (pr *PacketReader) Close() error {
pr.CVar.L.Lock()
defer pr.CVar.L.Unlock()
defer pr.CVar.Broadcast()
if pr.Err == nil {
pr.Err = io.ErrClosedPipe
}
return nil
}
type NullReader struct{}
func (NullReader) Read(buf []byte) (int, error) {
return 0, io.EOF
}
func (NullReader) Close() error {
return nil
}

View File

@ -1,40 +0,0 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package mpio
import (
"encoding/base64"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
type PacketWriter struct {
FdNum int
Sender *packet.PacketSender
CK base.CommandKey
}
func MakePacketWriter(fdNum int, sender *packet.PacketSender, ck base.CommandKey) *PacketWriter {
return &PacketWriter{FdNum: fdNum, Sender: sender, CK: ck}
}
func (pw *PacketWriter) Write(data []byte) (int, error) {
pk := packet.MakeDataPacket()
pk.CK = pw.CK
pk.FdNum = pw.FdNum
pk.Data64 = base64.StdEncoding.EncodeToString(data)
return len(data), pw.Sender.SendPacket(pk)
}
func (pw *PacketWriter) Close() error {
pk := packet.MakeDataPacket()
pk.CK = pw.CK
pk.FdNum = pw.FdNum
pk.Eof = true
return pw.Sender.SendPacket(pk)
}

View File

@ -37,14 +37,22 @@ func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser {
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
for v := range p1.MainCh { for pk := range p1.MainCh {
rtnParser.MainCh <- v sent := rtnParser.trySendRpcResponse(pk)
if sent {
continue
}
rtnParser.MainCh <- pk
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
for v := range p2.MainCh { for pk := range p2.MainCh {
rtnParser.MainCh <- v sent := rtnParser.trySendRpcResponse(pk)
if sent {
continue
}
rtnParser.MainCh <- pk
} }
}() }()
go func() { go func() {
@ -56,7 +64,7 @@ func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser {
// should have already registered rpc // should have already registered rpc
func (p *PacketParser) WaitForResponse(ctx context.Context, reqId string) RpcResponsePacketType { func (p *PacketParser) WaitForResponse(ctx context.Context, reqId string) RpcResponsePacketType {
entry := p.getRpcEntry(reqId, false) entry := p.getRpcEntry(reqId)
if entry == nil { if entry == nil {
return nil return nil
} }
@ -92,18 +100,18 @@ func (p *PacketParser) RegisterRpcSz(reqId string, queueSize int) chan RpcRespon
return ch return ch
} }
func (p *PacketParser) getRpcEntry(reqId string, remove bool) *RpcEntry { func (p *PacketParser) getRpcEntry(reqId string) *RpcEntry {
p.Lock.Lock() p.Lock.Lock()
defer p.Lock.Unlock() defer p.Lock.Unlock()
entry := p.RpcMap[reqId] entry := p.RpcMap[reqId]
if entry != nil && remove {
delete(p.RpcMap, reqId)
close(entry.RespCh)
}
return entry return entry
} }
func (p *PacketParser) trySendRpcResponse(respPk RpcResponsePacketType) bool { func (p *PacketParser) trySendRpcResponse(pk PacketType) bool {
respPk, ok := pk.(RpcResponsePacketType)
if !ok {
return false
}
p.Lock.Lock() p.Lock.Lock()
defer p.Lock.Unlock() defer p.Lock.Unlock()
entry := p.RpcMap[respPk.GetResponseId()] entry := p.RpcMap[respPk.GetResponseId()]
@ -185,12 +193,10 @@ func MakePacketParser(input io.Reader) *PacketParser {
if pk.GetType() == PingPacketStr { if pk.GetType() == PingPacketStr {
continue continue
} }
if respPk, ok := pk.(RpcResponsePacketType); ok { sent := parser.trySendRpcResponse(pk)
sent := parser.trySendRpcResponse(respPk)
if sent { if sent {
continue continue
} }
}
parser.MainCh <- pk parser.MainCh <- pk
} }
}() }()