force our exit trap to always run (for rtnstate commands) (#556)

* add command validation to shellapi.  mock out bash/zsh versions

* implement validate command fn bash and zsh

* test validate command

* change rtnstate commands to always end with a builtin, so we always get our exit trap to run

* simplify the rtnstate modification, don't add the 'wait' (as this is a different problem/feature)

* update schema
This commit is contained in:
Mike Sawka 2024-04-09 11:33:23 -07:00 committed by GitHub
parent 1f5309e097
commit 6919dbfb5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 145 additions and 49 deletions

View File

@ -6,6 +6,7 @@ package shellapi
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"os/exec" "os/exec"
"runtime" "runtime"
@ -266,6 +267,22 @@ func (bashShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash
return rtn, nil return rtn, nil
} }
func (bashShellApi) ValidateCommandSyntax(cmdStr string) error {
ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout)
defer cancelFn()
cmd := exec.CommandContext(ctx, GetLocalBashPath(), "-n", "-c", cmdStr)
output, err := cmd.CombinedOutput()
if err == nil {
return nil
}
errStr := utilfn.GetFirstLine(string(output))
errStr = strings.TrimPrefix(errStr, "bash: -c: ")
if len(errStr) == 0 {
return errors.New("bash syntax error")
}
return errors.New(errStr)
}
func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) { func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
if oldState == nil { if oldState == nil {
return nil, fmt.Errorf("cannot apply diff, oldState is nil") return nil, fmt.Errorf("cannot apply diff, oldState is nil")

View File

@ -214,9 +214,16 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
firstParseErr = err firstParseErr = err
} }
} }
if decl != nil && !BashNoStoreVarNames[decl.Name] { if decl == nil {
declMap[decl.Name] = decl continue
} }
if BashNoStoreVarNames[decl.Name] {
continue
}
if strings.HasPrefix(decl.Name, "_wavetemp_") {
continue
}
declMap[decl.Name] = decl
} }
pvarMap := parseExtVarOutput(pvarBytes, "", "") pvarMap := parseExtVarOutput(pvarBytes, "", "")
utilfn.CombineMaps(declMap, pvarMap) utilfn.CombineMaps(declMap, pvarMap)

View File

