initscript -- support for local files, and overrides in connections.json (#1818)

This commit is contained in:
Mike Sawka 2025-01-23 15:41:13 -08:00 committed by GitHub
parent 8bf90c0a4d
commit 1913cc5c99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 423 additions and 48 deletions

170
cmd/wsh/cmd/setmeta_test.go Normal file
View File

@ -0,0 +1,170 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"reflect"
"testing"
)
func TestParseMetaSets(t *testing.T) {
tests := []struct {
name string
input []string
want map[string]any
wantErr bool
}{
{
name: "basic types",
input: []string{"str=hello", "num=42", "float=3.14", "bool=true", "null=null"},
want: map[string]any{
"str": "hello",
"num": int64(42),
"float": float64(3.14),
"bool": true,
"null": nil,
},
},
{
name: "json values",
input: []string{
`arr=[1,2,3]`,
`obj={"foo":"bar"}`,
`str="quoted"`,
},
want: map[string]any{
"arr": []any{float64(1), float64(2), float64(3)},
"obj": map[string]any{"foo": "bar"},
"str": "quoted",
},
},
{
name: "nested paths",
input: []string{
"a/b=55",
"a/c=2",
},
want: map[string]any{
"a": map[string]any{
"b": int64(55),
"c": int64(2),
},
},
},
{
name: "deep nesting",
input: []string{
"a/b/c/d=hello",
},
want: map[string]any{
"a": map[string]any{
"b": map[string]any{
"c": map[string]any{
"d": "hello",
},
},
},
},
},
{
name: "override nested value",
input: []string{
"a/b/c=1",
"a/b=2",
},
want: map[string]any{
"a": map[string]any{
"b": int64(2),
},
},
},
{
name: "override with null",
input: []string{
"a/b=1",
"a/c=2",
"a=null",
},
want: map[string]any{
"a": nil,
},
},
{
name: "mixed types in path",
input: []string{
"a/b=1",
"a/c=[1,2,3]",
"a/d/e=true",
},
want: map[string]any{
"a": map[string]any{
"b": int64(1),
"c": []any{float64(1), float64(2), float64(3)},
"d": map[string]any{
"e": true,
},
},
},
},
{
name: "invalid format",
input: []string{"invalid"},
wantErr: true,
},
{
name: "invalid json",
input: []string{`a={"invalid`},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseMetaSets(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseMetaSets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseMetaSets() = %v, want %v", got, tt.want)
}
})
}
}
func TestParseMetaValue(t *testing.T) {
tests := []struct {
name string
input string
want any
wantErr bool
}{
{"empty string", "", nil, false},
{"null", "null", nil, false},
{"true", "true", true, false},
{"false", "false", false, false},
{"integer", "42", int64(42), false},
{"negative integer", "-42", int64(-42), false},
{"hex integer", "0xff", int64(255), false},
{"float", "3.14", float64(3.14), false},
{"string", "hello", "hello", false},
{"json array", "[1,2,3]", []any{float64(1), float64(2), float64(3)}, false},
{"json object", `{"foo":"bar"}`, map[string]any{"foo": "bar"}, false},
{"quoted string", `"quoted"`, "quoted", false},
{"invalid json", `{"invalid`, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseMetaValue(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseMetaValue() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseMetaValue() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -58,40 +58,88 @@ func loadJSONFile(filepath string) (map[string]interface{}, error) {
return result, nil
}
func parseMetaSets(metaSets []string) (map[string]interface{}, error) {
meta := make(map[string]interface{})
func parseMetaValue(setVal string) (any, error) {
if setVal == "" || setVal == "null" {
return nil, nil
}
if setVal == "true" {
return true, nil
}
if setVal == "false" {
return false, nil
}
if setVal[0] == '[' || setVal[0] == '{' || setVal[0] == '"' {
var val any
err := json.Unmarshal([]byte(setVal), &val)
if err != nil {
return nil, fmt.Errorf("invalid json value: %v", err)
}
return val, nil
}
// Try parsing as integer
ival, err := strconv.ParseInt(setVal, 0, 64)
if err == nil {
return ival, nil
}
// Try parsing as float
fval, err := strconv.ParseFloat(setVal, 64)
if err == nil {
return fval, nil
}
// Fallback to string
return setVal, nil
}
func setNestedValue(meta map[string]any, path []string, value any) {
// For single key, just set directly
if len(path) == 1 {
meta[path[0]] = value
return
}
// For nested path, traverse or create maps as needed
current := meta
for i := 0; i < len(path)-1; i++ {
key := path[i]
// If next level doesn't exist or isn't a map, create new map
next, exists := current[key]
if !exists {
nextMap := make(map[string]any)
current[key] = nextMap
current = nextMap
} else if nextMap, ok := next.(map[string]any); ok {
current = nextMap
} else {
// If existing value isn't a map, replace with new map
nextMap = make(map[string]any)
current[key] = nextMap
current = nextMap
}
}
// Set the final value
current[path[len(path)-1]] = value
}
func parseMetaSets(metaSets []string) (map[string]any, error) {
meta := make(map[string]any)
for _, metaSet := range metaSets {
fields := strings.SplitN(metaSet, "=", 2)
if len(fields) != 2 {
return nil, fmt.Errorf("invalid meta set: %q", metaSet)
}
setVal := fields[1]
if setVal == "" || setVal == "null" {
meta[fields[0]] = nil
} else if setVal == "true" {
meta[fields[0]] = true
} else if setVal == "false" {
meta[fields[0]] = false
} else if setVal[0] == '[' || setVal[0] == '{' || setVal[0] == '"' {
var val interface{}
err := json.Unmarshal([]byte(setVal), &val)
if err != nil {
return nil, fmt.Errorf("invalid json value: %v", err)
}
meta[fields[0]] = val
} else {
ival, err := strconv.ParseInt(setVal, 0, 64)
if err == nil {
meta[fields[0]] = ival
} else {
fval, err := strconv.ParseFloat(setVal, 64)
if err == nil {
meta[fields[0]] = fval
} else {
meta[fields[0]] = setVal
}
}
val, err := parseMetaValue(fields[1])
if err != nil {
return nil, err
}
// Split the key path and set nested value
path := strings.Split(fields[0], "/")
setNestedValue(meta, path, val)
}
return meta, nil
}

View File

@ -306,6 +306,13 @@ declare global {
"term:fontsize"?: number;
"term:fontfamily"?: string;
"term:theme"?: string;
"cmd:env"?: {[key: string]: string};
"cmd:initscript"?: string;
"cmd:initscript.sh"?: string;
"cmd:initscript.bash"?: string;
"cmd:initscript.zsh"?: string;
"cmd:initscript.pwsh"?: string;
"cmd:initscript.fish"?: string;
"ssh:user"?: string;
"ssh:hostname"?: string;
"ssh:port"?: string;

View File

@ -11,6 +11,7 @@ import (
"io"
"io/fs"
"log"
"os"
"strings"
"sync"
"sync/atomic"
@ -24,6 +25,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/shellexec"
"github.com/wavetermdev/waveterm/pkg/util/envutil"
"github.com/wavetermdev/waveterm/pkg/util/fileutil"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
@ -57,6 +59,7 @@ const (
const (
DefaultTermMaxFileSize = 256 * 1024
DefaultHtmlMaxFileSize = 256 * 1024
MaxInitScriptSize = 50 * 1024
)
const DefaultTimeout = 2 * time.Second
@ -232,25 +235,79 @@ func getCustomInitScriptKeyCascade(shellType string) []string {
return []string{waveobj.MetaKey_CmdInitScript}
}
func getCustomInitScript(meta waveobj.MetaMapType, connName string, shellType string) string {
func getCustomInitScript(logCtx context.Context, meta waveobj.MetaMapType, connName string, shellType string) string {
initScriptVal, metaKeyName := getCustomInitScriptValue(meta, connName, shellType)
if initScriptVal == "" {
return ""
}
if !fileutil.IsInitScriptPath(initScriptVal) {
blocklogger.Infof(logCtx, "[conndebug] inline initScript (size=%d) found in meta key: %s\n", len(initScriptVal), metaKeyName)
return initScriptVal
}
blocklogger.Infof(logCtx, "[conndebug] initScript detected as a file %q from meta key: %s\n", initScriptVal, metaKeyName)
initScriptVal, err := wavebase.ExpandHomeDir(initScriptVal)
if err != nil {
blocklogger.Infof(logCtx, "[conndebug] cannot expand home dir in Wave initscript file: %v\n", err)
return fmt.Sprintf("echo \"cannot expand home dir in Wave initscript file, from key %s\";\n", metaKeyName)
}
fileData, err := os.ReadFile(initScriptVal)
if err != nil {
blocklogger.Infof(logCtx, "[conndebug] cannot open Wave initscript file: %v\n", err)
return fmt.Sprintf("echo \"cannot open Wave initscript file, from key %s\";\n", metaKeyName)
}
if len(fileData) > MaxInitScriptSize {
blocklogger.Infof(logCtx, "[conndebug] initscript file too large, size=%d, max=%d\n", len(fileData), MaxInitScriptSize)
return fmt.Sprintf("echo \"initscript file too large, from key %s\";\n", metaKeyName)
}
if utilfn.HasBinaryData(fileData) {
blocklogger.Infof(logCtx, "[conndebug] initscript file contains binary data\n")
return fmt.Sprintf("echo \"initscript file contains binary data, from key %s\";\n", metaKeyName)
}
blocklogger.Infof(logCtx, "[conndebug] initscript file read successfully, size=%d\n", len(fileData))
return string(fileData)
}
// returns (value, metakey)
func getCustomInitScriptValue(meta waveobj.MetaMapType, connName string, shellType string) (string, string) {
keys := getCustomInitScriptKeyCascade(shellType)
connMeta := meta.GetConnectionOverride(connName)
if connMeta != nil {
for _, key := range keys {
if connMeta.HasKey(key) {
return connMeta.GetString(key, "")
return connMeta.GetString(key, ""), "blockmeta/[" + connName + "]/" + key
}
}
}
for _, key := range keys {
if meta.HasKey(key) {
return meta.GetString(key, "")
return meta.GetString(key, ""), "blockmeta/" + key
}
}
return ""
fullConfig := wconfig.GetWatcher().GetFullConfig()
connKeywords := fullConfig.Connections[connName]
connKeywordsMap := make(map[string]any)
err := utilfn.ReUnmarshal(&connKeywordsMap, connKeywords)
if err != nil {
log.Printf("error re-unmarshalling connKeywords: %v\n", err)
return "", ""
}
ckMeta := waveobj.MetaMapType(connKeywordsMap)
for _, key := range keys {
if ckMeta.HasKey(key) {
return ckMeta.GetString(key, ""), "connections.json/" + connName + "/" + key
}
}
return "", ""
}
func resolveEnvMap(blockId string, blockMeta waveobj.MetaMapType, connName string) (map[string]string, error) {
rtn := make(map[string]string)
config := wconfig.GetWatcher().GetFullConfig()
connKeywords := config.Connections[connName]
ckEnv := connKeywords.CmdEnv
for k, v := range ckEnv {
rtn[k] = v
}
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
_, envFileData, err := filestore.WFS.ReadFile(ctx, blockId, wavebase.BlockFile_Env)
@ -260,25 +317,27 @@ func resolveEnvMap(blockId string, blockMeta waveobj.MetaMapType, connName strin
if err != nil {
return nil, fmt.Errorf("error reading command env file: %w", err)
}
rtn := make(map[string]string)
if len(envFileData) > 0 {
envMap := envutil.EnvToMap(string(envFileData))
for k, v := range envMap {
rtn[k] = v
}
}
cmdEnv := blockMeta.GetMap(waveobj.MetaKey_CmdEnv)
cmdEnv := blockMeta.GetStringMap(waveobj.MetaKey_CmdEnv, true)
for k, v := range cmdEnv {
if v == nil {
if v == waveobj.MetaMap_DeleteSentinel {
delete(rtn, k)
continue
}
if strVal, ok := v.(string); ok {
rtn[k] = strVal
}
if floatVal, ok := v.(float64); ok {
rtn[k] = fmt.Sprintf("%v", floatVal)
rtn[k] = v
}
connEnv := blockMeta.GetConnectionOverride(connName).GetStringMap(waveobj.MetaKey_CmdEnv, true)
for k, v := range connEnv {
if v == waveobj.MetaMap_DeleteSentinel {
delete(rtn, k)
continue
}
rtn[k] = v
}
return rtn, nil
}
@ -322,7 +381,7 @@ func (bc *BlockController) DoRunShellCommand(logCtx context.Context, rc *RunShel
return bc.manageRunningShellProcess(shellProc, rc, blockMeta)
}
func (bc *BlockController) makeSwapToken(ctx context.Context, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry {
func (bc *BlockController) makeSwapToken(ctx context.Context, logCtx context.Context, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry {
token := &shellutil.TokenSwapEntry{
Token: uuid.New().String(),
Env: make(map[string]string),
@ -360,7 +419,7 @@ func (bc *BlockController) makeSwapToken(ctx context.Context, blockMeta waveobj.
for k, v := range envMap {
token.Env[k] = v
}
token.ScriptText = getCustomInitScript(blockMeta, remoteName, shellType)
token.ScriptText = getCustomInitScript(logCtx, blockMeta, remoteName, shellType)
return token
}
@ -509,9 +568,9 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType)
}
var shellProc *shellexec.ShellProc
swapToken := bc.makeSwapToken(ctx, blockMeta, remoteName, connUnion.ShellType)
swapToken := bc.makeSwapToken(ctx, logCtx, blockMeta, remoteName, connUnion.ShellType)
cmdOpts.SwapToken = swapToken
blocklogger.Infof(logCtx, "[conndebug] created swaptoken: %s\n", swapToken.Token)
blocklogger.Debugf(logCtx, "[conndebug] created swaptoken: %s\n", swapToken.Token)
if connUnion.ConnType == ConnType_Wsl {
wslConn := connUnion.WslConn
if !connUnion.WshEnabled {
@ -533,8 +592,8 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
if err != nil {
wslConn.SetWshError(err)
wslConn.WshEnabled.Store(false)
log.Printf("error starting wsl shell proc with wsh: %v", err)
log.Print("attempting install without wsh")
blocklogger.Infof(logCtx, "[conndebug] error starting wsl shell proc with wsh: %v\n", err)
blocklogger.Infof(logCtx, "[conndebug] attempting install without wsh\n")
shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return nil, err
@ -562,8 +621,8 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
if err != nil {
conn.SetWshError(err)
conn.WshEnabled.Store(false)
log.Printf("error starting remote shell proc with wsh: %v", err)
log.Print("attempting install without wsh")
blocklogger.Infof(logCtx, "[conndebug] error starting remote shell proc with wsh: %v\n", err)
blocklogger.Infof(logCtx, "[conndebug] attempting install without wsh\n")
shellProc, err = shellexec.StartRemoteShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil {
return nil, err

View File

@ -11,6 +11,7 @@ import (
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/wavetermdev/waveterm/pkg/wavebase"
@ -116,3 +117,52 @@ func DetectMimeType(path string, fileInfo fs.FileInfo, extended bool) string {
}
return rtn
}
var (
systemBinDirs = []string{
"/bin/",
"/usr/bin/",
"/usr/local/bin/",
"/opt/bin/",
"/sbin/",
"/usr/sbin/",
}
suspiciousPattern = regexp.MustCompile(`[:;#!&$\t%="|>{}]`)
flagPattern = regexp.MustCompile(` --?[a-zA-Z0-9]`)
)
// IsInitScriptPath tries to determine if the input string is a path to a script
// rather than an inline script content.
func IsInitScriptPath(input string) bool {
if len(input) == 0 || strings.Contains(input, "\n") {
return false
}
if suspiciousPattern.MatchString(input) {
return false
}
if flagPattern.MatchString(input) {
return false
}
// Check for home directory path
if strings.HasPrefix(input, "~/") {
return true
}
// Path must be absolute (if not home directory)
if !filepath.IsAbs(input) {
return false
}
// Check if path starts with system binary directories
normalizedPath := filepath.ToSlash(input)
for _, binDir := range systemBinDirs {
if strings.HasPrefix(normalizedPath, binDir) {
return false
}
}
return true
}

View File

@ -932,3 +932,12 @@ func TimeoutFromContext(ctx context.Context, defaultTimeout time.Duration) time.
}
return time.Until(deadline)
}
func HasBinaryData(data []byte) bool {
for _, b := range data {
if b < 32 && b != '\n' && b != '\r' && b != '\t' && b != '\f' && b != '\b' {
return true
}
}
return false
}

View File

@ -3,8 +3,12 @@
package waveobj
import "github.com/google/uuid"
type MetaMapType map[string]any
var MetaMap_DeleteSentinel = uuid.NewString()
func (m MetaMapType) GetString(key string, def string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
@ -48,6 +52,26 @@ func (m MetaMapType) GetStringList(key string) []string {
return rtn
}
func (m MetaMapType) GetStringMap(key string, useDeleteSentinel bool) map[string]string {
mval := m.GetMap(key)
if len(mval) == 0 {
return nil
}
rtn := make(map[string]string, len(mval))
for k, v := range mval {
if v == nil {
if useDeleteSentinel {
rtn[k] = MetaMap_DeleteSentinel
}
continue
}
if s, ok := v.(string); ok {
rtn[k] = s
}
}
return rtn
}
func (m MetaMapType) GetBool(key string, def bool) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {

View File

@ -151,6 +151,14 @@ type ConnKeywords struct {
TermFontFamily string `json:"term:fontfamily,omitempty"`
TermTheme string `json:"term:theme,omitempty"`
CmdEnv map[string]string `json:"cmd:env,omitempty"`
CmdInitScript string `json:"cmd:initscript,omitempty"`
CmdInitScriptSh string `json:"cmd:initscript.sh,omitempty"`
CmdInitScriptBash string `json:"cmd:initscript.bash,omitempty"`
CmdInitScriptZsh string `json:"cmd:initscript.zsh,omitempty"`
CmdInitScriptPwsh string `json:"cmd:initscript.pwsh,omitempty"`
CmdInitScriptFish string `json:"cmd:initscript.fish,omitempty"`
SshUser *string `json:"ssh:user,omitempty"`
SshHostName *string `json:"ssh:hostname,omitempty"`
SshPort *string `json:"ssh:port,omitempty"`