From 2c055b56d09ec7e39a0f9f9802ef40d51e4b9974 Mon Sep 17 00:00:00 2001 From: Mike Sawka Date: Fri, 8 Nov 2024 15:48:54 -0800 Subject: [PATCH] new resolver formats (tab:N), and also make the structure of the resolvers much more robust (#1254) --- cmd/wsh/cmd/wshcmd-deleteblock.go | 9 +- cmd/wsh/cmd/wshcmd-getmeta.go | 7 +- cmd/wsh/cmd/wshcmd-readfile.go | 7 +- cmd/wsh/cmd/wshcmd-root.go | 33 ----- cmd/wsh/cmd/wshcmd-setmeta.go | 9 +- cmd/wsh/cmd/wshcmd-web.go | 4 - pkg/waveobj/waveobj.go | 3 + pkg/waveobj/wtype.go | 9 ++ pkg/wshrpc/wshserver/resolvers.go | 198 ++++++++++++++++++++++++++++++ pkg/wshrpc/wshserver/wshserver.go | 70 ++--------- 10 files changed, 226 insertions(+), 123 deletions(-) create mode 100644 pkg/wshrpc/wshserver/resolvers.go diff --git a/cmd/wsh/cmd/wshcmd-deleteblock.go b/cmd/wsh/cmd/wshcmd-deleteblock.go index 9ccea751a..d80d703a3 100644 --- a/cmd/wsh/cmd/wshcmd-deleteblock.go +++ b/cmd/wsh/cmd/wshcmd-deleteblock.go @@ -21,18 +21,13 @@ func init() { func deleteBlockRun(cmd *cobra.Command, args []string) { oref := blockArg - err := validateEasyORef(oref) - if err != nil { - WriteStderr("[error]%v\n", err) - return - } fullORef, err := resolveSimpleId(oref) if err != nil { - WriteStderr("[error] resolving oref: %v\n", err) + WriteStderr("[error] %v\n", err) return } if fullORef.OType != "block" { - WriteStderr("[error] oref is not a block\n") + WriteStderr("[error] object reference is not a block\n") return } deleteBlockData := &wshrpc.CommandDeleteBlockData{ diff --git a/cmd/wsh/cmd/wshcmd-getmeta.go b/cmd/wsh/cmd/wshcmd-getmeta.go index 0edf3d151..7380866e6 100644 --- a/cmd/wsh/cmd/wshcmd-getmeta.go +++ b/cmd/wsh/cmd/wshcmd-getmeta.go @@ -74,14 +74,9 @@ func getMetaRun(cmd *cobra.Command, args []string) { WriteStderr("[error] oref is required") return } - err := validateEasyORef(oref) - if err != nil { - WriteStderr("[error] %v\n", err) - return - } fullORef, err := resolveSimpleId(oref) if err != nil { - WriteStderr("[error] resolving oref: %v\n", err) + WriteStderr("[error] %v\n", err) return } resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: *fullORef}, &wshrpc.RpcOpts{Timeout: 2000}) diff --git a/cmd/wsh/cmd/wshcmd-readfile.go b/cmd/wsh/cmd/wshcmd-readfile.go index 7c8444fbd..2f2ecbd50 100644 --- a/cmd/wsh/cmd/wshcmd-readfile.go +++ b/cmd/wsh/cmd/wshcmd-readfile.go @@ -29,14 +29,9 @@ func runReadFile(cmd *cobra.Command, args []string) { WriteStderr("[error] oref is required\n") return } - err := validateEasyORef(oref) - if err != nil { - WriteStderr("[error] %v\n", err) - return - } fullORef, err := resolveSimpleId(oref) if err != nil { - WriteStderr("error resolving oref: %v\n", err) + WriteStderr("[error] %v\n", err) return } resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.RpcOpts{Timeout: 5000}) diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index 6f29f65a5..fa6fe38c4 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -9,11 +9,9 @@ import ( "os" "regexp" "runtime/debug" - "strconv" "strings" "time" - "github.com/google/uuid" "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -76,10 +74,6 @@ func resolveBlockArg() (*waveobj.ORef, error) { if oref == "" { return nil, fmt.Errorf("blockid is required") } - err := validateEasyORef(oref) - if err != nil { - return nil, err - } fullORef, err := resolveSimpleId(oref) if err != nil { return nil, fmt.Errorf("resolving blockid: %w", err) @@ -128,33 +122,6 @@ func setTermHtmlMode() { var oidRe = regexp.MustCompile(`^[0-9a-f]{8}$`) -func validateEasyORef(oref string) error { - if oref == "this" || oref == "tab" { - return nil - } - if num, err := strconv.Atoi(oref); err == nil && num >= 1 { - return nil - } - if strings.Contains(oref, ":") { - _, err := waveobj.ParseORef(oref) - if err != nil { - return fmt.Errorf("invalid ORef: %v", err) - } - return nil - } - if len(oref) == 8 { - if !oidRe.MatchString(oref) { - return fmt.Errorf("invalid short OID format, must only use 0-9a-f: %q", oref) - } - return nil - } - _, err := uuid.Parse(oref) - if err != nil { - return fmt.Errorf("invalid object reference (must be UUID, or a positive integer): %v", err) - } - return nil -} - func isFullORef(orefStr string) bool { _, err := waveobj.ParseORef(orefStr) return err == nil diff --git a/cmd/wsh/cmd/wshcmd-setmeta.go b/cmd/wsh/cmd/wshcmd-setmeta.go index 8e70e4033..44ee60b39 100644 --- a/cmd/wsh/cmd/wshcmd-setmeta.go +++ b/cmd/wsh/cmd/wshcmd-setmeta.go @@ -112,14 +112,9 @@ func setMetaRun(cmd *cobra.Command, args []string) { WriteStderr("[error] block (oref) is required\n") return } - err := validateEasyORef(blockArg) - if err != nil { - WriteStderr("[error] %v\n", err) - return - } - var jsonMeta map[string]interface{} if setMetaJsonFilePath != "" { + var err error jsonMeta, err = loadJSONFile(setMetaJsonFilePath) if err != nil { WriteStderr("[error] %v\n", err) @@ -146,7 +141,7 @@ func setMetaRun(cmd *cobra.Command, args []string) { } fullORef, err := resolveSimpleId(blockArg) if err != nil { - WriteStderr("[error] resolving oref: %v\n", err) + WriteStderr("[error] %v\n", err) return } diff --git a/cmd/wsh/cmd/wshcmd-web.go b/cmd/wsh/cmd/wshcmd-web.go index 775a0cfef..882059f94 100644 --- a/cmd/wsh/cmd/wshcmd-web.go +++ b/cmd/wsh/cmd/wshcmd-web.go @@ -55,10 +55,6 @@ func webGetRun(cmd *cobra.Command, args []string) error { if oref == "" { return fmt.Errorf("blockid not specified") } - err := validateEasyORef(oref) - if err != nil { - return err - } fullORef, err := resolveSimpleId(oref) if err != nil { return fmt.Errorf("resolving blockid: %w", err) diff --git a/pkg/waveobj/waveobj.go b/pkg/waveobj/waveobj.go index 13111b54a..c8d005523 100644 --- a/pkg/waveobj/waveobj.go +++ b/pkg/waveobj/waveobj.go @@ -86,6 +86,9 @@ func ParseORef(orefStr string) (ORef, error) { if !otypeRe.MatchString(otype) { return ORef{}, fmt.Errorf("invalid object type: %q", otype) } + if !ValidOTypes[otype] { + return ORef{}, fmt.Errorf("unknown object type: %q", otype) + } oid := fields[1] _, err := uuid.Parse(oid) if err != nil { diff --git a/pkg/waveobj/wtype.go b/pkg/waveobj/wtype.go index 5953a22b3..0293319b1 100644 --- a/pkg/waveobj/wtype.go +++ b/pkg/waveobj/wtype.go @@ -30,6 +30,15 @@ const ( OType_Block = "block" ) +var ValidOTypes = map[string]bool{ + OType_Client: true, + OType_Window: true, + OType_Workspace: true, + OType_Tab: true, + OType_LayoutState: true, + OType_Block: true, +} + type WaveObjUpdate struct { UpdateType string `json:"updatetype"` OType string `json:"otype"` diff --git a/pkg/wshrpc/wshserver/resolvers.go b/pkg/wshrpc/wshserver/resolvers.go new file mode 100644 index 000000000..54688a833 --- /dev/null +++ b/pkg/wshrpc/wshserver/resolvers.go @@ -0,0 +1,198 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshserver + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/waveobj" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wstore" +) + +const SimpleId_This = "this" +const SimpleId_Tab = "tab" + +var ( + simpleTabNumRe = regexp.MustCompile(`^tab:(\d{1,3})$`) + shortUUIDRe = regexp.MustCompile(`^[0-9a-f]{8}$`) + SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`) +) + +// Helper function to validate UUIDs or 8-char UUIDs format +func isValidSimpleUUID(s string) bool { + // Try parsing as full UUID + _, err := uuid.Parse(s) + if err == nil { + return true + } + + // Check if it's an 8-char hex prefix + shortUUIDPattern := regexp.MustCompile(`^[0-9a-f]{8}$`) + return shortUUIDPattern.MatchString(strings.ToLower(s)) +} + +// First function: detect/choose discriminator +func parseSimpleId(simpleId string) (discriminator string, value string, err error) { + // Check for explicit discriminator with @ + if parts := strings.SplitN(simpleId, "@", 2); len(parts) == 2 { + return parts[0], parts[1], nil + } + + // Handle special keywords + if simpleId == SimpleId_This || simpleId == SimpleId_Tab { + return "this", simpleId, nil + } + + // Check if it's a simple ORef (type:uuid) + if _, err := waveobj.ParseORef(simpleId); err == nil { + return "oref", simpleId, nil + } + + // Check for tab:N format + if simpleTabNumRe.MatchString(simpleId) { + return "tabnum", simpleId, nil + } + + // Check for plain number (block reference) + if _, err := strconv.Atoi(simpleId); err == nil { + return "blocknum", simpleId, nil + } + + // Check for UUIDs + if _, err := uuid.Parse(simpleId); err == nil { + return "uuid", simpleId, nil + } + if shortUUIDRe.MatchString(strings.ToLower(simpleId)) { + return "uuid8", simpleId, nil + } + + return "", "", fmt.Errorf("invalid simple id format: %s", simpleId) +} + +// Individual resolvers +func resolveThis(ctx context.Context, data wshrpc.CommandResolveIdsData, value string) (*waveobj.ORef, error) { + if data.BlockId == "" { + return nil, fmt.Errorf("no blockid in request") + } + + if value == SimpleId_This { + return &waveobj.ORef{OType: waveobj.OType_Block, OID: data.BlockId}, nil + } + if value == SimpleId_Tab { + tabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId) + if err != nil { + return nil, fmt.Errorf("error finding tab: %v", err) + } + return &waveobj.ORef{OType: waveobj.OType_Tab, OID: tabId}, nil + } + return nil, fmt.Errorf("invalid value for 'this' resolver: %s", value) +} + +func resolveORef(ctx context.Context, value string) (*waveobj.ORef, error) { + parsedORef, err := waveobj.ParseORef(value) + if err != nil { + return nil, fmt.Errorf("error parsing oref: %v", err) + } + return &parsedORef, nil +} + +func resolveTabNum(ctx context.Context, data wshrpc.CommandResolveIdsData, value string) (*waveobj.ORef, error) { + m := simpleTabNumRe.FindStringSubmatch(value) + if m == nil { + return nil, fmt.Errorf("error parsing simple tab id: %s", value) + } + + tabNum, err := strconv.Atoi(m[1]) + if err != nil { + return nil, fmt.Errorf("error parsing simple tab num: %v", err) + } + + curTabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId) + if err != nil { + return nil, fmt.Errorf("error finding tab for block: %v", err) + } + + wsId, err := wstore.DBFindWorkspaceForTabId(ctx, curTabId) + if err != nil { + return nil, fmt.Errorf("error finding current workspace: %v", err) + } + + ws, err := wstore.DBMustGet[*waveobj.Workspace](ctx, wsId) + if err != nil { + return nil, fmt.Errorf("error getting workspace: %v", err) + } + + if tabNum < 1 || tabNum > len(ws.TabIds) { + return nil, fmt.Errorf("tab num out of range, workspace has %d tabs", len(ws.TabIds)) + } + + resolvedTabId := ws.TabIds[tabNum-1] + return &waveobj.ORef{OType: waveobj.OType_Tab, OID: resolvedTabId}, nil +} + +func resolveBlock(ctx context.Context, data wshrpc.CommandResolveIdsData, value string) (*waveobj.ORef, error) { + blockNum, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("error parsing block number: %v", err) + } + + tabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId) + if err != nil { + return nil, fmt.Errorf("error finding tab for blockid %s: %w", data.BlockId, err) + } + + tab, err := wstore.DBGet[*waveobj.Tab](ctx, tabId) + if err != nil { + return nil, fmt.Errorf("error retrieving tab %s: %w", tabId, err) + } + + layout, err := wstore.DBGet[*waveobj.LayoutState](ctx, tab.LayoutState) + if err != nil { + return nil, fmt.Errorf("error retrieving layout state %s: %w", tab.LayoutState, err) + } + + if layout.LeafOrder == nil { + return nil, fmt.Errorf("could not resolve block num %v, leaf order is empty", blockNum) + } + + leafIndex := blockNum - 1 // block nums are 1-indexed + if len(*layout.LeafOrder) <= leafIndex { + return nil, fmt.Errorf("could not find a node in the layout matching blockNum %v", blockNum) + } + + leafEntry := (*layout.LeafOrder)[leafIndex] + return &waveobj.ORef{OType: waveobj.OType_Block, OID: leafEntry.BlockId}, nil +} + +func resolveUUID(ctx context.Context, value string) (*waveobj.ORef, error) { + return wstore.DBResolveEasyOID(ctx, value) +} + +// Main resolver function +func resolveSimpleId(ctx context.Context, data wshrpc.CommandResolveIdsData, simpleId string) (*waveobj.ORef, error) { + discriminator, value, err := parseSimpleId(simpleId) + if err != nil { + return nil, err + } + switch discriminator { + case "this": + return resolveThis(ctx, data, value) + case "oref": + return resolveORef(ctx, value) + case "tabnum": + return resolveTabNum(ctx, data, value) + case "blocknum": + return resolveBlock(ctx, data, value) + case "uuid", "uuid8": + return resolveUUID(ctx, value) + default: + return nil, fmt.Errorf("unknown discriminator: %s", discriminator) + } +} diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index b4b3bdebe..f1c9cfa57 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -12,8 +12,6 @@ import ( "fmt" "io/fs" "log" - "regexp" - "strconv" "strings" "time" @@ -35,10 +33,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/wstore" ) -const SimpleId_This = "this" -const SimpleId_Tab = "tab" - -var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`) var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"} type WshServer struct{} @@ -155,70 +149,26 @@ func sendWaveObjUpdate(oref waveobj.ORef) { }) } -func resolveSimpleId(ctx context.Context, data wshrpc.CommandResolveIdsData, simpleId string) (*waveobj.ORef, error) { - if simpleId == SimpleId_This { - if data.BlockId == "" { - return nil, fmt.Errorf("no blockid in request") - } - return &waveobj.ORef{OType: waveobj.OType_Block, OID: data.BlockId}, nil - } - if simpleId == SimpleId_Tab { - if data.BlockId == "" { - return nil, fmt.Errorf("no blockid in request") - } - tabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId) - if err != nil { - return nil, fmt.Errorf("error finding tab: %v", err) - } - return &waveobj.ORef{OType: waveobj.OType_Tab, OID: tabId}, nil - } - blockNum, err := strconv.Atoi(simpleId) - if err == nil { - tabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId) - if err != nil { - return nil, fmt.Errorf("error finding tab for blockid %s: %w", data.BlockId, err) - } - - tab, err := wstore.DBGet[*waveobj.Tab](ctx, tabId) - if err != nil { - return nil, fmt.Errorf("error retrieving tab %s: %w", tabId, err) - } - - layout, err := wstore.DBGet[*waveobj.LayoutState](ctx, tab.LayoutState) - if err != nil { - return nil, fmt.Errorf("error retrieving layout state %s: %w", tab.LayoutState, err) - } - - if layout.LeafOrder == nil { - return nil, fmt.Errorf("could not resolve block num %v, leaf order is empty", blockNum) - } - - leafIndex := blockNum - 1 // block nums are 1-indexed, we need the 0-indexed version - if len(*layout.LeafOrder) <= leafIndex { - return nil, fmt.Errorf("could not find a node in the layout matching blockNum %v", blockNum) - } - leafEntry := (*layout.LeafOrder)[leafIndex] - return &waveobj.ORef{OType: waveobj.OType_Block, OID: leafEntry.BlockId}, nil - } else if strings.Contains(simpleId, ":") { - rtn, err := waveobj.ParseORef(simpleId) - if err != nil { - return nil, fmt.Errorf("error parsing simple id: %w", err) - } - return &rtn, nil - } - return wstore.DBResolveEasyOID(ctx, simpleId) -} - func (ws *WshServer) ResolveIdsCommand(ctx context.Context, data wshrpc.CommandResolveIdsData) (wshrpc.CommandResolveIdsRtnData, error) { rtn := wshrpc.CommandResolveIdsRtnData{} rtn.ResolvedIds = make(map[string]waveobj.ORef) + var firstErr error for _, simpleId := range data.Ids { oref, err := resolveSimpleId(ctx, data, simpleId) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } if err != nil || oref == nil { continue } rtn.ResolvedIds[simpleId] = *oref } + if firstErr != nil && len(data.Ids) == 1 { + return rtn, firstErr + } return rtn, nil }