@ -28,6 +28,7 @@ import (
) )
const GetVersionTimeout = 5 * time.Second const GetVersionTimeout = 5 * time.Second
const ValidateTimeout = 2 * time.Second
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"` const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"` const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"` const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
@ -69,6 +70,7 @@ type ShellStateOutput struct {
type ShellApi interface { type ShellApi interface {
GetShellType() string GetShellType() string
MakeExitTrap(fdNum int) (string, []byte) MakeExitTrap(fdNum int) (string, []byte)
ValidateCommandSyntax(cmdStr string) error
GetLocalMajorVersion() string GetLocalMajorVersion() string
GetLocalShellPath() string GetLocalShellPath() string
GetRemoteShellPath() string GetRemoteShellPath() string

View File

@ -211,27 +211,41 @@ type ZshMap = map[ZshParamKey]string
type zshShellApi struct{} type zshShellApi struct{}
func (z zshShellApi) GetShellType() string { func (zshShellApi) GetShellType() string {
return packet.ShellType_zsh return packet.ShellType_zsh
} }
func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) { func (zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
return MakeZshExitTrap(fdNum) return MakeZshExitTrap(fdNum)
} }
func (z zshShellApi) GetLocalMajorVersion() string { func (zshShellApi) GetLocalMajorVersion() string {
return GetLocalZshMajorVersion() return GetLocalZshMajorVersion()
} }
func (z zshShellApi) GetLocalShellPath() string { func (zshShellApi) GetLocalShellPath() string {
return "/bin/zsh" return "/bin/zsh"
} }
func (z zshShellApi) GetRemoteShellPath() string { func (zshShellApi) GetRemoteShellPath() string {
return "zsh" return "zsh"
} }
func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string { func (zshShellApi) ValidateCommandSyntax(cmdStr string) error {
ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout)
defer cancelFn()
cmd := exec.CommandContext(ctx, GetLocalZshPath(), "-n", "-c", cmdStr)
output, err := cmd.CombinedOutput()
if err == nil {
return nil
}
if len(output) == 0 {
return errors.New("zsh syntax error")
}
return errors.New(utilfn.GetFirstLine(string(output)))
}
func (zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
if !opts.Sudo { if !opts.Sudo {
return cmdStr return cmdStr
} }
@ -242,7 +256,7 @@ func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
} }
} }
func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd { func (zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr) return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
} }
@ -274,7 +288,7 @@ func (z zshShellApi) GetShellState(ctx context.Context, outCh chan ShellStateOut
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats} outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
} }
func (z zshShellApi) GetBaseShellOpts() string { func (zshShellApi) GetBaseShellOpts() string {
return BaseZshOpts return BaseZshOpts
} }
@ -343,6 +357,9 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
if strings.HasPrefix(varDecl.Name, "ZFTP_") { if strings.HasPrefix(varDecl.Name, "ZFTP_") {
continue continue
} }
if strings.HasPrefix(varDecl.Name, "_wavetemp_") {
continue
}
if varDecl.IsExtVar { if varDecl.IsExtVar {
continue continue
} }
@ -709,7 +726,7 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
return buf.String() return buf.String()
} }
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) { func (zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
if scbase.IsDevMode() && DebugState { if scbase.IsDevMode() && DebugState {
writeStateToFile(packet.ShellType_zsh, outputBytes) writeStateToFile(packet.ShellType_zsh, outputBytes)
} }

View File

@ -2,7 +2,9 @@ package shellapi
import ( import (
"fmt" "fmt"
"log"
"testing" "testing"
"time"
) )
func testSingleDecl(declStr string) { func testSingleDecl(declStr string) {
@ -45,3 +47,35 @@ func TestZshSafeDeclName(t *testing.T) {
t.Errorf("should not be safe") t.Errorf("should not be safe")
} }
} }
func testValidate(t *testing.T, shell string, cmd string, expectErr bool) {
var sapi ShellApi
if shell == "bash" {
sapi = bashShellApi{}
} else if shell == "zsh" {
sapi = zshShellApi{}
} else {
t.Errorf("unknown shell %q", shell)
return
}
tstart := time.Now()
err := sapi.ValidateCommandSyntax(cmd)
log.Printf("shell:%s dur:%v err: %v\n", shell, time.Since(tstart), err)
if expectErr && err == nil {
t.Errorf("cmd %q, expected error", cmd)
}
if !expectErr && err != nil {
t.Errorf("cmd %q, unexpected error: %v", cmd, err)
}
}
func TestValidate(t *testing.T) {
testValidate(t, "zsh", "echo foo", false)
testValidate(t, "zsh", "foo >& &", true)
testValidate(t, "zsh", "cd .", false)
testValidate(t, "zsh", "echo foo | grep foo", false)
testValidate(t, "zsh", "x; echo \"hello", true)
testValidate(t, "bash", "echo foo", false)
testValidate(t, "bash", "foo >& &", true)
testValidate(t, "bash", "cd .; echo \"", true)
}

View File

@ -12,6 +12,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/base" "github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
) )
@ -274,7 +275,7 @@ func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.Pa
cmdDuration := endTs.Sub(cproc.StartTs) cmdDuration := endTs.Sub(cproc.StartTs)
donePacket := packet.MakeCmdDonePacket(ck) donePacket := packet.MakeCmdDonePacket(ck)
donePacket.Ts = endTs.UnixMilli() donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetExitCode(exitErr) donePacket.ExitCode = utilfn.GetExitCode(exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond) donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
sender.SendPacket(donePacket) sender.SendPacket(donePacket)
} }

View File

@ -31,6 +31,7 @@ import (
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi" "github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv" "github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
"github.com/wavetermdev/waveterm/waveshell/pkg/wlog" "github.com/wavetermdev/waveterm/waveshell/pkg/wlog"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -826,6 +827,10 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
var rtnStateWriter *os.File var rtnStateWriter *os.File
rcFileStr := sapi.MakeRcFileStr(pk) rcFileStr := sapi.MakeRcFileStr(pk)
if pk.ReturnState { if pk.ReturnState {
err := sapi.ValidateCommandSyntax(pk.Command)
if err != nil {
return nil, err
}
pr, pw, err := os.Pipe() pr, pw, err := os.Pipe()
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot create returnstate pipe: %v", err) return nil, fmt.Errorf("cannot create returnstate pipe: %v", err)
@ -894,7 +899,12 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
os.Remove(cmd.TmpRcFileName) os.Remove(cmd.TmpRcFileName)
}() }()
} }
cmd.Cmd = sapi.MakeShExecCommand(pk.Command, rcFileName, pk.UsePty) fullCmdStr := pk.Command
if pk.ReturnState {
// this ensures that the last command is a shell buitin so we always get our exit trap to run
fullCmdStr = fullCmdStr + "\nexit $? 2> /dev/null"
}
cmd.Cmd = sapi.MakeShExecCommand(fullCmdStr, rcFileName, pk.UsePty)
if !pk.StateComplete { if !pk.StateComplete {
cmd.Cmd.Env = os.Environ() cmd.Cmd.Env = os.Environ()
} }
@ -1075,34 +1085,6 @@ func copyToCirFile(dest *cirfile.File, src io.Reader) error {
} }
} }
func GetCmdExitCode(cmd *exec.Cmd, err error) int {
if cmd == nil || cmd.ProcessState == nil {
return GetExitCode(err)
}
status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
if !ok {
return cmd.ProcessState.ExitCode()
}
signaled := status.Signaled()
if signaled {
signal := status.Signal()
return 128 + int(signal)
}
exitStatus := status.ExitStatus()
return exitStatus
}
func GetExitCode(err error) int {
if err == nil {
return 0
}
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
} else {
return -1
}
}
func (c *ShExecType) ProcWait() error { func (c *ShExecType) ProcWait() error {
exitErr := c.Cmd.Wait() exitErr := c.Cmd.Wait()
c.Lock.Lock() c.Lock.Lock()
@ -1139,7 +1121,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
endTs := time.Now() endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs) cmdDuration := endTs.Sub(c.StartTs)
donePacket.Ts = endTs.UnixMilli() donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetCmdExitCode(c.Cmd, exitErr) donePacket.ExitCode = utilfn.GetCmdExitCode(c.Cmd, exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond) donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
if c.FileNames != nil { if c.FileNames != nil {
os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error) os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)

View File

@ -15,9 +15,11 @@ import (
mathrand "math/rand" mathrand "math/rand"
"net/http" "net/http"
"os" "os"
"os/exec"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"syscall"
"unicode/utf8" "unicode/utf8"
) )
@ -635,3 +637,39 @@ func DetectMimeType(path string) string {
} }
return rtn return rtn
} }
func GetCmdExitCode(cmd *exec.Cmd, err error) int {
if cmd == nil || cmd.ProcessState == nil {
return GetExitCode(err)
}
status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
if !ok {
return cmd.ProcessState.ExitCode()
}
signaled := status.Signaled()
if signaled {
signal := status.Signal()
return 128 + int(signal)
}
exitStatus := status.ExitStatus()
return exitStatus
}
func GetExitCode(err error) int {
if err == nil {
return 0
}
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
} else {
return -1
}
}
func GetFirstLine(s string) string {
idx := strings.Index(s, "\n")
if idx == -1 {
return s
}
return s[0:idx]
}

