force our exit trap to always run (for rtnstate commands) (#556)

* add command validation to shellapi.  mock out bash/zsh versions

* implement validate command fn bash and zsh

* test validate command

* change rtnstate commands to always end with a builtin, so we always get our exit trap to run

* simplify the rtnstate modification, don't add the 'wait' (as this is a different problem/feature)

* update schema
This commit is contained in:
Mike Sawka 2024-04-09 11:33:23 -07:00 committed by GitHub
parent 1f5309e097
commit 6919dbfb5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 145 additions and 49 deletions

View File

@ -6,6 +6,7 @@ package shellapi
import (
"bytes"
"context"
"errors"
"fmt"
"os/exec"
"runtime"
@ -266,6 +267,22 @@ func (bashShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash
return rtn, nil
}
func (bashShellApi) ValidateCommandSyntax(cmdStr string) error {
ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout)
defer cancelFn()
cmd := exec.CommandContext(ctx, GetLocalBashPath(), "-n", "-c", cmdStr)
output, err := cmd.CombinedOutput()
if err == nil {
return nil
}
errStr := utilfn.GetFirstLine(string(output))
errStr = strings.TrimPrefix(errStr, "bash: -c: ")
if len(errStr) == 0 {
return errors.New("bash syntax error")
}
return errors.New(errStr)
}
func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
if oldState == nil {
return nil, fmt.Errorf("cannot apply diff, oldState is nil")

View File

@ -214,9 +214,16 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
firstParseErr = err
}
}
if decl != nil && !BashNoStoreVarNames[decl.Name] {
declMap[decl.Name] = decl
if decl == nil {
continue
}
if BashNoStoreVarNames[decl.Name] {
continue
}
if strings.HasPrefix(decl.Name, "_wavetemp_") {
continue
}
declMap[decl.Name] = decl
}
pvarMap := parseExtVarOutput(pvarBytes, "", "")
utilfn.CombineMaps(declMap, pvarMap)

View File

