diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index 82ab1427f..ccd34894e 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -9,6 +9,7 @@ import ( "log" "os" "regexp" + "strconv" "strings" "time" @@ -96,6 +97,12 @@ func setTermHtmlMode() { var oidRe = regexp.MustCompile(`^[0-9a-f]{8}$`) func validateEasyORef(oref string) error { + if oref == "this" { + return nil + } + if num, err := strconv.Atoi(oref); err == nil && num >= 1 { + return nil + } if strings.Contains(oref, ":") { _, err := waveobj.ParseORef(oref) if err != nil { diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 07e101f4b..6a75d261e 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -27,6 +27,8 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wstore" ) +const SimpleId_This = "this" + func (ws *WshServer) AuthenticateCommand(ctx context.Context, data string) error { w := wshutil.GetWshRpcFromContext(ctx) if w == nil { @@ -190,6 +192,17 @@ func sendWaveObjUpdate(oref waveobj.ORef) { } func resolveSimpleId(ctx context.Context, simpleId string) (*waveobj.ORef, error) { + if simpleId == SimpleId_This { + wshRpc := wshutil.GetWshRpcFromContext(ctx) + if wshRpc == nil { + return nil, fmt.Errorf("no wshrpc in context") + } + rpcCtx := wshRpc.GetRpcContext() + if rpcCtx.BlockId == "" { + return nil, fmt.Errorf("no blockid in rpc context") + } + return &waveobj.ORef{OType: wstore.OType_Block, OID: rpcCtx.BlockId}, nil + } if strings.Contains(simpleId, ":") { rtn, err := waveobj.ParseORef(simpleId) if err != nil { diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 7a478cc79..7a3d4b29b 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -287,6 +287,10 @@ func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext, } else { return nil, fmt.Errorf("iss claim is missing or invalid") } + return mapClaimsToRpcContext(claims), nil +} + +func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext { rpcCtx := &wshrpc.RpcContext{} if claims["blockid"] != nil { if blockId, ok := claims["blockid"].(string); ok { @@ -303,9 +307,25 @@ func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext, rpcCtx.WindowId = windowId } } - return rpcCtx, nil + return rpcCtx } +// only for use on client +func ExtractUnverifiedRpcContext(tokenStr string) (*wshrpc.RpcContext, error) { + // this happens on the client who does not have access to the secret key + // we want to read the claims without validating the signature + token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("error parsing token: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("error getting claims from token") + } + return mapClaimsToRpcContext(claims), nil +} + +// only for use on client func ExtractUnverifiedSocketName(tokenStr string) (string, error) { // this happens on the client who does not have access to the secret key // we want to read the claims without validating the signature