diff --git a/waveshell/pkg/shellapi/bashapi.go b/waveshell/pkg/shellapi/bashapi.go index d2cbedfa0..8f1d18126 100644 --- a/waveshell/pkg/shellapi/bashapi.go +++ b/waveshell/pkg/shellapi/bashapi.go @@ -6,6 +6,7 @@ package shellapi import ( "bytes" "context" + "errors" "fmt" "os/exec" "runtime" @@ -266,6 +267,22 @@ func (bashShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash return rtn, nil } +func (bashShellApi) ValidateCommandSyntax(cmdStr string) error { + ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout) + defer cancelFn() + cmd := exec.CommandContext(ctx, GetLocalBashPath(), "-n", "-c", cmdStr) + output, err := cmd.CombinedOutput() + if err == nil { + return nil + } + errStr := utilfn.GetFirstLine(string(output)) + errStr = strings.TrimPrefix(errStr, "bash: -c: ") + if len(errStr) == 0 { + return errors.New("bash syntax error") + } + return errors.New(errStr) +} + func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) { if oldState == nil { return nil, fmt.Errorf("cannot apply diff, oldState is nil") diff --git a/waveshell/pkg/shellapi/bashparser.go b/waveshell/pkg/shellapi/bashparser.go index f56b33b37..889171ae4 100644 --- a/waveshell/pkg/shellapi/bashparser.go +++ b/waveshell/pkg/shellapi/bashparser.go @@ -214,9 +214,16 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB firstParseErr = err } } - if decl != nil && !BashNoStoreVarNames[decl.Name] { - declMap[decl.Name] = decl + if decl == nil { + continue } + if BashNoStoreVarNames[decl.Name] { + continue + } + if strings.HasPrefix(decl.Name, "_wavetemp_") { + continue + } + declMap[decl.Name] = decl } pvarMap := parseExtVarOutput(pvarBytes, "", "") utilfn.CombineMaps(declMap, pvarMap) diff --git a/waveshell/pkg/shellapi/shellapi.go b/waveshell/pkg/shellapi/shellapi.go index 81c481269..c948a56ee 100644 --- a/waveshell/pkg/shellapi/shellapi.go +++ b/waveshell/pkg/shellapi/shellapi.go @@ -28,6 +28,7 @@ import ( ) const GetVersionTimeout = 5 * time.Second +const ValidateTimeout = 2 * time.Second const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"` const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"` @@ -69,6 +70,7 @@ type ShellStateOutput struct { type ShellApi interface { GetShellType() string MakeExitTrap(fdNum int) (string, []byte) + ValidateCommandSyntax(cmdStr string) error GetLocalMajorVersion() string GetLocalShellPath() string GetRemoteShellPath() string diff --git a/waveshell/pkg/shellapi/zshapi.go b/waveshell/pkg/shellapi/zshapi.go index 71cbdf4b4..461543c70 100644 --- a/waveshell/pkg/shellapi/zshapi.go +++ b/waveshell/pkg/shellapi/zshapi.go @@ -211,27 +211,41 @@ type ZshMap = map[ZshParamKey]string type zshShellApi struct{} -func (z zshShellApi) GetShellType() string { +func (zshShellApi) GetShellType() string { return packet.ShellType_zsh } -func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) { +func (zshShellApi) MakeExitTrap(fdNum int) (string, []byte) { return MakeZshExitTrap(fdNum) } -func (z zshShellApi) GetLocalMajorVersion() string { +func (zshShellApi) GetLocalMajorVersion() string { return GetLocalZshMajorVersion() } -func (z zshShellApi) GetLocalShellPath() string { +func (zshShellApi) GetLocalShellPath() string { return "/bin/zsh" } -func (z zshShellApi) GetRemoteShellPath() string { +func (zshShellApi) GetRemoteShellPath() string { return "zsh" } -func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string { +func (zshShellApi) ValidateCommandSyntax(cmdStr string) error { + ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout) + defer cancelFn() + cmd := exec.CommandContext(ctx, GetLocalZshPath(), "-n", "-c", cmdStr) + output, err := cmd.CombinedOutput() + if err == nil { + return nil + } + if len(output) == 0 { + return errors.New("zsh syntax error") + } + return errors.New(utilfn.GetFirstLine(string(output))) +} + +func (zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string { if !opts.Sudo { return cmdStr } @@ -242,7 +256,7 @@ func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string { } } -func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd { +func (zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd { return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr) } @@ -274,7 +288,7 @@ func (z zshShellApi) GetShellState(ctx context.Context, outCh chan ShellStateOut outCh <- ShellStateOutput{ShellState: rtn, Stats: stats} } -func (z zshShellApi) GetBaseShellOpts() string { +func (zshShellApi) GetBaseShellOpts() string { return BaseZshOpts } @@ -343,6 +357,9 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string { if strings.HasPrefix(varDecl.Name, "ZFTP_") { continue } + if strings.HasPrefix(varDecl.Name, "_wavetemp_") { + continue + } if varDecl.IsExtVar { continue } @@ -709,7 +726,7 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string { return buf.String() } -func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) { +func (zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) { if scbase.IsDevMode() && DebugState { writeStateToFile(packet.ShellType_zsh, outputBytes) } diff --git a/waveshell/pkg/shellapi/zshapi_test.go b/waveshell/pkg/shellapi/zshapi_test.go index 557ac99ec..b5d1762e4 100644 --- a/waveshell/pkg/shellapi/zshapi_test.go +++ b/waveshell/pkg/shellapi/zshapi_test.go @@ -2,7 +2,9 @@ package shellapi import ( "fmt" + "log" "testing" + "time" ) func testSingleDecl(declStr string) { @@ -45,3 +47,35 @@ func TestZshSafeDeclName(t *testing.T) { t.Errorf("should not be safe") } } + +func testValidate(t *testing.T, shell string, cmd string, expectErr bool) { + var sapi ShellApi + if shell == "bash" { + sapi = bashShellApi{} + } else if shell == "zsh" { + sapi = zshShellApi{} + } else { + t.Errorf("unknown shell %q", shell) + return + } + tstart := time.Now() + err := sapi.ValidateCommandSyntax(cmd) + log.Printf("shell:%s dur:%v err: %v\n", shell, time.Since(tstart), err) + if expectErr && err == nil { + t.Errorf("cmd %q, expected error", cmd) + } + if !expectErr && err != nil { + t.Errorf("cmd %q, unexpected error: %v", cmd, err) + } +} + +func TestValidate(t *testing.T) { + testValidate(t, "zsh", "echo foo", false) + testValidate(t, "zsh", "foo >& &", true) + testValidate(t, "zsh", "cd .", false) + testValidate(t, "zsh", "echo foo | grep foo", false) + testValidate(t, "zsh", "x; echo \"hello", true) + testValidate(t, "bash", "echo foo", false) + testValidate(t, "bash", "foo >& &", true) + testValidate(t, "bash", "cd .; echo \"", true) +} diff --git a/waveshell/pkg/shexec/client.go b/waveshell/pkg/shexec/client.go index 9efcf2250..04aea9b68 100644 --- a/waveshell/pkg/shexec/client.go +++ b/waveshell/pkg/shexec/client.go @@ -12,6 +12,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/packet" + "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" ) @@ -274,7 +275,7 @@ func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.Pa cmdDuration := endTs.Sub(cproc.StartTs) donePacket := packet.MakeCmdDonePacket(ck) donePacket.Ts = endTs.UnixMilli() - donePacket.ExitCode = GetExitCode(exitErr) + donePacket.ExitCode = utilfn.GetExitCode(exitErr) donePacket.DurationMs = int64(cmdDuration / time.Millisecond) sender.SendPacket(donePacket) } diff --git a/waveshell/pkg/shexec/shexec.go b/waveshell/pkg/shexec/shexec.go index 9ff6211da..7a5bdef51 100644 --- a/waveshell/pkg/shexec/shexec.go +++ b/waveshell/pkg/shexec/shexec.go @@ -31,6 +31,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/shellapi" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" + "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" "github.com/wavetermdev/waveterm/waveshell/pkg/wlog" "golang.org/x/mod/semver" "golang.org/x/sys/unix" @@ -826,6 +827,10 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro var rtnStateWriter *os.File rcFileStr := sapi.MakeRcFileStr(pk) if pk.ReturnState { + err := sapi.ValidateCommandSyntax(pk.Command) + if err != nil { + return nil, err + } pr, pw, err := os.Pipe() if err != nil { return nil, fmt.Errorf("cannot create returnstate pipe: %v", err) @@ -894,7 +899,12 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro os.Remove(cmd.TmpRcFileName) }() } - cmd.Cmd = sapi.MakeShExecCommand(pk.Command, rcFileName, pk.UsePty) + fullCmdStr := pk.Command + if pk.ReturnState { + // this ensures that the last command is a shell buitin so we always get our exit trap to run + fullCmdStr = fullCmdStr + "\nexit $? 2> /dev/null" + } + cmd.Cmd = sapi.MakeShExecCommand(fullCmdStr, rcFileName, pk.UsePty) if !pk.StateComplete { cmd.Cmd.Env = os.Environ() } @@ -1075,34 +1085,6 @@ func copyToCirFile(dest *cirfile.File, src io.Reader) error { } } -func GetCmdExitCode(cmd *exec.Cmd, err error) int { - if cmd == nil || cmd.ProcessState == nil { - return GetExitCode(err) - } - status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus) - if !ok { - return cmd.ProcessState.ExitCode() - } - signaled := status.Signaled() - if signaled { - signal := status.Signal() - return 128 + int(signal) - } - exitStatus := status.ExitStatus() - return exitStatus -} - -func GetExitCode(err error) int { - if err == nil { - return 0 - } - if exitErr, ok := err.(*exec.ExitError); ok { - return exitErr.ExitCode() - } else { - return -1 - } -} - func (c *ShExecType) ProcWait() error { exitErr := c.Cmd.Wait() c.Lock.Lock() @@ -1139,7 +1121,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType { endTs := time.Now() cmdDuration := endTs.Sub(c.StartTs) donePacket.Ts = endTs.UnixMilli() - donePacket.ExitCode = GetCmdExitCode(c.Cmd, exitErr) + donePacket.ExitCode = utilfn.GetCmdExitCode(c.Cmd, exitErr) donePacket.DurationMs = int64(cmdDuration / time.Millisecond) if c.FileNames != nil { os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error) diff --git a/waveshell/pkg/utilfn/utilfn.go b/waveshell/pkg/utilfn/utilfn.go index 608e36cc8..b9ef0e5c6 100644 --- a/waveshell/pkg/utilfn/utilfn.go +++ b/waveshell/pkg/utilfn/utilfn.go @@ -15,9 +15,11 @@ import ( mathrand "math/rand" "net/http" "os" + "os/exec" "regexp" "sort" "strings" + "syscall" "unicode/utf8" ) @@ -635,3 +637,39 @@ func DetectMimeType(path string) string { } return rtn } + +func GetCmdExitCode(cmd *exec.Cmd, err error) int { + if cmd == nil || cmd.ProcessState == nil { + return GetExitCode(err) + } + status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus) + if !ok { + return cmd.ProcessState.ExitCode() + } + signaled := status.Signaled() + if signaled { + signal := status.Signal() + return 128 + int(signal) + } + exitStatus := status.ExitStatus() + return exitStatus +} + +func GetExitCode(err error) int { + if err == nil { + return 0 + } + if exitErr, ok := err.(*exec.ExitError); ok { + return exitErr.ExitCode() + } else { + return -1 + } +} + +func GetFirstLine(s string) string { + idx := strings.Index(s, "\n") + if idx == -1 { + return s + } + return s[0:idx] +} diff --git a/wavesrv/db/schema.sql b/wavesrv/db/schema.sql index dcdfea604..d7225ff70 100644 --- a/wavesrv/db/schema.sql +++ b/wavesrv/db/schema.sql @@ -27,7 +27,7 @@ CREATE TABLE remote_instance ( festate json NOT NULL, statebasehash varchar(36) NOT NULL, statediffhasharr json NOT NULL -); +, shelltype varchar(20) NOT NULL DEFAULT 'bash'); CREATE TABLE state_base ( basehash varchar(36) PRIMARY KEY, ts bigint NOT NULL, @@ -55,10 +55,8 @@ CREATE TABLE remote ( lastconnectts bigint NOT NULL, local boolean NOT NULL, archived boolean NOT NULL, - remoteidx int NOT NULL, - statevars json NOT NULL DEFAULT '{}', - sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual', - openaiopts json NOT NULL DEFAULT '{}'); + remoteidx int NOT NULL +, statevars json NOT NULL DEFAULT '{}', openaiopts json NOT NULL DEFAULT '{}', sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual', shellpref varchar(20) NOT NULL DEFAULT 'detect'); CREATE TABLE history ( historyid varchar(36) PRIMARY KEY, ts bigint NOT NULL, @@ -203,7 +201,7 @@ CREATE TABLE IF NOT EXISTS "cmd" ( rtnstate boolean NOT NULL, rtnbasehash varchar(36) NOT NULL, rtndiffhasharr json NOT NULL, - runout json NOT NULL, + runout json NOT NULL, restartts bigint NOT NULL DEFAULT 0, PRIMARY KEY (screenid, lineid) ); CREATE TABLE cmd_migrate20 ( diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index d01b7dc43..99de0f9b6 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -1735,7 +1735,7 @@ func (msh *MShellProc) Launch(interactive bool) { msh.WriteToPtyBuffer("connected to %s\n", remoteCopy.RemoteCanonicalName) go func() { exitErr := cproc.Cmd.Wait() - exitCode := shexec.GetExitCode(exitErr) + exitCode := utilfn.GetExitCode(exitErr) msh.WithLock(func() { if msh.Status == StatusConnected || msh.Status == StatusConnecting { msh.Status = StatusDisconnected