@ -28,6 +28,7 @@ import (
)
const GetVersionTimeout = 5 * time.Second
const ValidateTimeout = 2 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
@ -69,6 +70,7 @@ type ShellStateOutput struct {
type ShellApi interface {
GetShellType() string
MakeExitTrap(fdNum int) (string, []byte)
ValidateCommandSyntax(cmdStr string) error
GetLocalMajorVersion() string
GetLocalShellPath() string
GetRemoteShellPath() string

View File

@ -211,27 +211,41 @@ type ZshMap = map[ZshParamKey]string
type zshShellApi struct{}
func (z zshShellApi) GetShellType() string {
func (zshShellApi) GetShellType() string {
return packet.ShellType_zsh
}
func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
func (zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeZshExitTrap(fdNum)
}
func (z zshShellApi) GetLocalMajorVersion() string {
func (zshShellApi) GetLocalMajorVersion() string {
return GetLocalZshMajorVersion()
}
func (z zshShellApi) GetLocalShellPath() string {
func (zshShellApi) GetLocalShellPath() string {
return "/bin/zsh"
}
func (z zshShellApi) GetRemoteShellPath() string {
func (zshShellApi) GetRemoteShellPath() string {
return "zsh"
}
func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
func (zshShellApi) ValidateCommandSyntax(cmdStr string) error {
ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout)
defer cancelFn()
cmd := exec.CommandContext(ctx, GetLocalZshPath(), "-n", "-c", cmdStr)
output, err := cmd.CombinedOutput()
if err == nil {
return nil
}
if len(output) == 0 {
return errors.New("zsh syntax error")
}
return errors.New(utilfn.GetFirstLine(string(output)))
}
func (zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
if !opts.Sudo {
return cmdStr
}
@ -242,7 +256,7 @@ func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
}
}
func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
func (zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
}
@ -274,7 +288,7 @@ func (z zshShellApi) GetShellState(ctx context.Context, outCh chan ShellStateOut
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
}
func (z zshShellApi) GetBaseShellOpts() string {
func (zshShellApi) GetBaseShellOpts() string {
return BaseZshOpts
}
@ -343,6 +357,9 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
if strings.HasPrefix(varDecl.Name, "ZFTP_") {
continue
}
if strings.HasPrefix(varDecl.Name, "_wavetemp_") {
continue
}
if varDecl.IsExtVar {
continue
}
@ -709,7 +726,7 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
return buf.String()
}
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
func (zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_zsh, outputBytes)
}

View File

@ -2,7 +2,9 @@ package shellapi
import (
"fmt"
"log"
"testing"
"time"
)
func testSingleDecl(declStr string) {
@ -45,3 +47,35 @@ func TestZshSafeDeclName(t *testing.T) {
t.Errorf("should not be safe")
}
}
func testValidate(t *testing.T, shell string, cmd string, expectErr bool) {
var sapi ShellApi
if shell == "bash" {
sapi = bashShellApi{}
} else if shell == "zsh" {
sapi = zshShellApi{}
} else {
t.Errorf("unknown shell %q", shell)
return
}
tstart := time.Now()
err := sapi.ValidateCommandSyntax(cmd)
log.Printf("shell:%s dur:%v err: %v\n", shell, time.Since(tstart), err)
if expectErr && err == nil {
t.Errorf("cmd %q, expected error", cmd)
}
if !expectErr && err != nil {
t.Errorf("cmd %q, unexpected error: %v", cmd, err)
}
}
func TestValidate(t *testing.T) {
testValidate(t, "zsh", "echo foo", false)
testValidate(t, "zsh", "foo >& &", true)
testValidate(t, "zsh", "cd .", false)
testValidate(t, "zsh", "echo foo | grep foo", false)
testValidate(t, "zsh", "x; echo \"hello", true)
testValidate(t, "bash", "echo foo", false)
testValidate(t, "bash", "foo >& &", true)
testValidate(t, "bash", "cd .; echo \"", true)
}

View File

@ -12,6 +12,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"golang.org/x/crypto/ssh"
"golang.org/x/mod/semver"
)
@ -274,7 +275,7 @@ func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.Pa
cmdDuration := endTs.Sub(cproc.StartTs)
donePacket := packet.MakeCmdDonePacket(ck)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetExitCode(exitErr)
donePacket.ExitCode = utilfn.GetExitCode(exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
sender.SendPacket(donePacket)
}

View File

@ -31,6 +31,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/wlog"
"golang.org/x/mod/semver"
"golang.org/x/sys/unix"
@ -826,6 +827,10 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
var rtnStateWriter *os.File
rcFileStr := sapi.MakeRcFileStr(pk)
if pk.ReturnState {
err := sapi.ValidateCommandSyntax(pk.Command)
if err != nil {
return nil, err
}
pr, pw, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("cannot create returnstate pipe: %v", err)
@ -894,7 +899,12 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
os.Remove(cmd.TmpRcFileName)
}()
}
cmd.Cmd = sapi.MakeShExecCommand(pk.Command, rcFileName, pk.UsePty)
fullCmdStr := pk.Command
if pk.ReturnState {
// this ensures that the last command is a shell buitin so we always get our exit trap to run
fullCmdStr = fullCmdStr + "\nexit $? 2> /dev/null"
}
cmd.Cmd = sapi.MakeShExecCommand(fullCmdStr, rcFileName, pk.UsePty)
if !pk.StateComplete {
cmd.Cmd.Env = os.Environ()
}
@ -1075,34 +1085,6 @@ func copyToCirFile(dest *cirfile.File, src io.Reader) error {
}
}
func GetCmdExitCode(cmd *exec.Cmd, err error) int {
if cmd == nil || cmd.ProcessState == nil {
return GetExitCode(err)
}
status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
if !ok {
return cmd.ProcessState.ExitCode()
}
signaled := status.Signaled()
if signaled {
signal := status.Signal()
return 128 + int(signal)
}
exitStatus := status.ExitStatus()
return exitStatus
}
func GetExitCode(err error) int {
if err == nil {
return 0
}
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
} else {
return -1
}
}
func (c *ShExecType) ProcWait() error {
exitErr := c.Cmd.Wait()
c.Lock.Lock()
@ -1139,7 +1121,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetCmdExitCode(c.Cmd, exitErr)
donePacket.ExitCode = utilfn.GetCmdExitCode(c.Cmd, exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
if c.FileNames != nil {
os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)

View File

@ -15,9 +15,11 @@ import (
mathrand "math/rand"
"net/http"
"os"
"os/exec"
"regexp"
"sort"
"strings"
"syscall"
"unicode/utf8"
)
@ -635,3 +637,39 @@ func DetectMimeType(path string) string {
}
return rtn
}
func GetCmdExitCode(cmd *exec.Cmd, err error) int {
if cmd == nil || cmd.ProcessState == nil {
return GetExitCode(err)
}
status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
if !ok {
return cmd.ProcessState.ExitCode()
}
signaled := status.Signaled()
if signaled {
signal := status.Signal()
return 128 + int(signal)
}
exitStatus := status.ExitStatus()
return exitStatus
}
func GetExitCode(err error) int {
if err == nil {
return 0
}
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
} else {
return -1
}
}
func GetFirstLine(s string) string {
idx := strings.Index(s, "\n")
if idx == -1 {
return s
}
return s[0:idx]
}

View File

@ -27,7 +27,7 @@ CREATE TABLE remote_instance (
festate json NOT NULL,
statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL
);
, shelltype varchar(20) NOT NULL DEFAULT 'bash');
CREATE TABLE state_base (
basehash varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
@ -55,10 +55,8 @@ CREATE TABLE remote (
lastconnectts bigint NOT NULL,
local boolean NOT NULL,
archived boolean NOT NULL,
remoteidx int NOT NULL,
statevars json NOT NULL DEFAULT '{}',
sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual',
openaiopts json NOT NULL DEFAULT '{}');
remoteidx int NOT NULL
, statevars json NOT NULL DEFAULT '{}', openaiopts json NOT NULL DEFAULT '{}', sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual', shellpref varchar(20) NOT NULL DEFAULT 'detect');
CREATE TABLE history (
historyid varchar(36) PRIMARY KEY,
ts bigint NOT NULL,
@ -203,7 +201,7 @@ CREATE TABLE IF NOT EXISTS "cmd" (
rtnstate boolean NOT NULL,
rtnbasehash varchar(36) NOT NULL,
rtndiffhasharr json NOT NULL,
runout json NOT NULL,
runout json NOT NULL, restartts bigint NOT NULL DEFAULT 0,
PRIMARY KEY (screenid, lineid)
);
CREATE TABLE cmd_migrate20 (

View File

@ -1735,7 +1735,7 @@ func (msh *MShellProc) Launch(interactive bool) {
msh.WriteToPtyBuffer("connected to %s\n", remoteCopy.RemoteCanonicalName)
go func() {
exitErr := cproc.Cmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
exitCode := utilfn.GetExitCode(exitErr)
msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
msh.Status = StatusDisconnected