View File

@ -27,7 +27,7 @@ CREATE TABLE remote_instance (
festate json NOT NULL, festate json NOT NULL,
statebasehash varchar(36) NOT NULL, statebasehash varchar(36) NOT NULL,
statediffhasharr json NOT NULL statediffhasharr json NOT NULL
); , shelltype varchar(20) NOT NULL DEFAULT 'bash');
CREATE TABLE state_base ( CREATE TABLE state_base (
basehash varchar(36) PRIMARY KEY, basehash varchar(36) PRIMARY KEY,
ts bigint NOT NULL, ts bigint NOT NULL,
@ -55,10 +55,8 @@ CREATE TABLE remote (
lastconnectts bigint NOT NULL, lastconnectts bigint NOT NULL,
local boolean NOT NULL, local boolean NOT NULL,
archived boolean NOT NULL, archived boolean NOT NULL,
remoteidx int NOT NULL, remoteidx int NOT NULL
statevars json NOT NULL DEFAULT '{}', , statevars json NOT NULL DEFAULT '{}', openaiopts json NOT NULL DEFAULT '{}', sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual', shellpref varchar(20) NOT NULL DEFAULT 'detect');
sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual',
openaiopts json NOT NULL DEFAULT '{}');
CREATE TABLE history ( CREATE TABLE history (
historyid varchar(36) PRIMARY KEY, historyid varchar(36) PRIMARY KEY,
ts bigint NOT NULL, ts bigint NOT NULL,
@ -203,7 +201,7 @@ CREATE TABLE IF NOT EXISTS "cmd" (
rtnstate boolean NOT NULL, rtnstate boolean NOT NULL,
rtnbasehash varchar(36) NOT NULL, rtnbasehash varchar(36) NOT NULL,
rtndiffhasharr json NOT NULL, rtndiffhasharr json NOT NULL,
runout json NOT NULL, runout json NOT NULL, restartts bigint NOT NULL DEFAULT 0,
PRIMARY KEY (screenid, lineid) PRIMARY KEY (screenid, lineid)
); );
CREATE TABLE cmd_migrate20 ( CREATE TABLE cmd_migrate20 (

View File

@ -1735,7 +1735,7 @@ func (msh *MShellProc) Launch(interactive bool) {
msh.WriteToPtyBuffer("connected to %s\n", remoteCopy.RemoteCanonicalName) msh.WriteToPtyBuffer("connected to %s\n", remoteCopy.RemoteCanonicalName)
go func() { go func() {
exitErr := cproc.Cmd.Wait() exitErr := cproc.Cmd.Wait()
exitCode := shexec.GetExitCode(exitErr) exitCode := utilfn.GetExitCode(exitErr)
msh.WithLock(func() { msh.WithLock(func() {
if msh.Status == StatusConnected || msh.Status == StatusConnecting { if msh.Status == StatusConnected || msh.Status == StatusConnecting {
msh.Status = StatusDisconnected msh.Status = StatusDisconnected