// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package shellexec import ( "bytes" "fmt" "io" "os" "os/exec" "reflect" "syscall" "github.com/creack/pty" "github.com/wavetermdev/thenextwave/pkg/util/shellutil" "github.com/wavetermdev/thenextwave/pkg/wavebase" ) type TermSize struct { Rows int `json:"rows"` Cols int `json:"cols"` } type ShellProc struct { Cmd *exec.Cmd Pty *os.File } func (sp *ShellProc) Close() { sp.Cmd.Process.Kill() go func() { sp.Cmd.Process.Wait() sp.Pty.Close() }() } func setBoolConditionally(rval reflect.Value, field string, value bool) { if rval.Elem().FieldByName(field).IsValid() { rval.Elem().FieldByName(field).SetBool(value) } } func setSysProcAttrs(cmd *exec.Cmd) { rval := reflect.ValueOf(cmd.SysProcAttr) setBoolConditionally(rval, "Setsid", true) setBoolConditionally(rval, "Setctty", true) } func StartShellProc(termSize TermSize) (*ShellProc, error) { shellPath := shellutil.DetectLocalShellPath() ecmd := exec.Command(shellPath, "-i", "-l") ecmd.Env = os.Environ() ecmd.Dir = wavebase.GetHomeDir() envToAdd := shellutil.WaveshellEnvVars(shellutil.DefaultTermType) if os.Getenv("LANG") == "" { envToAdd["LANG"] = wavebase.DetermineLang() } shellutil.UpdateCmdEnv(ecmd, envToAdd) cmdPty, cmdTty, err := pty.Open() if err != nil { return nil, fmt.Errorf("opening new pty: %w", err) } if termSize.Rows == 0 || termSize.Cols == 0 { termSize.Rows = shellutil.DefaultTermRows termSize.Cols = shellutil.DefaultTermCols } if termSize.Rows <= 0 || termSize.Cols <= 0 { return nil, fmt.Errorf("invalid term size: %v", termSize) } pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) ecmd.Stdin = cmdTty ecmd.Stdout = cmdTty ecmd.Stderr = cmdTty ecmd.SysProcAttr = &syscall.SysProcAttr{} setSysProcAttrs(ecmd) err = ecmd.Start() cmdTty.Close() if err != nil { cmdPty.Close() return nil, err } return &ShellProc{Cmd: ecmd, Pty: cmdPty}, nil } func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize TermSize) ([]byte, error) { ecmd.Env = os.Environ() shellutil.UpdateCmdEnv(ecmd, shellutil.WaveshellEnvVars(shellutil.DefaultTermType)) cmdPty, cmdTty, err := pty.Open() if err != nil { return nil, fmt.Errorf("opening new pty: %w", err) } if termSize.Rows == 0 || termSize.Cols == 0 { termSize.Rows = shellutil.DefaultTermRows termSize.Cols = shellutil.DefaultTermCols } if termSize.Rows <= 0 || termSize.Cols <= 0 { return nil, fmt.Errorf("invalid term size: %v", termSize) } pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) ecmd.Stdin = cmdTty ecmd.Stdout = cmdTty ecmd.Stderr = cmdTty ecmd.SysProcAttr = &syscall.SysProcAttr{} setSysProcAttrs(ecmd) err = ecmd.Start() cmdTty.Close() if err != nil { cmdPty.Close() return nil, err } defer cmdPty.Close() ioDone := make(chan bool) var outputBuf bytes.Buffer go func() { // ignore error (/dev/ptmx has read error when process is done) defer close(ioDone) io.Copy(&outputBuf, cmdPty) }() exitErr := ecmd.Wait() if exitErr != nil { return nil, exitErr } <-ioDone return outputBuf.Bytes(), nil }