// Copyright 2022 Dashborg Inc // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. package packet import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "io" "os" "reflect" "sync" "github.com/scripthaus-dev/mshell/pkg/base" ) // single : run, >cmddata, >cmddone, data, <>dataack, run, >cmddata, >cmddone, run, >cmddata, >cmddone, data, <>dataack, cd, >getcmd, >untailcmd, >input, error, <>message, <>ping, = 127 || (b < 32 && b != 10 && b != 13) { buf[idx] = '?' } } } func SendPacket(w io.Writer, packet PacketType) error { if packet == nil { return nil } jsonBytes, err := json.Marshal(packet) if err != nil { return fmt.Errorf("marshaling '%s' packet: %w", packet.GetType(), err) } var outBuf bytes.Buffer outBuf.WriteByte('\n') outBuf.WriteString(fmt.Sprintf("##%d", len(jsonBytes))) outBuf.Write(jsonBytes) outBuf.WriteByte('\n') if GlobalDebug { fmt.Printf("SEND> %s\n", AsString(packet)) } outBytes := outBuf.Bytes() sanitizeBytes(outBytes) _, err = w.Write(outBytes) if err != nil { return err } return nil } func SendCmdError(w io.Writer, ck base.CommandKey, err error) error { return SendPacket(w, MakeCmdErrorPacket(ck, err)) } type PacketSender struct { Lock *sync.Mutex SendCh chan PacketType Err error Done bool DoneCh chan bool } func MakePacketSender(output io.Writer) *PacketSender { sender := &PacketSender{ Lock: &sync.Mutex{}, SendCh: make(chan PacketType, PacketSenderQueueSize), DoneCh: make(chan bool), } go func() { defer close(sender.DoneCh) defer sender.Close() for pk := range sender.SendCh { err := SendPacket(output, pk) if err != nil { sender.Lock.Lock() sender.Err = err sender.Lock.Unlock() return } } }() return sender } func MakeChannelPacketSender(packetCh chan PacketType) *PacketSender { sender := &PacketSender{ Lock: &sync.Mutex{}, SendCh: make(chan PacketType, PacketSenderQueueSize), DoneCh: make(chan bool), } go func() { defer close(sender.DoneCh) defer sender.Close() for pk := range sender.SendCh { packetCh <- pk } }() return sender } func (sender *PacketSender) Close() { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.Done { return } sender.Done = true close(sender.SendCh) } func (sender *PacketSender) WaitForDone() { <-sender.DoneCh } func (sender *PacketSender) checkStatus() error { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.Done { return fmt.Errorf("cannot send packet, sender write loop is closed") } if sender.Err != nil { return fmt.Errorf("cannot send packet, sender had error: %w", sender.Err) } return nil } func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) error { err := sender.checkStatus() if err != nil { return err } select { case sender.SendCh <- pk: return nil case <-ctx.Done(): return ctx.Err() } } func (sender *PacketSender) SendPacket(pk PacketType) error { err := sender.checkStatus() if err != nil { return err } sender.SendCh <- pk return nil } func (sender *PacketSender) SendCmdError(ck base.CommandKey, err error) error { return sender.SendPacket(MakeCmdErrorPacket(ck, err)) } func (sender *PacketSender) SendErrorResponse(reqId string, err error) error { pk := MakeErrorResponsePacket(reqId, err) return sender.SendPacket(pk) } func (sender *PacketSender) SendResponse(reqId string, data interface{}) error { pk := MakeResponsePacket(reqId, data) return sender.SendPacket(pk) } func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error { return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...))) } type UnknownPacketReporter interface { UnknownPacket(pk PacketType) } type DefaultUPR struct{} func (DefaultUPR) UnknownPacket(pk PacketType) { if pk.GetType() == CmdErrorPacketStr { errPacket := pk.(*CmdErrorPacketType) // at this point, just send the error packet to stderr rather than try to do something special fmt.Fprintf(os.Stderr, "[error] %s\n", errPacket.Error) } else if pk.GetType() == RawPacketStr { rawPacket := pk.(*RawPacketType) fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data) } else if pk.GetType() == CmdStartPacketStr { return // do nothing } else { fmt.Fprintf(os.Stderr, "[error] invalid packet received '%s'", AsExtType(pk)) } } // todo: clean hanging entries in RunMap when in server mode type RunPacketBuilder struct { RunMap map[base.CommandKey]*RunPacketType } func MakeRunPacketBuilder() *RunPacketBuilder { return &RunPacketBuilder{ RunMap: make(map[base.CommandKey]*RunPacketType), } } // returns (consumed, fullRunPacket) func (b *RunPacketBuilder) ProcessPacket(pk PacketType) (bool, *RunPacketType) { if pk.GetType() == RunPacketStr { runPacket := pk.(*RunPacketType) if len(runPacket.RunData) == 0 { return true, runPacket } b.RunMap[runPacket.CK] = runPacket return true, nil } if pk.GetType() == DataEndPacketStr { endPacket := pk.(*DataEndPacketType) runPacket := b.RunMap[endPacket.CK] // might be nil delete(b.RunMap, endPacket.CK) return true, runPacket } if pk.GetType() == DataPacketStr { dataPacket := pk.(*DataPacketType) runPacket := b.RunMap[dataPacket.CK] if runPacket == nil { return false, nil } for idx, runData := range runPacket.RunData { if runData.FdNum == dataPacket.FdNum { // can ignore error, will get caught later with RunData.DataLen check realData, _ := base64.StdEncoding.DecodeString(dataPacket.Data64) runData.Data = append(runData.Data, realData...) runPacket.RunData[idx] = runData break } } return true, nil } return false, nil }