// Copyright 2023, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package packet import ( "bytes" "context" "encoding/base64" "encoding/json" "errors" "fmt" "io" "os" "reflect" "sync" "github.com/wavetermdev/waveterm/waveshell/pkg/base" ) // single : run, >cmddata, >cmddone, data, <>dataack, run, >cmddata, >cmddone, run, >cmddata, >cmddone, data, <>dataack, cd, >getcmd, >untailcmd, >input, error, <>message, <>ping, streamfile, writefile, filedata*, = 127 || (b < 32 && b != 10 && b != 13) { buf[idx] = '?' } } } type SendError struct { IsWriteError bool // fatal IsMarshalError bool // not fatal PacketType string Err error } func (e *SendError) Unwrap() error { return e.Err } func (e *SendError) Error() string { if e.IsMarshalError { return fmt.Sprintf("SendPacket marshal-error '%s' packet: %v", e.PacketType, e.Err) } else if e.IsWriteError { return fmt.Sprintf("SendPacket write-error packet[%s]: %v", e.PacketType, e.Err) } else { return e.Err.Error() } } func MarshalPacket(packet PacketType) ([]byte, error) { if packet == nil { return nil, fmt.Errorf("invalid nil packet") } jsonBytes, err := json.Marshal(packet) if err != nil { return nil, &SendError{IsMarshalError: true, PacketType: packet.GetType(), Err: err} } var outBuf bytes.Buffer outBuf.WriteByte('\n') outBuf.WriteString(fmt.Sprintf("##%d", len(jsonBytes))) outBuf.Write(jsonBytes) outBuf.WriteByte('\n') outBytes := outBuf.Bytes() sanitizeBytes(outBytes) return outBytes, nil } func SendPacket(w io.Writer, packet PacketType) error { if packet == nil { return nil } outBytes, err := MarshalPacket(packet) if err != nil { return err } if GlobalDebug { base.Logf("SEND> %s\n", AsString(packet)) } _, err = w.Write(outBytes) if err != nil { return &SendError{IsWriteError: true, PacketType: packet.GetType(), Err: err} } return nil } func SendCmdError(w io.Writer, ck base.CommandKey, err error) error { return SendPacket(w, MakeCmdErrorPacket(ck, err)) } type PacketSender struct { Lock *sync.Mutex SendCh chan PacketType Done bool DoneCh chan bool ErrHandler func(*PacketSender, PacketType, error) ExitErr error } func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketType, error)) *PacketSender { sender := &PacketSender{ Lock: &sync.Mutex{}, SendCh: make(chan PacketType, PacketSenderQueueSize), DoneCh: make(chan bool), ErrHandler: errHandler, } go func() { defer close(sender.DoneCh) defer sender.Close() for pk := range sender.SendCh { err := SendPacket(output, pk) if err != nil { sender.goHandleError(pk, err) if serr, ok := err.(*SendError); ok && serr.IsMarshalError { // marshaler errors are recoverable continue } // write errors are not recoverable sender.Lock.Lock() sender.ExitErr = err sender.Lock.Unlock() return } } }() return sender } func (sender *PacketSender) goHandleError(pk PacketType, err error) { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.ErrHandler != nil { go sender.ErrHandler(sender, pk, err) } } func MakeChannelPacketSender(packetCh chan PacketType) *PacketSender { sender := &PacketSender{ Lock: &sync.Mutex{}, SendCh: make(chan PacketType, PacketSenderQueueSize), DoneCh: make(chan bool), } go func() { defer close(sender.DoneCh) defer sender.Close() for pk := range sender.SendCh { packetCh <- pk } }() return sender } func (sender *PacketSender) Close() { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.Done { return } sender.Done = true close(sender.SendCh) } // returns ExitErr if set func (sender *PacketSender) WaitForDone() error { <-sender.DoneCh sender.Lock.Lock() defer sender.Lock.Unlock() return sender.ExitErr } // this is "advisory", as there is a race condition between the loop closing and setting Done. // that's okay because that's an impossible race condition anyway (you could enqueue the packet // and then the connection dies, or it dies half way, etc.). this just stops blindly adding // packets forever when the loop is done. func (sender *PacketSender) checkStatus() error { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.Done { return fmt.Errorf("cannot send packet, sender write loop is closed") } return nil } func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) error { err := sender.checkStatus() if err != nil { return err } select { case sender.SendCh <- pk: return nil case <-ctx.Done(): return ctx.Err() } } func (sender *PacketSender) SendPacket(pk PacketType) error { err := sender.checkStatus() if err != nil { return err } sender.SendCh <- pk return nil } func (sender *PacketSender) SendCmdError(ck base.CommandKey, err error) error { return sender.SendPacket(MakeCmdErrorPacket(ck, err)) } func (sender *PacketSender) SendErrorResponse(reqId string, err error) error { pk := MakeErrorResponsePacket(reqId, err) return sender.SendPacket(pk) } func (sender *PacketSender) SendResponse(reqId string, data interface{}) error { pk := MakeResponsePacket(reqId, data) return sender.SendPacket(pk) } func (sender *PacketSender) SendMessageFmt(fmtStr string, args ...interface{}) error { return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...))) } type UnknownPacketReporter interface { UnknownPacket(pk PacketType) } type DefaultUPR struct{} func (DefaultUPR) UnknownPacket(pk PacketType) { if pk.GetType() == CmdErrorPacketStr { errPacket := pk.(*CmdErrorPacketType) // at this point, just send the error packet to stderr rather than try to do something special fmt.Fprintf(os.Stderr, "[error] %s\n", errPacket.Error) } else if pk.GetType() == RawPacketStr { rawPacket := pk.(*RawPacketType) fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data) } else if pk.GetType() == CmdStartPacketStr { return // do nothing } else { fmt.Fprintf(os.Stderr, "[error] invalid packet received '%s'", AsExtType(pk)) } } type MessageUPR struct { CK base.CommandKey Sender *PacketSender } func (upr MessageUPR) UnknownPacket(pk PacketType) { msg := FmtMessagePacket("[error] invalid packet received %s", AsString(pk)) msg.CK = upr.CK upr.Sender.SendPacket(msg) } // todo: clean hanging entries in RunMap when in server mode type RunPacketBuilder struct { RunMap map[base.CommandKey]*RunPacketType } func MakeRunPacketBuilder() *RunPacketBuilder { return &RunPacketBuilder{ RunMap: make(map[base.CommandKey]*RunPacketType), } } // returns (consumed, fullRunPacket) func (b *RunPacketBuilder) ProcessPacket(pk PacketType) (bool, *RunPacketType) { if pk.GetType() == RunPacketStr { runPacket := pk.(*RunPacketType) if len(runPacket.RunData) == 0 { return true, runPacket } b.RunMap[runPacket.CK] = runPacket return true, nil } if pk.GetType() == DataEndPacketStr { endPacket := pk.(*DataEndPacketType) runPacket := b.RunMap[endPacket.CK] // might be nil delete(b.RunMap, endPacket.CK) return true, runPacket } if pk.GetType() == DataPacketStr { dataPacket := pk.(*DataPacketType) runPacket := b.RunMap[dataPacket.CK] if runPacket == nil { return false, nil } for idx, runData := range runPacket.RunData { if runData.FdNum == dataPacket.FdNum { // can ignore error, will get caught later with RunData.DataLen check realData, _ := base64.StdEncoding.DecodeString(dataPacket.Data64) runData.Data = append(runData.Data, realData...) runPacket.RunData[idx] = runData break } } return true, nil } return false, nil }