From b3a7c466e54969b309175324f7e8b82c6464bd1b Mon Sep 17 00:00:00 2001 From: Sylvie Crowe <107814465+oneirocosm@users.noreply.github.com> Date: Wed, 4 Sep 2024 02:13:00 -0700 Subject: [PATCH] Powershell Wsh Integration (#320) Add wsh to the path in powershell. Should work locally and in remote connections. Should work on both windows and unix systems. --- pkg/remote/conncontroller/conncontroller.go | 11 +++++++- pkg/remote/connutil.go | 19 +++++++++++-- pkg/shellexec/shellexec.go | 31 +++++++++------------ pkg/util/shellutil/shellutil.go | 30 ++++++++++++++++++++ 4 files changed, 70 insertions(+), 21 deletions(-) diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index 818e619e5..93271cf53 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -212,7 +212,16 @@ func (conn *SSHConn) StartConnServer() error { pipeRead, pipeWrite := io.Pipe() sshSession.Stdout = pipeWrite sshSession.Stderr = pipeWrite - cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) + shellPath, err := remote.DetectShell(client) + if err != nil { + return err + } + var cmdStr string + if remote.IsPowershell(shellPath) { + cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) + } else { + cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) + } log.Printf("starting conn controller: %s\n", cmdStr) err = sshSession.Start(cmdStr) if err != nil { diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 28d9174df..de58ad9ff 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -311,10 +311,25 @@ func GetHomeDir(client *ssh.Client) string { return "~" } - out, err := session.Output("pwd") + out, err := session.Output(`echo "$HOME"`) + if err == nil { + return strings.TrimSpace(string(out)) + } + + session, err = client.NewSession() if err != nil { return "~" } - return strings.TrimSpace(string(out)) + out, err = session.Output(`echo %userprofile%`) + if err == nil { + return strings.TrimSpace(string(out)) + } + return "~" +} + +func IsPowershell(shellPath string) bool { + // get the base path, and then check contains + shellBase := filepath.Base(shellPath) + return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh") } diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 2936fd91e..056d98e66 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -11,8 +11,6 @@ import ( "os" "os/exec" "path/filepath" - "reflect" - "regexp" "runtime" "strings" "sync" @@ -91,18 +89,6 @@ func ExitCodeFromWaitErr(err error) int { } -func setBoolConditionally(rval reflect.Value, field string, value bool) { - if rval.Elem().FieldByName(field).IsValid() { - rval.Elem().FieldByName(field).SetBool(value) - } -} - -func setSysProcAttrs(cmd *exec.Cmd) { - rval := reflect.ValueOf(cmd.SysProcAttr) - setBoolConditionally(rval, "Setsid", true) - setBoolConditionally(rval, "Setctty", true) -} - func checkCwd(cwd string) error { if cwd == "" { return fmt.Errorf("cwd is empty") @@ -113,8 +99,6 @@ func checkCwd(cwd string) error { return nil } -var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`) - type PipePty struct { remoteStdinWrite *os.File remoteStdoutRead *os.File @@ -174,6 +158,10 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm // add --rcfile // cant set -l or -i with --rcfile shellOpts = append(shellOpts, "--rcfile", fmt.Sprintf(`"%s"/.waveterm/bash-integration/.bashrc`, homeDir)) + } else if remote.IsPowershell(shellPath) { + // powershell is weird about quoted path executables and requires an ampersand first + shellPath = "& " + shellPath + shellOpts = append(shellOpts, "-NoExit", "-File", homeDir+"/.waveterm/pwsh-integration/wavepwsh.ps1") } else { if cmdOpts.Login { shellOpts = append(shellOpts, "-l") @@ -241,7 +229,12 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm if !ok { return nil, fmt.Errorf("no jwt token provided to connection") } - cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) + + if remote.IsPowershell(shellPath) { + cmdCombined = fmt.Sprintf(`$env:%s="%s"; %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) + } else { + cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined) + } session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil) @@ -277,7 +270,9 @@ func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOpt // add --rcfile // cant set -l or -i with --rcfile shellOpts = append(shellOpts, "--rcfile", shellutil.GetBashRcFileOverride()) - } else if runtime.GOOS != "windows" { + } else if remote.IsPowershell(shellPath) { + shellOpts = append(shellOpts, "-NoExit", "-File", shellutil.GetWavePowershellEnv()) + } else { if cmdOpts.Login { shellOpts = append(shellOpts, "-l") } diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go index fa832ea15..5c6e16f55 100644 --- a/pkg/util/shellutil/shellutil.go +++ b/pkg/util/shellutil/shellutil.go @@ -38,6 +38,7 @@ const AppPathBinDir = "bin" const ( ZshIntegrationDir = "zsh-integration" BashIntegrationDir = "bash-integration" + PwshIntegrationDir = "pwsh-integration" WaveHomeBinDir = "bin" ZshStartup_Zprofile = ` @@ -77,6 +78,12 @@ elif [ -f ~/.profile ]; then fi export PATH={{.WSHBINDIR}}:$PATH +` + PwshStartup_wavepwsh = ` +# no need to source regular profiles since we cannot +# overwrite those with powershell. Instead we will source +# this file with -NoExit +$env:PATH = "{{.WSHBINDIR}}" + "{{.PATHSEP}}" + $env:PATH ` ) @@ -194,6 +201,10 @@ func GetBashRcFileOverride() string { return filepath.Join(wavebase.GetWaveHomeDir(), BashIntegrationDir, ".bashrc") } +func GetWavePowershellEnv() string { + return filepath.Join(wavebase.GetWaveHomeDir(), PwshIntegrationDir, "wavepwsh.ps1") +} + func GetZshZDotDir() string { return filepath.Join(wavebase.GetWaveHomeDir(), ZshIntegrationDir) } @@ -218,6 +229,11 @@ func InitRcFiles(waveHome string, wshBinDir string) error { if err != nil { return err } + pwshDir := filepath.Join(waveHome, PwshIntegrationDir) + err = wavebase.CacheEnsureDir(pwshDir, PwshIntegrationDir, 0755, PwshIntegrationDir) + if err != nil { + return err + } // write files to directory zprofilePath := filepath.Join(zshDir, ".zprofile") @@ -243,6 +259,16 @@ func InitRcFiles(waveHome string, wshBinDir string) error { if err != nil { return fmt.Errorf("error writing bash-integration .bashrc: %v", err) } + var pathSep string + if runtime.GOOS == "windows" { + pathSep = ";" + } else { + pathSep = ":" + } + err = utilfn.WriteTemplateToFile(filepath.Join(pwshDir, "wavepwsh.ps1"), PwshStartup_wavepwsh, map[string]string{"WSHBINDIR": toPwshEnvVarRef(wshBinDir), "PATHSEP": pathSep}) + if err != nil { + return fmt.Errorf("error writing pwsh-integration wavepwsh.ps1: %v", err) + } return nil } @@ -282,3 +308,7 @@ func initCustomShellStartupFilesInternal() error { func computeWshBaseName() string { return fmt.Sprintf("wsh-%s-%s.%s", wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH) } + +func toPwshEnvVarRef(input string) string { + return strings.Replace(input, "$", "$env:", -1) +}