Powershell Wsh Integration (#320)

Add wsh to the path in powershell. Should work locally and in remote
connections. Should work on both windows and unix systems.
This commit is contained in:
Sylvie Crowe 2024-09-04 02:13:00 -07:00 committed by GitHub
parent 4e60880b8f
commit b3a7c466e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 21 deletions

View File

@ -212,7 +212,16 @@ func (conn *SSHConn) StartConnServer() error {
pipeRead, pipeWrite := io.Pipe()
sshSession.Stdout = pipeWrite
sshSession.Stderr = pipeWrite
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
shellPath, err := remote.DetectShell(client)
if err != nil {
return err
}
var cmdStr string
if remote.IsPowershell(shellPath) {
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
} else {
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
}
log.Printf("starting conn controller: %s\n", cmdStr)
err = sshSession.Start(cmdStr)
if err != nil {

View File

@ -311,10 +311,25 @@ func GetHomeDir(client *ssh.Client) string {
return "~"
}
out, err := session.Output("pwd")
out, err := session.Output(`echo "$HOME"`)
if err == nil {
return strings.TrimSpace(string(out))
}
session, err = client.NewSession()
if err != nil {
return "~"
}
return strings.TrimSpace(string(out))
out, err = session.Output(`echo %userprofile%`)
if err == nil {
return strings.TrimSpace(string(out))
}
return "~"
}
func IsPowershell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
}

View File

@ -11,8 +11,6 @@ import (
"os"
"os/exec"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
"sync"
@ -91,18 +89,6 @@ func ExitCodeFromWaitErr(err error) int {
}
func setBoolConditionally(rval reflect.Value, field string, value bool) {
if rval.Elem().FieldByName(field).IsValid() {
rval.Elem().FieldByName(field).SetBool(value)
}
}
func setSysProcAttrs(cmd *exec.Cmd) {
rval := reflect.ValueOf(cmd.SysProcAttr)
setBoolConditionally(rval, "Setsid", true)
setBoolConditionally(rval, "Setctty", true)
}
func checkCwd(cwd string) error {
if cwd == "" {
return fmt.Errorf("cwd is empty")
@ -113,8 +99,6 @@ func checkCwd(cwd string) error {
return nil
}
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
type PipePty struct {
remoteStdinWrite *os.File
remoteStdoutRead *os.File
@ -174,6 +158,10 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm
// add --rcfile
// cant set -l or -i with --rcfile
shellOpts = append(shellOpts, "--rcfile", fmt.Sprintf(`"%s"/.waveterm/bash-integration/.bashrc`, homeDir))
} else if remote.IsPowershell(shellPath) {
// powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath
shellOpts = append(shellOpts, "-NoExit", "-File", homeDir+"/.waveterm/pwsh-integration/wavepwsh.ps1")
} else {
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
@ -241,7 +229,12 @@ func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts Comm
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
if remote.IsPowershell(shellPath) {
cmdCombined = fmt.Sprintf(`$env:%s="%s"; %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
} else {
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
}
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
@ -277,7 +270,9 @@ func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOpt
// add --rcfile
// cant set -l or -i with --rcfile
shellOpts = append(shellOpts, "--rcfile", shellutil.GetBashRcFileOverride())
} else if runtime.GOOS != "windows" {
} else if remote.IsPowershell(shellPath) {
shellOpts = append(shellOpts, "-NoExit", "-File", shellutil.GetWavePowershellEnv())
} else {
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
}

View File

@ -38,6 +38,7 @@ const AppPathBinDir = "bin"
const (
ZshIntegrationDir = "zsh-integration"
BashIntegrationDir = "bash-integration"
PwshIntegrationDir = "pwsh-integration"
WaveHomeBinDir = "bin"
ZshStartup_Zprofile = `
@ -77,6 +78,12 @@ elif [ -f ~/.profile ]; then
fi
export PATH={{.WSHBINDIR}}:$PATH
`
PwshStartup_wavepwsh = `
# no need to source regular profiles since we cannot
# overwrite those with powershell. Instead we will source
# this file with -NoExit
$env:PATH = "{{.WSHBINDIR}}" + "{{.PATHSEP}}" + $env:PATH
`
)
@ -194,6 +201,10 @@ func GetBashRcFileOverride() string {
return filepath.Join(wavebase.GetWaveHomeDir(), BashIntegrationDir, ".bashrc")
}
func GetWavePowershellEnv() string {
return filepath.Join(wavebase.GetWaveHomeDir(), PwshIntegrationDir, "wavepwsh.ps1")
}
func GetZshZDotDir() string {
return filepath.Join(wavebase.GetWaveHomeDir(), ZshIntegrationDir)
}
@ -218,6 +229,11 @@ func InitRcFiles(waveHome string, wshBinDir string) error {
if err != nil {
return err
}
pwshDir := filepath.Join(waveHome, PwshIntegrationDir)
err = wavebase.CacheEnsureDir(pwshDir, PwshIntegrationDir, 0755, PwshIntegrationDir)
if err != nil {
return err
}
// write files to directory
zprofilePath := filepath.Join(zshDir, ".zprofile")
@ -243,6 +259,16 @@ func InitRcFiles(waveHome string, wshBinDir string) error {
if err != nil {
return fmt.Errorf("error writing bash-integration .bashrc: %v", err)
}
var pathSep string
if runtime.GOOS == "windows" {
pathSep = ";"
} else {
pathSep = ":"
}
err = utilfn.WriteTemplateToFile(filepath.Join(pwshDir, "wavepwsh.ps1"), PwshStartup_wavepwsh, map[string]string{"WSHBINDIR": toPwshEnvVarRef(wshBinDir), "PATHSEP": pathSep})
if err != nil {
return fmt.Errorf("error writing pwsh-integration wavepwsh.ps1: %v", err)
}
return nil
}
@ -282,3 +308,7 @@ func initCustomShellStartupFilesInternal() error {
func computeWshBaseName() string {
return fmt.Sprintf("wsh-%s-%s.%s", wavebase.WaveVersion, runtime.GOOS, runtime.GOARCH)
}
func toPwshEnvVarRef(input string) string {
return strings.Replace(input, "$", "$env:", -1)
}