From 07d07472db28fc6f4e3cd5cf71305cd9cffc0cef Mon Sep 17 00:00:00 2001 From: Mike Sawka Date: Tue, 14 Jan 2025 15:29:36 -0800 Subject: [PATCH] move genconn quote, and getshelltype to shellutil (#1731) --- pkg/genconn/genconn.go | 7 +- pkg/remote/conncontroller/conncontroller.go | 3 +- pkg/remote/conncontroller/tokenswap.go | 73 +++++++++++++++++++ pkg/shellexec/shellexec.go | 48 +++--------- .../quote.go => util/shellutil/shellquote.go} | 36 ++++++++- .../shellutil/shellquote_test.go} | 2 +- pkg/util/shellutil/shellutil.go | 30 +++++++- 7 files changed, 151 insertions(+), 48 deletions(-) create mode 100644 pkg/remote/conncontroller/tokenswap.go rename pkg/{genconn/quote.go => util/shellutil/shellquote.go} (74%) rename pkg/{genconn/quote_test.go => util/shellutil/shellquote_test.go} (99%) diff --git a/pkg/genconn/genconn.go b/pkg/genconn/genconn.go index d81ed616e..59bb439a1 100644 --- a/pkg/genconn/genconn.go +++ b/pkg/genconn/genconn.go @@ -12,6 +12,7 @@ import ( "strings" "sync" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/syncbuf" ) @@ -114,17 +115,17 @@ func BuildShellCommand(opts CommandSpec) (string, error) { if !isValidEnvVarName(key) { return "", fmt.Errorf("invalid environment variable name: %q", key) } - envVars.WriteString(fmt.Sprintf("%s=%s ", key, HardQuote(value))) + envVars.WriteString(fmt.Sprintf("%s=%s ", key, shellutil.HardQuote(value))) } // Build the command shellCmd := opts.Cmd if opts.Cwd != "" { - shellCmd = fmt.Sprintf("cd %s && %s", HardQuote(opts.Cwd), shellCmd) + shellCmd = fmt.Sprintf("cd %s && %s", shellutil.HardQuote(opts.Cwd), shellCmd) } // Quote the command for `sh -c` - return fmt.Sprintf("sh -c %s", HardQuote(envVars.String()+shellCmd)), nil + return fmt.Sprintf("sh -c %s", shellutil.HardQuote(envVars.String()+shellCmd)), nil } func isValidEnvVarName(name string) bool { diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index 27bd9b7d4..a7e9ac03a 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -26,6 +26,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/userinput" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -298,7 +299,7 @@ func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, string, } cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath) log.Printf("starting conn controller: %q\n", cmdStr) - shWrappedCmdStr := fmt.Sprintf("sh -c %s", genconn.HardQuote(cmdStr)) + shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr)) blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr) err = sshSession.Start(shWrappedCmdStr) if err != nil { diff --git a/pkg/remote/conncontroller/tokenswap.go b/pkg/remote/conncontroller/tokenswap.go new file mode 100644 index 000000000..14498316a --- /dev/null +++ b/pkg/remote/conncontroller/tokenswap.go @@ -0,0 +1,73 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package conncontroller + +import ( + "fmt" + + "github.com/wavetermdev/waveterm/pkg/util/shellutil" +) + +type TokenSwapEntry struct { + Token string + Env map[string]string + ScriptText string +} + +func encodeEnvVarsForBash(env map[string]string) (string, error) { + var encoded string + for k, v := range env { + // validate key + if !shellutil.IsValidEnvVarName(k) { + return "", fmt.Errorf("invalid env var name: %q", k) + } + encoded += fmt.Sprintf("export %s=%s\n", k, shellutil.HardQuote(v)) + } + return encoded, nil +} + +func encodeEnvVarsForFish(env map[string]string) (string, error) { + var encoded string + for k, v := range env { + // validate key + if !shellutil.IsValidEnvVarName(k) { + return "", fmt.Errorf("invalid env var name: %q", k) + } + encoded += fmt.Sprintf("set -x %s %s\n", k, shellutil.HardQuoteFish(v)) + } + return encoded, nil +} + +func encodeEnvVarsForPowerShell(env map[string]string) (string, error) { + var encoded string + for k, v := range env { + // validate key + if !shellutil.IsValidEnvVarName(k) { + return "", fmt.Errorf("invalid env var name: %q", k) + } + encoded += fmt.Sprintf("$env:%s = %s\n", k, shellutil.HardQuotePowerShell(v)) + } + return encoded, nil +} + +func EncodeEnvVarsForShell(shellType string, env map[string]string) (string, error) { + switch shellType { + case shellutil.ShellType_bash, shellutil.ShellType_zsh: + return encodeEnvVarsForBash(env) + case shellutil.ShellType_fish: + return encodeEnvVarsForFish(env) + case shellutil.ShellType_pwsh: + return encodeEnvVarsForPowerShell(env) + default: + return "", fmt.Errorf("unknown or unsupported shell type for env var encoding: %s", shellType) + } +} + +func (t *TokenSwapEntry) EncodeForShell(shellType string) (string, error) { + encodedEnv, err := EncodeEnvVarsForShell(shellType, t.Env) + if err != nil { + return "", err + } + return encodedEnv + "\n" + t.ScriptText, nil +} diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 545d5e627..66f23b065 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -19,7 +19,6 @@ import ( "time" "github.com/creack/pty" - "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" @@ -35,14 +34,6 @@ import ( const DefaultGracefulKillWait = 400 * time.Millisecond -const ( - ShellType_bash = "bash" - ShellType_zsh = "zsh" - ShellType_fish = "fish" - ShellType_pwsh = "pwsh" - ShellType_unknown = "unknown" -) - type CommandOptsType struct { Interactive bool `json:"interactive,omitempty"` Login bool `json:"login,omitempty"` @@ -158,23 +149,6 @@ func (pp *PipePty) WriteString(s string) (n int, err error) { return pp.Write([]byte(s)) } -func getShellTypeFromShellPath(shellPath string) string { - shellBase := filepath.Base(shellPath) - if strings.Contains(shellBase, "bash") { - return ShellType_bash - } - if strings.Contains(shellBase, "zsh") { - return ShellType_zsh - } - if strings.Contains(shellBase, "fish") { - return ShellType_fish - } - if strings.Contains(shellBase, "pwsh") || strings.Contains(shellBase, "powershell") { - return ShellType_pwsh - } - return ShellType_unknown -} - func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) { utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second) defer cancelFn() @@ -349,17 +323,17 @@ func StartRemoteShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr return nil, err } shellOpts = append(shellOpts, cmdOpts.ShellOpts...) - shellType := getShellTypeFromShellPath(shellPath) + shellType := shellutil.GetShellTypeFromShellPath(shellPath) conn.Infof(ctx, "detected shell type: %s\n", shellType) if cmdStr == "" { /* transform command in order to inject environment vars */ - if shellType == ShellType_bash { + if shellType == shellutil.ShellType_bash { // add --rcfile // cant set -l or -i with --rcfile bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir) shellOpts = append(shellOpts, "--rcfile", bashPath) - } else if shellType == ShellType_fish { + } else if shellType == shellutil.ShellType_fish { if cmdOpts.Login { shellOpts = append(shellOpts, "-l") } @@ -367,7 +341,7 @@ func StartRemoteShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir) carg := fmt.Sprintf(`"source %s"`, waveFishPath) shellOpts = append(shellOpts, "-C", carg) - } else if shellType == ShellType_pwsh { + } else if shellType == shellutil.ShellType_pwsh { pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir) // powershell is weird about quoted path executables and requires an ampersand first shellPath = "& " + shellPath @@ -424,7 +398,7 @@ func StartRemoteShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr session.Setenv(envKey, envVal) } - if shellType == ShellType_zsh { + if shellType == shellutil.ShellType_zsh { zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir) conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir) cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined) @@ -471,21 +445,21 @@ func StartLocalShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comma if shellPath == "" { shellPath = shellutil.DetectLocalShellPath() } - shellType := getShellTypeFromShellPath(shellPath) + shellType := shellutil.GetShellTypeFromShellPath(shellPath) shellOpts = append(shellOpts, cmdOpts.ShellOpts...) if cmdStr == "" { - if shellType == ShellType_bash { + if shellType == shellutil.ShellType_bash { // add --rcfile // cant set -l or -i with --rcfile shellOpts = append(shellOpts, "--rcfile", shellutil.GetLocalBashRcFileOverride()) - } else if shellType == ShellType_fish { + } else if shellType == shellutil.ShellType_fish { if cmdOpts.Login { shellOpts = append(shellOpts, "-l") } waveFishPath := shellutil.GetLocalWaveFishFilePath() - carg := fmt.Sprintf("source %s", genconn.HardQuote(waveFishPath)) + carg := fmt.Sprintf("source %s", shellutil.HardQuoteFish(waveFishPath)) shellOpts = append(shellOpts, "-C", carg) - } else if shellType == ShellType_pwsh { + } else if shellType == shellutil.ShellType_pwsh { shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", shellutil.GetLocalWavePowershellEnv()) } else { if cmdOpts.Login { @@ -497,7 +471,7 @@ func StartLocalShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comma } ecmd = exec.Command(shellPath, shellOpts...) ecmd.Env = os.Environ() - if shellType == ShellType_zsh { + if shellType == shellutil.ShellType_zsh { shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetLocalZshZDotDir()}) } } else { diff --git a/pkg/genconn/quote.go b/pkg/util/shellutil/shellquote.go similarity index 74% rename from pkg/genconn/quote.go rename to pkg/util/shellutil/shellquote.go index 469359cbf..fa7b66d2c 100644 --- a/pkg/genconn/quote.go +++ b/pkg/util/shellutil/shellquote.go @@ -1,15 +1,20 @@ // Copyright 2025, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -package genconn +package shellutil import "regexp" var ( - safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) - psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`) + safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) + psSafePattern = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`) + envVarNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) ) +func IsValidEnvVarName(name string) bool { + return envVarNamePattern.MatchString(name) +} + // TODO: fish quoting is slightly different // specifically \` will cause an inconsistency between fish and bash/zsh :/ // might need a specific fish quoting function, and an explicit fish shell detection @@ -40,6 +45,31 @@ func HardQuote(s string) string { return string(buf) } +func HardQuoteFish(s string) string { + if s == "" { + return "\"\"" + } + + if safePattern.MatchString(s) { + return s + } + + buf := make([]byte, 0, len(s)+5) + buf = append(buf, '"') + + for i := 0; i < len(s); i++ { + switch s[i] { + case '"', '\\', '$': // Escape only these characters + buf = append(buf, '\\', s[i]) + default: + buf = append(buf, s[i]) + } + } + + buf = append(buf, '"') + return string(buf) +} + func HardQuotePowerShell(s string) string { if s == "" { return "\"\"" diff --git a/pkg/genconn/quote_test.go b/pkg/util/shellutil/shellquote_test.go similarity index 99% rename from pkg/genconn/quote_test.go rename to pkg/util/shellutil/shellquote_test.go index 0cc40f158..334840ef4 100644 --- a/pkg/genconn/quote_test.go +++ b/pkg/util/shellutil/shellquote_test.go @@ -1,6 +1,6 @@ // Copyright 2025, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -package genconn +package shellutil import "testing" diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index f89568bcf..b9c6e2e4e 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -17,7 +17,6 @@ import ( "sync" "time" - "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -33,6 +32,14 @@ var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`) const DefaultShellPath = "/bin/bash" +const ( + ShellType_bash = "bash" + ShellType_zsh = "zsh" + ShellType_fish = "fish" + ShellType_pwsh = "pwsh" + ShellType_unknown = "unknown" +) + const ( // there must be no spaces in these integration dir paths ZshIntegrationDir = "shell/zsh" @@ -306,8 +313,8 @@ func InitRcFiles(waveHome string, absWshBinDir string) error { pathSep = ":" } params := map[string]string{ - "WSHBINDIR": genconn.HardQuote(absWshBinDir), - "WSHBINDIR_PWSH": genconn.HardQuotePowerShell(absWshBinDir), + "WSHBINDIR": HardQuote(absWshBinDir), + "WSHBINDIR_PWSH": HardQuotePowerShell(absWshBinDir), "PATHSEP": pathSep, } @@ -379,3 +386,20 @@ func initCustomShellStartupFilesInternal() error { log.Printf("wsh binary successfully copied from %q to %q\n", wshBaseName, wshDstPath) return nil } + +func GetShellTypeFromShellPath(shellPath string) string { + shellBase := filepath.Base(shellPath) + if strings.Contains(shellBase, "bash") { + return ShellType_bash + } + if strings.Contains(shellBase, "zsh") { + return ShellType_zsh + } + if strings.Contains(shellBase, "fish") { + return ShellType_fish + } + if strings.Contains(shellBase, "pwsh") || strings.Contains(shellBase, "powershell") { + return ShellType_pwsh + } + return ShellType_unknown +}