waveterm/pkg/shellexec/shellexec.go
Sylvie Crowe 6c53eb884a
fix: remove context from wsl connection struct (#1663)
This fixes a WSL bug where disconnecting the controller does not allow
the controller to restart until the app reboots.
2025-01-02 13:05:28 -08:00

515 lines
15 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package shellexec
import (
"bytes"
"context"
"fmt"
"io"
"log"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
)
const DefaultGracefulKillWait = 400 * time.Millisecond
type CommandOptsType struct {
Interactive bool `json:"interactive,omitempty"`
Login bool `json:"login,omitempty"`
Cwd string `json:"cwd,omitempty"`
Env map[string]string `json:"env,omitempty"`
ShellPath string `json:"shellPath,omitempty"`
ShellOpts []string `json:"shellOpts,omitempty"`
}
type ShellProc struct {
ConnName string
Cmd ConnInterface
CloseOnce *sync.Once
DoneCh chan any // closed after proc.Wait() returns
WaitErr error // WaitErr is synchronized by DoneCh (written before DoneCh is closed) and CloseOnce
}
func (sp *ShellProc) Close() {
sp.Cmd.KillGraceful(DefaultGracefulKillWait)
go func() {
defer func() {
panichandler.PanicHandler("ShellProc.Close", recover())
}()
waitErr := sp.Cmd.Wait()
sp.SetWaitErrorAndSignalDone(waitErr)
// windows cannot handle the pty being
// closed twice, so we let the pty
// close itself instead
if runtime.GOOS != "windows" {
sp.Cmd.Close()
}
}()
}
func (sp *ShellProc) SetWaitErrorAndSignalDone(waitErr error) {
sp.CloseOnce.Do(func() {
sp.WaitErr = waitErr
close(sp.DoneCh)
})
}
func (sp *ShellProc) Wait() error {
<-sp.DoneCh
return sp.WaitErr
}
// returns (done, waitError)
func (sp *ShellProc) WaitNB() (bool, error) {
select {
case <-sp.DoneCh:
return true, sp.WaitErr
default:
return false, nil
}
}
func ExitCodeFromWaitErr(err error) int {
if err == nil {
return 0
}
if exitErr, ok := err.(*exec.ExitError); ok {
if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
return status.ExitStatus()
}
}
return -1
}
func checkCwd(cwd string) error {
if cwd == "" {
return fmt.Errorf("cwd is empty")
}
if _, err := os.Stat(cwd); err != nil {
return fmt.Errorf("error statting cwd %q: %w", cwd, err)
}
return nil
}
type PipePty struct {
remoteStdinWrite *os.File
remoteStdoutRead *os.File
}
func (pp *PipePty) Fd() uintptr {
return pp.remoteStdinWrite.Fd()
}
func (pp *PipePty) Name() string {
return "pipe-pty"
}
func (pp *PipePty) Read(p []byte) (n int, err error) {
return pp.remoteStdoutRead.Read(p)
}
func (pp *PipePty) Write(p []byte) (n int, err error) {
return pp.remoteStdinWrite.Write(p)
}
func (pp *PipePty) Close() error {
err1 := pp.remoteStdinWrite.Close()
err2 := pp.remoteStdoutRead.Close()
if err1 != nil {
return err1
}
return err2
}
func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
defer cancelFn()
client := conn.GetClient()
shellPath := cmdOpts.ShellPath
if shellPath == "" {
remoteShellPath, err := wsl.DetectShell(utilCtx, client)
if err != nil {
return nil, err
}
shellPath = remoteShellPath
}
var shellOpts []string
log.Printf("detected shell: %s", shellPath)
err := wsl.InstallClientRcFiles(utilCtx, client)
if err != nil {
log.Printf("error installing rc files: %v", err)
return nil, err
}
homeDir := wsl.GetHomeDir(utilCtx, client)
shellOpts = append(shellOpts, "~", "-d", client.Name())
var subShellOpts []string
if cmdStr == "" {
/* transform command in order to inject environment vars */
if isBashShell(shellPath) {
log.Printf("recognized as bash shell")
// add --rcfile
// cant set -l or -i with --rcfile
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
} else if isFishShell(shellPath) {
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
subShellOpts = append(subShellOpts, "-C", carg)
} else if wsl.IsPowershell(shellPath) {
// powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", fmt.Sprintf("%s/.waveterm/%s/wavepwsh.ps1", homeDir, shellutil.PwshIntegrationDir))
} else {
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
// can't set environment vars this way
// will try to do later if possible
}
} else {
shellPath = cmdStr
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
subShellOpts = append(subShellOpts, "-c", cmdStr)
}
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
if remote.IsPowershell(shellPath) {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
} else {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
}
if isZshShell(shellPath) {
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR=%s/.waveterm/%s`, homeDir, shellutil.ZshIntegrationDir))
}
shellOpts = append(shellOpts, shellPath)
shellOpts = append(shellOpts, subShellOpts...)
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
ecmd := exec.Command("wsl.exe", shellOpts...)
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
return nil, err
}
cmdWrap := MakeCmdWrap(ecmd, cmdPty)
return &ShellProc{Cmd: cmdWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func StartRemoteShellProcNoWsh(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
client := conn.GetClient()
session, err := client.NewSession()
if err != nil {
return nil, err
}
remoteStdinRead, remoteStdinWriteOurs, err := os.Pipe()
if err != nil {
return nil, err
}
remoteStdoutReadOurs, remoteStdoutWrite, err := os.Pipe()
if err != nil {
return nil, err
}
pipePty := &PipePty{
remoteStdinWrite: remoteStdinWriteOurs,
remoteStdoutRead: remoteStdoutReadOurs,
}
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
session.Stdin = remoteStdinRead
session.Stdout = remoteStdoutWrite
session.Stderr = remoteStdoutWrite
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
sessionWrap := MakeSessionWrap(session, "", pipePty)
err = session.Shell()
if err != nil {
pipePty.Close()
return nil, err
}
return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
client := conn.GetClient()
shellPath := cmdOpts.ShellPath
if shellPath == "" {
remoteShellPath, err := remote.DetectShell(client)
if err != nil {
return nil, err
}
shellPath = remoteShellPath
}
var shellOpts []string
var cmdCombined string
log.Printf("detected shell: %s", shellPath)
err := remote.InstallClientRcFiles(client)
if err != nil {
log.Printf("error installing rc files: %v", err)
return nil, err
}
shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
homeDir := remote.GetHomeDir(client)
if cmdStr == "" {
/* transform command in order to inject environment vars */
if isBashShell(shellPath) {
log.Printf("recognized as bash shell")
// add --rcfile
// cant set -l or -i with --rcfile
shellOpts = append(shellOpts, "--rcfile", fmt.Sprintf(`"%s"/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
} else if isFishShell(shellPath) {
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
shellOpts = append(shellOpts, "-C", carg)
} else if remote.IsPowershell(shellPath) {
// powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", fmt.Sprintf("%s/.waveterm/%s/wavepwsh.ps1", homeDir, shellutil.PwshIntegrationDir))
} else {
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
} else if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i")
}
// zdotdir setting moved to after session is created
}
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
} else {
shellPath = cmdStr
shellOpts = append(shellOpts, "-c", cmdStr)
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
}
session, err := client.NewSession()
if err != nil {
return nil, err
}
remoteStdinRead, remoteStdinWriteOurs, err := os.Pipe()
if err != nil {
return nil, err
}
remoteStdoutReadOurs, remoteStdoutWrite, err := os.Pipe()
if err != nil {
return nil, err
}
pipePty := &PipePty{
remoteStdinWrite: remoteStdinWriteOurs,
remoteStdoutRead: remoteStdoutReadOurs,
}
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
session.Stdin = remoteStdinRead
session.Stdout = remoteStdoutWrite
session.Stderr = remoteStdoutWrite
for envKey, envVal := range cmdOpts.Env {
// note these might fail depending on server settings, but we still try
session.Setenv(envKey, envVal)
}
if isZshShell(shellPath) {
cmdCombined = fmt.Sprintf(`ZDOTDIR="%s/.waveterm/%s" %s`, homeDir, shellutil.ZshIntegrationDir, cmdCombined)
}
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
if remote.IsPowershell(shellPath) {
cmdCombined = fmt.Sprintf(`$env:%s="%s"; %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
} else {
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
}
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
sessionWrap := MakeSessionWrap(session, cmdCombined, pipePty)
err = sessionWrap.Start()
if err != nil {
pipePty.Close()
return nil, err
}
return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func isZshShell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "zsh")
}
func isBashShell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "bash")
}
func isFishShell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "fish")
}
func StartShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) {
shellutil.InitCustomShellStartupFiles()
var ecmd *exec.Cmd
var shellOpts []string
shellPath := cmdOpts.ShellPath
if shellPath == "" {
shellPath = shellutil.DetectLocalShellPath()
}
shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
if cmdStr == "" {
if isBashShell(shellPath) {
// add --rcfile
// cant set -l or -i with --rcfile
shellOpts = append(shellOpts, "--rcfile", shellutil.GetBashRcFileOverride())
} else if isFishShell(shellPath) {
wshBinDir := filepath.Join(wavebase.GetWaveDataDir(), shellutil.WaveHomeBinDir)
quotedWshBinDir := utilfn.ShellQuote(wshBinDir, false, 300)
shellOpts = append(shellOpts, "-C", fmt.Sprintf("set -x PATH %s $PATH", quotedWshBinDir))
} else if remote.IsPowershell(shellPath) {
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", shellutil.GetWavePowershellEnv())
} else {
if cmdOpts.Login {
shellOpts = append(shellOpts, "-l")
} else if cmdOpts.Interactive {
shellOpts = append(shellOpts, "-i")
}
}
ecmd = exec.Command(shellPath, shellOpts...)
ecmd.Env = os.Environ()
if isZshShell(shellPath) {
shellutil.UpdateCmdEnv(ecmd, map[string]string{"ZDOTDIR": shellutil.GetZshZDotDir()})
}
} else {
shellOpts = append(shellOpts, "-c", cmdStr)
ecmd = exec.Command(shellPath, shellOpts...)
ecmd.Env = os.Environ()
}
if cmdOpts.Cwd != "" {
ecmd.Dir = cmdOpts.Cwd
}
if cwdErr := checkCwd(ecmd.Dir); cwdErr != nil {
ecmd.Dir = wavebase.GetHomeDir()
}
envToAdd := shellutil.WaveshellLocalEnvVars(shellutil.DefaultTermType)
if os.Getenv("LANG") == "" {
envToAdd["LANG"] = wavebase.DetermineLang()
}
shellutil.UpdateCmdEnv(ecmd, envToAdd)
shellutil.UpdateCmdEnv(ecmd, cmdOpts.Env)
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
return nil, err
}
cmdWrap := MakeCmdWrap(ecmd, cmdPty)
return &ShellProc{Cmd: cmdWrap, CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize waveobj.TermSize) ([]byte, error) {
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.WaveshellLocalEnvVars(shellutil.DefaultTermType))
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
cmdPty.Close()
return nil, err
}
if runtime.GOOS != "windows" {
defer cmdPty.Close()
}
ioDone := make(chan bool)
var outputBuf bytes.Buffer
go func() {
panichandler.PanicHandler("RunSimpleCmdInPty:ioCopy", recover())
// ignore error (/dev/ptmx has read error when process is done)
defer close(ioDone)
io.Copy(&outputBuf, cmdPty)
}()
exitErr := ecmd.Wait()
if exitErr != nil {
return nil, exitErr
}
<-ioDone
return outputBuf.Bytes(), nil
}