From 014c6fb2ec6c78ab901ea61f8fee76ff5045407f Mon Sep 17 00:00:00 2001 From: sawka Date: Fri, 14 Jun 2024 14:43:47 -0700 Subject: [PATCH] redo ptybuffer, move to wshutil to help with stdin processing. change wsh to use cobra --- cmd/wsh/cmd/getmeta.go | 42 ++++++ cmd/wsh/cmd/root.go | 27 ++++ cmd/wsh/cmd/setmeta.go | 86 ++++++++++++ pkg/blockcontroller/blockcontroller.go | 37 +++--- pkg/blockcontroller/ptybuffer.go | 119 ----------------- pkg/waveobj/waveobj.go | 22 ++++ pkg/wshutil/wshcmdreader.go | 174 +++++++++++++++++++++++++ pkg/wshutil/wshcommands.go | 13 ++ pkg/wshutil/wshutil.go | 8 +- pkg/wstore/wstore_dbops.go | 23 ++++ 10 files changed, 411 insertions(+), 140 deletions(-) create mode 100644 cmd/wsh/cmd/getmeta.go create mode 100644 cmd/wsh/cmd/setmeta.go delete mode 100644 pkg/blockcontroller/ptybuffer.go create mode 100644 pkg/wshutil/wshcmdreader.go diff --git a/cmd/wsh/cmd/getmeta.go b/cmd/wsh/cmd/getmeta.go new file mode 100644 index 000000000..f6274104e --- /dev/null +++ b/cmd/wsh/cmd/getmeta.go @@ -0,0 +1,42 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + "github.com/wavetermdev/thenextwave/pkg/wshutil" +) + +var getMetaCmd = &cobra.Command{ + Use: "getmeta", + Short: "get metadata for an entity", + Args: cobra.ExactArgs(1), + Run: getMetaRun, +} + +func init() { + rootCmd.AddCommand(getMetaCmd) +} + +func getMetaRun(cmd *cobra.Command, args []string) { + oref := args[0] + if oref == "" { + fmt.Println("oref is required") + return + } + err := validateEasyORef(oref) + if err != nil { + fmt.Printf("%v\n", err) + return + } + getMetaWshCmd := &wshutil.BlockGetMetaCommand{ + Command: wshutil.BlockCommand_SetMeta, + OID: oref, + } + barr, _ := wshutil.EncodeWaveOSCMessage(getMetaWshCmd) + os.Stdout.Write(barr) +} diff --git a/cmd/wsh/cmd/root.go b/cmd/wsh/cmd/root.go index 3441a9f65..f79e4abc5 100644 --- a/cmd/wsh/cmd/root.go +++ b/cmd/wsh/cmd/root.go @@ -8,10 +8,14 @@ import ( "log" "os" "os/signal" + "regexp" + "strings" "sync" "syscall" + "github.com/google/uuid" "github.com/spf13/cobra" + "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wshutil" "golang.org/x/term" ) @@ -78,6 +82,29 @@ func installShutdownSignalHandlers() { }() } +var oidRe = regexp.MustCompile(`^[0-9a-f]{8}$`) + +func validateEasyORef(oref string) error { + if strings.Index(oref, ":") >= 0 { + _, err := waveobj.ParseORef(oref) + if err != nil { + return fmt.Errorf("invalid ORef: %v", err) + } + return nil + } + if len(oref) == 8 { + if !oidRe.MatchString(oref) { + return fmt.Errorf("invalid short OID format, must only use 0-9a-f: %q", oref) + } + return nil + } + _, err := uuid.Parse(oref) + if err != nil { + return fmt.Errorf("invalid OID (must be UUID): %v", err) + } + return nil +} + // Execute executes the root command. func Execute() error { return rootCmd.Execute() diff --git a/cmd/wsh/cmd/setmeta.go b/cmd/wsh/cmd/setmeta.go new file mode 100644 index 000000000..7b19d0082 --- /dev/null +++ b/cmd/wsh/cmd/setmeta.go @@ -0,0 +1,86 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + + "github.com/spf13/cobra" + "github.com/wavetermdev/thenextwave/pkg/wshutil" +) + +var setMetaCmd = &cobra.Command{ + Use: "setmeta", + Short: "set metadata for an entity", + Args: cobra.MinimumNArgs(2), + Run: setMetaRun, +} + +func init() { + rootCmd.AddCommand(setMetaCmd) +} + +func parseMetaSets(metaSets []string) (map[string]interface{}, error) { + meta := make(map[string]interface{}) + for _, metaSet := range metaSets { + fields := strings.Split(metaSet, "=") + if len(fields) != 2 { + return nil, fmt.Errorf("invalid meta set: %q", metaSet) + } + setVal := fields[1] + if setVal == "" || setVal == "null" { + meta[fields[0]] = nil + } else if setVal == "true" { + meta[fields[0]] = true + } else if setVal == "false" { + meta[fields[0]] = false + } else if setVal[0] == '[' || setVal[0] == '{' { + var val interface{} + err := json.Unmarshal([]byte(setVal), &val) + if err != nil { + return nil, fmt.Errorf("invalid json value: %v", err) + } + meta[fields[0]] = val + } else { + ival, err := strconv.ParseInt(setVal, 10, 64) + if err == nil { + meta[fields[0]] = ival + } else { + meta[fields[0]] = setVal + } + } + meta[fields[0]] = fields[1] + } + return meta, nil +} + +func setMetaRun(cmd *cobra.Command, args []string) { + oref := args[0] + metaSetsStrs := args[1:] + if oref == "" { + fmt.Println("oref is required") + return + } + err := validateEasyORef(oref) + if err != nil { + fmt.Printf("%v\n", err) + return + } + meta, err := parseMetaSets(metaSetsStrs) + if err != nil { + fmt.Printf("%v\n", err) + return + } + setMetaWshCmd := &wshutil.BlockSetMetaCommand{ + Command: wshutil.BlockCommand_SetMeta, + OID: oref, + Meta: meta, + } + barr, _ := wshutil.EncodeWaveOSCMessage(setMetaWshCmd) + os.Stdout.Write(barr) +} diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 115ef072f..60d84a885 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -46,10 +46,8 @@ type BlockController struct { InputCh chan wshutil.BlockCommand Status string CreatedHtmlFile bool - - PtyBuffer *PtyBuffer - ShellProc *shellexec.ShellProc - ShellInputCh chan *wshutil.BlockInputCommand + ShellProc *shellexec.ShellProc + ShellInputCh chan *wshutil.BlockInputCommand } func (bc *BlockController) WithLock(f func()) { @@ -187,6 +185,17 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts) error { } shellInputCh := make(chan *wshutil.BlockInputCommand) bc.ShellInputCh = shellInputCh + commandCh := make(chan wshutil.BlockCommand, 32) + ptyBuffer := wshutil.MakePtyBuffer(bc.ShellProc.Pty, commandCh) + go func() { + for cmd := range commandCh { + if strings.HasPrefix(cmd.GetCommand(), "controller:") { + bc.InputCh <- cmd + } else { + ProcessStaticCommand(bc.BlockId, cmd) + } + } + }() go func() { defer func() { // needs synchronization @@ -197,12 +206,11 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts) error { }() buf := make([]byte, 4096) for { - nr, err := bc.ShellProc.Pty.Read(buf) + nr, err := ptyBuffer.Read(buf) if nr > 0 { - bc.PtyBuffer.AppendData(buf[:nr]) - if bc.PtyBuffer.Err != nil { - log.Printf("error processing pty data: %v\n", bc.PtyBuffer.Err) - break + err := handleAppendBlockFile(bc.BlockId, BlockFile_Main, buf[:nr]) + if err != nil { + log.Printf("error appending to blockfile: %v\n", err) } } if err == io.EOF { @@ -303,17 +311,6 @@ func StartBlockController(ctx context.Context, blockId string) error { Status: "init", InputCh: make(chan wshutil.BlockCommand), } - ptyBuffer := MakePtyBuffer(func(fileName string, data []byte) error { - return handleAppendBlockFile(blockId, fileName, data) - }, func(cmd wshutil.BlockCommand) error { - if strings.HasPrefix(cmd.GetCommand(), "controller:") { - bc.InputCh <- cmd - } else { - ProcessStaticCommand(blockId, cmd) - } - return nil - }) - bc.PtyBuffer = ptyBuffer blockControllerMap[blockId] = bc go bc.Run(blockData) return nil diff --git a/pkg/blockcontroller/ptybuffer.go b/pkg/blockcontroller/ptybuffer.go deleted file mode 100644 index fb20a96c4..000000000 --- a/pkg/blockcontroller/ptybuffer.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2024, Command Line Inc. -// SPDX-License-Identifier: Apache-2.0 - -package blockcontroller - -import ( - "encoding/json" - "fmt" - - "github.com/wavetermdev/thenextwave/pkg/wshutil" -) - -const ( - Mode_Normal = "normal" - Mode_Esc = "esc" - Mode_WaveEsc = "waveesc" -) - -type PtyBuffer struct { - Mode string - EscSeqBuf []byte - DataOutputFn func(string, []byte) error - CommandOutputFn func(wshutil.BlockCommand) error - Err error -} - -func MakePtyBuffer(dataOutputFn func(string, []byte) error, commandOutputFn func(wshutil.BlockCommand) error) *PtyBuffer { - return &PtyBuffer{ - Mode: Mode_Normal, - DataOutputFn: dataOutputFn, - CommandOutputFn: commandOutputFn, - } -} - -func (b *PtyBuffer) setErr(err error) { - if b.Err == nil { - b.Err = err - } -} - -func (b *PtyBuffer) processWaveEscSeq(escSeq []byte) { - jmsg := make(map[string]any) - err := json.Unmarshal(escSeq, &jmsg) - if err != nil { - b.setErr(fmt.Errorf("error unmarshalling Wave OSC sequence data: %w", err)) - return - } - cmd, err := wshutil.ParseCmdMap(jmsg) - if err != nil { - b.setErr(fmt.Errorf("error parsing Wave OSC command: %w", err)) - return - } - err = b.CommandOutputFn(cmd) - if err != nil { - b.setErr(fmt.Errorf("error processing Wave OSC command: %w", err)) - return - } -} - -func (b *PtyBuffer) AppendData(data []byte) { - outputBuf := make([]byte, 0, len(data)) - for _, ch := range data { - if b.Mode == Mode_WaveEsc { - if ch == wshutil.ESC { - // terminates the escape sequence (and the rest was invalid) - b.Mode = Mode_Normal - outputBuf = append(outputBuf, b.EscSeqBuf...) - outputBuf = append(outputBuf, ch) - b.EscSeqBuf = nil - } else if ch == wshutil.BEL || ch == wshutil.ST { - // terminates the escpae sequence (is a valid Wave OSC command) - b.Mode = Mode_Normal - waveEscSeq := b.EscSeqBuf[len(wshutil.WaveOSCPrefix):] - b.EscSeqBuf = nil - b.processWaveEscSeq(waveEscSeq) - } else { - b.EscSeqBuf = append(b.EscSeqBuf, ch) - } - continue - } - if b.Mode == Mode_Esc { - if ch == wshutil.ESC || ch == wshutil.BEL || ch == wshutil.ST { - // these all terminate the escape sequence (invalid, not a Wave OSC) - b.Mode = Mode_Normal - outputBuf = append(outputBuf, b.EscSeqBuf...) - outputBuf = append(outputBuf, ch) - } else { - if ch == wshutil.WaveOSCPrefixBytes[len(b.EscSeqBuf)] { - // we're still building what could be a Wave OSC sequence - b.EscSeqBuf = append(b.EscSeqBuf, ch) - } else { - // this is not a Wave OSC sequence, just an escape sequence - b.Mode = Mode_Normal - outputBuf = append(outputBuf, b.EscSeqBuf...) - outputBuf = append(outputBuf, ch) - continue - } - // check to see if we have a full Wave OSC prefix - if len(b.EscSeqBuf) == len(wshutil.WaveOSCPrefixBytes) { - b.Mode = Mode_WaveEsc - } - } - continue - } - // Mode_Normal - if ch == wshutil.ESC { - b.Mode = Mode_Esc - b.EscSeqBuf = []byte{ch} - continue - } - outputBuf = append(outputBuf, ch) - } - if len(outputBuf) > 0 { - err := b.DataOutputFn(BlockFile_Main, outputBuf) - if err != nil { - b.setErr(fmt.Errorf("error processing data output: %w", err)) - } - } -} diff --git a/pkg/waveobj/waveobj.go b/pkg/waveobj/waveobj.go index 99395886d..0d6bc3c13 100644 --- a/pkg/waveobj/waveobj.go +++ b/pkg/waveobj/waveobj.go @@ -7,8 +7,11 @@ import ( "encoding/json" "fmt" "reflect" + "regexp" + "strings" "sync" + "github.com/google/uuid" "github.com/mitchellh/mapstructure" ) @@ -39,6 +42,25 @@ func MakeORef(otype string, oid string) ORef { } } +var otypeRe = regexp.MustCompile(`^[a-z]+$`) + +func ParseORef(orefStr string) (ORef, error) { + fields := strings.Split(orefStr, ":") + if len(fields) != 2 { + return ORef{}, fmt.Errorf("invalid object reference: %q", orefStr) + } + otype := fields[0] + if !otypeRe.MatchString(otype) { + return ORef{}, fmt.Errorf("invalid object type: %q", otype) + } + oid := fields[1] + _, err := uuid.Parse(oid) + if err != nil { + return ORef{}, fmt.Errorf("invalid object id: %q", oid) + } + return ORef{OType: otype, OID: oid}, nil +} + type WaveObj interface { GetOType() string // should not depend on object state (should work with nil value) } diff --git a/pkg/wshutil/wshcmdreader.go b/pkg/wshutil/wshcmdreader.go new file mode 100644 index 000000000..655e16a05 --- /dev/null +++ b/pkg/wshutil/wshcmdreader.go @@ -0,0 +1,174 @@ +package wshutil + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "sync" +) + +const ( + Mode_Normal = "normal" + Mode_Esc = "esc" + Mode_WaveEsc = "waveesc" + BlockFile_Main = "main" // Assuming this is defined elsewhere +) + +const MaxBufferedDataSize = 256 * 1024 + +type PtyBuffer struct { + CVar *sync.Cond + DataBuf *bytes.Buffer + EscMode string + EscSeqBuf []byte + InputReader io.Reader + CommandCh chan BlockCommand + AtEOF bool + Err error +} + +func MakePtyBuffer(input io.Reader, commandCh chan BlockCommand) *PtyBuffer { + b := &PtyBuffer{ + CVar: sync.NewCond(&sync.Mutex{}), + DataBuf: &bytes.Buffer{}, + EscMode: Mode_Normal, + InputReader: input, + CommandCh: commandCh, + } + go b.run() + return b +} + +func (b *PtyBuffer) setErr(err error) { + b.CVar.L.Lock() + defer b.CVar.L.Unlock() + if b.Err == nil { + b.Err = err + } + b.CVar.Broadcast() +} + +func (b *PtyBuffer) setEOF() { + b.CVar.L.Lock() + defer b.CVar.L.Unlock() + b.AtEOF = true + b.CVar.Broadcast() +} + +func (b *PtyBuffer) processWaveEscSeq(escSeq []byte) { + jmsg := make(map[string]any) + err := json.Unmarshal(escSeq, &jmsg) + if err != nil { + b.setErr(fmt.Errorf("error unmarshalling Wave OSC sequence data: %w", err)) + return + } + cmd, err := ParseCmdMap(jmsg) + if err != nil { + b.setErr(fmt.Errorf("error parsing Wave OSC command: %w", err)) + return + } + b.CommandCh <- cmd +} + +func (b *PtyBuffer) run() { + defer close(b.CommandCh) + buf := make([]byte, 4096) + for { + n, err := b.InputReader.Read(buf) + b.processData(buf[:n]) + if err == io.EOF { + b.setEOF() + return + } + if err != nil { + b.setErr(fmt.Errorf("error reading input: %w", err)) + return + } + } +} + +func (b *PtyBuffer) processData(data []byte) { + outputBuf := make([]byte, 0, len(data)) + for _, ch := range data { + if b.EscMode == Mode_WaveEsc { + if ch == ESC { + // terminates the escape sequence (and the rest was invalid) + b.EscMode = Mode_Normal + outputBuf = append(outputBuf, b.EscSeqBuf...) + outputBuf = append(outputBuf, ch) + b.EscSeqBuf = nil + } else if ch == BEL || ch == ST { + // terminates the escpae sequence (is a valid Wave OSC command) + b.EscMode = Mode_Normal + waveEscSeq := b.EscSeqBuf[len(WaveOSCPrefix):] + b.EscSeqBuf = nil + b.processWaveEscSeq(waveEscSeq) + } else { + b.EscSeqBuf = append(b.EscSeqBuf, ch) + } + continue + } + if b.EscMode == Mode_Esc { + if ch == ESC || ch == BEL || ch == ST { + // these all terminate the escape sequence (invalid, not a Wave OSC) + b.EscMode = Mode_Normal + outputBuf = append(outputBuf, b.EscSeqBuf...) + outputBuf = append(outputBuf, ch) + } else { + if ch == WaveOSCPrefixBytes[len(b.EscSeqBuf)] { + // we're still building what could be a Wave OSC sequence + b.EscSeqBuf = append(b.EscSeqBuf, ch) + } else { + // this is not a Wave OSC sequence, just an escape sequence + b.EscMode = Mode_Normal + outputBuf = append(outputBuf, b.EscSeqBuf...) + outputBuf = append(outputBuf, ch) + continue + } + // check to see if we have a full Wave OSC prefix + if len(b.EscSeqBuf) == len(WaveOSCPrefixBytes) { + b.EscMode = Mode_WaveEsc + } + } + continue + } + // Mode_Normal + if ch == ESC { + b.EscMode = Mode_Esc + b.EscSeqBuf = []byte{ch} + continue + } + outputBuf = append(outputBuf, ch) + } + if len(outputBuf) > 0 { + b.writeData(outputBuf) + } +} + +func (b *PtyBuffer) writeData(data []byte) { + b.CVar.L.Lock() + defer b.CVar.L.Unlock() + // only wait if buffer is currently over max size, otherwise allow this append to go through + for b.DataBuf.Len() > MaxBufferedDataSize { + b.CVar.Wait() + } + b.DataBuf.Write(data) + b.CVar.Broadcast() +} + +func (b *PtyBuffer) Read(p []byte) (n int, err error) { + b.CVar.L.Lock() + defer b.CVar.L.Unlock() + for b.DataBuf.Len() == 0 { + if b.Err != nil { + return 0, b.Err + } + if b.AtEOF { + return 0, io.EOF + } + b.CVar.Wait() + } + b.CVar.Broadcast() + return b.DataBuf.Read(p) +} diff --git a/pkg/wshutil/wshcommands.go b/pkg/wshutil/wshcommands.go index 32d4da2b1..d87929556 100644 --- a/pkg/wshutil/wshcommands.go +++ b/pkg/wshutil/wshcommands.go @@ -19,6 +19,7 @@ const ( BlockCommand_Message = "message" BlockCommand_SetView = "setview" BlockCommand_SetMeta = "setmeta" + BlockCommand_GetMeta = "getmeta" BlockCommand_Input = "controller:input" BlockCommand_AppendBlockFile = "blockfile:append" BlockCommand_AppendIJson = "blockfile:appendijson" @@ -28,6 +29,7 @@ var CommandToTypeMap = map[string]reflect.Type{ BlockCommand_Input: reflect.TypeOf(BlockInputCommand{}), BlockCommand_SetView: reflect.TypeOf(BlockSetViewCommand{}), BlockCommand_SetMeta: reflect.TypeOf(BlockSetMetaCommand{}), + BlockCommand_GetMeta: reflect.TypeOf(BlockGetMetaCommand{}), BlockCommand_Message: reflect.TypeOf(BlockMessageCommand{}), BlockCommand_AppendBlockFile: reflect.TypeOf(BlockAppendFileCommand{}), BlockCommand_AppendIJson: reflect.TypeOf(BlockAppendIJsonCommand{}), @@ -98,8 +100,19 @@ func (svc *BlockSetViewCommand) GetCommand() string { return BlockCommand_SetView } +type BlockGetMetaCommand struct { + Command string `json:"command" tstype:"\"getmeta\""` + RpcId string `json:"rpcid"` + OID string `json:"oid"` // allows oref, 8-char oid, or full uuid +} + +func (gmc *BlockGetMetaCommand) GetCommand() string { + return BlockCommand_GetMeta +} + type BlockSetMetaCommand struct { Command string `json:"command" tstype:"\"setmeta\""` + OID string `json:"oid"` // allows oref, 8-char oid, or full uuid Meta map[string]any `json:"meta"` } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 36eab91bf..ade3cc1f3 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -13,6 +13,9 @@ import ( const WaveOSC = "23198" const WaveOSCPrefix = "\x1b]" + WaveOSC + ";" +const WaveResponseOSC = "23199" +const WaveResponseOSCPrefix = "\x1b]" + WaveResponseOSC + ";" + const HexChars = "0123456789ABCDEF" const BEL = 0x07 const ST = 0x9c @@ -25,9 +28,12 @@ var WaveOSCPrefixBytes = []byte(WaveOSCPrefix) // JSON = must escape all ASCII control characters ([\x00-\x1F\x7F]) // we can tell the difference between JSON and base64-JSON by the first character: '{' or not +// for responses (terminal -> program), we'll use OSC 23199 +// same json format + func EncodeWaveOSCMessage(cmd BlockCommand) ([]byte, error) { if cmd.GetCommand() == "" { - return nil, fmt.Errorf("Command field not set in struct") + return nil, fmt.Errorf("command field not set in struct") } ctype, ok := CommandToTypeMap[cmd.GetCommand()] if !ok { diff --git a/pkg/wstore/wstore_dbops.go b/pkg/wstore/wstore_dbops.go index e2cc7a019..5ad2aff97 100644 --- a/pkg/wstore/wstore_dbops.go +++ b/pkg/wstore/wstore_dbops.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "log" + "reflect" "time" "github.com/wavetermdev/thenextwave/pkg/filestore" @@ -157,6 +158,28 @@ func DBSelectORefs(ctx context.Context, orefs []waveobj.ORef) ([]waveobj.WaveObj }) } +func DBResolveEasyOID(ctx context.Context, oid string) (*waveobj.ORef, error) { + return WithTxRtn(ctx, func(tx *TxWrap) (*waveobj.ORef, error) { + for _, rtype := range AllWaveObjTypes() { + otype := reflect.Zero(rtype).Interface().(waveobj.WaveObj).GetOType() + table := tableNameFromOType(otype) + var fullOID string + if len(oid) == 8 { + query := fmt.Sprintf("SELECT oid FROM %s WHERE oid LIKE ?", table) + fullOID = tx.GetString(query, oid+"%") + } else { + query := fmt.Sprintf("SELECT oid FROM %s WHERE oid = ?", table) + fullOID = tx.GetString(query, oid) + } + if fullOID != "" { + oref := waveobj.MakeORef(otype, fullOID) + return &oref, nil + } + } + return nil, ErrNotFound + }) +} + func DBSelectMap[T waveobj.WaveObj](ctx context.Context, ids []string) (map[string]T, error) { rtnArr, err := dbSelectOIDs(ctx, getOTypeGen[T](), ids) if err != nil {