// Copyright 2023, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package packet import ( "bytes" "context" "encoding/base64" "encoding/json" "errors" "fmt" "io" "io/fs" "log" "os" "reflect" "sync" "time" "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/wlog" ) // single : run, >cmddata, >cmddone, data, <>dataack, run, >cmddata, >cmddone, run, >cmddata, >cmddone, data, <>dataack, cd, >getcmd, >untailcmd, >input, error, <>message, <>ping, streamfile, writefile, filedata*, %s\n", AsString(packet)) } _, err = w.Write(outBytes) if err != nil { return &SendError{IsWriteError: true, PacketType: packet.GetType(), Err: err} } return nil } type PacketSender struct { Lock *sync.Mutex SendCh chan PacketType Done bool DoneCh chan bool ErrHandler func(*PacketSender, PacketType, error) ExitErr error } func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketType, error)) *PacketSender { sender := &PacketSender{ Lock: &sync.Mutex{}, SendCh: make(chan PacketType, PacketSenderQueueSize), DoneCh: make(chan bool), ErrHandler: errHandler, } go func() { defer close(sender.DoneCh) defer sender.Close() for pk := range sender.SendCh { err := SendPacket(output, pk) if err != nil { sender.goHandleError(pk, err) if serr, ok := err.(*SendError); ok && serr.IsMarshalError { // marshaler errors are recoverable continue } // write errors are not recoverable sender.Lock.Lock() sender.ExitErr = err sender.Lock.Unlock() return } } }() return sender } func (sender *PacketSender) SendLogPacket(entry wlog.LogEntry) { sender.SendPacket(MakeLogPacket(entry)) } func (sender *PacketSender) goHandleError(pk PacketType, err error) { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.ErrHandler != nil { go sender.ErrHandler(sender, pk, err) } } func MakeChannelPacketSender(packetCh chan PacketType) *PacketSender { sender := &PacketSender{ Lock: &sync.Mutex{}, SendCh: make(chan PacketType, PacketSenderQueueSize), DoneCh: make(chan bool), } go func() { defer close(sender.DoneCh) defer sender.Close() for pk := range sender.SendCh { packetCh <- pk } }() return sender } func (sender *PacketSender) Close() { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.Done { return } sender.Done = true close(sender.SendCh) } // returns ExitErr if set func (sender *PacketSender) WaitForDone() error { <-sender.DoneCh sender.Lock.Lock() defer sender.Lock.Unlock() return sender.ExitErr } // this is "advisory", as there is a race condition between the loop closing and setting Done. // that's okay because that's an impossible race condition anyway (you could enqueue the packet // and then the connection dies, or it dies half way, etc.). this just stops blindly adding // packets forever when the loop is done. func (sender *PacketSender) checkStatus() error { sender.Lock.Lock() defer sender.Lock.Unlock() if sender.Done { return fmt.Errorf("cannot send packet, sender write loop is closed") } return nil } func (sender *PacketSender) SendPacketCtx(ctx context.Context, pk PacketType) error { err := sender.checkStatus() if err != nil { return err } select { case sender.SendCh <- pk: return nil case <-ctx.Done(): return ctx.Err() } } func (sender *PacketSender) SendPacket(pk PacketType) error { if pk == nil { log.Printf("tried to send nil packet\n") return fmt.Errorf("tried to send nil packet") } if pk.GetType() == "" { log.Printf("tried to send invalid packet: %T\n", pk) return fmt.Errorf("tried to send packet without a type: %T", pk) } err := sender.checkStatus() if err != nil { return err } sender.SendCh <- pk return nil } func (sender *PacketSender) SendErrorResponse(reqId string, err error) error { pk := MakeErrorResponsePacket(reqId, err) return sender.SendPacket(pk) } func (sender *PacketSender) SendResponse(reqId string, data interface{}) error { pk := MakeResponsePacket(reqId, data) return sender.SendPacket(pk) } func (sender *PacketSender) SendMessageFmt(fmtStr string, args ...interface{}) error { return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...))) } type UnknownPacketReporter interface { UnknownPacket(pk PacketType) } type DefaultUPR struct{} func (DefaultUPR) UnknownPacket(pk PacketType) { if pk.GetType() == RawPacketStr { rawPacket := pk.(*RawPacketType) fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data) } else if pk.GetType() == CmdStartPacketStr { return // do nothing } else { wlog.Logf("[upr] invalid packet received '%s'", AsExtType(pk)) } } type MessageUPR struct { CK base.CommandKey Sender *PacketSender } func (upr MessageUPR) UnknownPacket(pk PacketType) { msg := FmtMessagePacket("[error] invalid packet received %s", AsString(pk)) msg.CK = upr.CK upr.Sender.SendPacket(msg) } // todo: clean hanging entries in RunMap when in server mode type RunPacketBuilder struct { RunMap map[base.CommandKey]*RunPacketType } func MakeRunPacketBuilder() *RunPacketBuilder { return &RunPacketBuilder{ RunMap: make(map[base.CommandKey]*RunPacketType), } } // returns (consumed, fullRunPacket) func (b *RunPacketBuilder) ProcessPacket(pk PacketType) (bool, *RunPacketType) { if pk.GetType() == RunPacketStr { runPacket := pk.(*RunPacketType) if len(runPacket.RunData) == 0 { return true, runPacket } b.RunMap[runPacket.CK] = runPacket return true, nil } if pk.GetType() == DataEndPacketStr { endPacket := pk.(*DataEndPacketType) runPacket := b.RunMap[endPacket.CK] // might be nil delete(b.RunMap, endPacket.CK) return true, runPacket } if pk.GetType() == DataPacketStr { dataPacket := pk.(*DataPacketType) runPacket := b.RunMap[dataPacket.CK] if runPacket == nil { return false, nil } for idx, runData := range runPacket.RunData { if runData.FdNum == dataPacket.FdNum { // can ignore error, will get caught later with RunData.DataLen check realData, _ := base64.StdEncoding.DecodeString(dataPacket.Data64) runData.Data = append(runData.Data, realData...) runPacket.RunData[idx] = runData break } } return true, nil } return false, nil }