mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-02 18:39:05 +01:00
force our exit trap to always run (for rtnstate commands) (#556)
* add command validation to shellapi. mock out bash/zsh versions * implement validate command fn bash and zsh * test validate command * change rtnstate commands to always end with a builtin, so we always get our exit trap to run * simplify the rtnstate modification, don't add the 'wait' (as this is a different problem/feature) * update schema
This commit is contained in:
parent
1f5309e097
commit
6919dbfb5f
@ -6,6 +6,7 @@ package shellapi
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@ -266,6 +267,22 @@ func (bashShellApi) MakeShellStateDiff(oldState *packet.ShellState, oldStateHash
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (bashShellApi) ValidateCommandSyntax(cmdStr string) error {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout)
|
||||
defer cancelFn()
|
||||
cmd := exec.CommandContext(ctx, GetLocalBashPath(), "-n", "-c", cmdStr)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
errStr := utilfn.GetFirstLine(string(output))
|
||||
errStr = strings.TrimPrefix(errStr, "bash: -c: ")
|
||||
if len(errStr) == 0 {
|
||||
return errors.New("bash syntax error")
|
||||
}
|
||||
return errors.New(errStr)
|
||||
}
|
||||
|
||||
func (bashShellApi) ApplyShellStateDiff(oldState *packet.ShellState, diff *packet.ShellStateDiff) (*packet.ShellState, error) {
|
||||
if oldState == nil {
|
||||
return nil, fmt.Errorf("cannot apply diff, oldState is nil")
|
||||
|
@ -214,9 +214,16 @@ func bashParseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarB
|
||||
firstParseErr = err
|
||||
}
|
||||
}
|
||||
if decl != nil && !BashNoStoreVarNames[decl.Name] {
|
||||
declMap[decl.Name] = decl
|
||||
if decl == nil {
|
||||
continue
|
||||
}
|
||||
if BashNoStoreVarNames[decl.Name] {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(decl.Name, "_wavetemp_") {
|
||||
continue
|
||||
}
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
pvarMap := parseExtVarOutput(pvarBytes, "", "")
|
||||
utilfn.CombineMaps(declMap, pvarMap)
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
)
|
||||
|
||||
const GetVersionTimeout = 5 * time.Second
|
||||
const ValidateTimeout = 2 * time.Second
|
||||
const GetGitBranchCmdStr = `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`
|
||||
const GetK8sContextCmdStr = `printf "K8SCONTEXT %s\x00" "$(kubectl config current-context 2>/dev/null)"`
|
||||
const GetK8sNamespaceCmdStr = `printf "K8SNAMESPACE %s\x00" "$(kubectl config view --minify --output 'jsonpath={..namespace}' 2>/dev/null)"`
|
||||
@ -69,6 +70,7 @@ type ShellStateOutput struct {
|
||||
type ShellApi interface {
|
||||
GetShellType() string
|
||||
MakeExitTrap(fdNum int) (string, []byte)
|
||||
ValidateCommandSyntax(cmdStr string) error
|
||||
GetLocalMajorVersion() string
|
||||
GetLocalShellPath() string
|
||||
GetRemoteShellPath() string
|
||||
|
@ -211,27 +211,41 @@ type ZshMap = map[ZshParamKey]string
|
||||
|
||||
type zshShellApi struct{}
|
||||
|
||||
func (z zshShellApi) GetShellType() string {
|
||||
func (zshShellApi) GetShellType() string {
|
||||
return packet.ShellType_zsh
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
|
||||
func (zshShellApi) MakeExitTrap(fdNum int) (string, []byte) {
|
||||
return MakeZshExitTrap(fdNum)
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetLocalMajorVersion() string {
|
||||
func (zshShellApi) GetLocalMajorVersion() string {
|
||||
return GetLocalZshMajorVersion()
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetLocalShellPath() string {
|
||||
func (zshShellApi) GetLocalShellPath() string {
|
||||
return "/bin/zsh"
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetRemoteShellPath() string {
|
||||
func (zshShellApi) GetRemoteShellPath() string {
|
||||
return "zsh"
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
|
||||
func (zshShellApi) ValidateCommandSyntax(cmdStr string) error {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), ValidateTimeout)
|
||||
defer cancelFn()
|
||||
cmd := exec.CommandContext(ctx, GetLocalZshPath(), "-n", "-c", cmdStr)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if len(output) == 0 {
|
||||
return errors.New("zsh syntax error")
|
||||
}
|
||||
return errors.New(utilfn.GetFirstLine(string(output)))
|
||||
}
|
||||
|
||||
func (zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
|
||||
if !opts.Sudo {
|
||||
return cmdStr
|
||||
}
|
||||
@ -242,7 +256,7 @@ func (z zshShellApi) MakeRunCommand(cmdStr string, opts RunCommandOpts) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (z zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||
func (zshShellApi) MakeShExecCommand(cmdStr string, rcFileName string, usePty bool) *exec.Cmd {
|
||||
return exec.Command(GetLocalZshPath(), "-l", "-i", "-c", cmdStr)
|
||||
}
|
||||
|
||||
@ -274,7 +288,7 @@ func (z zshShellApi) GetShellState(ctx context.Context, outCh chan ShellStateOut
|
||||
outCh <- ShellStateOutput{ShellState: rtn, Stats: stats}
|
||||
}
|
||||
|
||||
func (z zshShellApi) GetBaseShellOpts() string {
|
||||
func (zshShellApi) GetBaseShellOpts() string {
|
||||
return BaseZshOpts
|
||||
}
|
||||
|
||||
@ -343,6 +357,9 @@ func (z zshShellApi) MakeRcFileStr(pk *packet.RunPacketType) string {
|
||||
if strings.HasPrefix(varDecl.Name, "ZFTP_") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(varDecl.Name, "_wavetemp_") {
|
||||
continue
|
||||
}
|
||||
if varDecl.IsExtVar {
|
||||
continue
|
||||
}
|
||||
@ -709,7 +726,7 @@ func makeZshFuncsStrForShellState(fnMap map[ZshParamKey]string) string {
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (z zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||
func (zshShellApi) ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, *packet.ShellStateStats, error) {
|
||||
if scbase.IsDevMode() && DebugState {
|
||||
writeStateToFile(packet.ShellType_zsh, outputBytes)
|
||||
}
|
||||
|
@ -2,7 +2,9 @@ package shellapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func testSingleDecl(declStr string) {
|
||||
@ -45,3 +47,35 @@ func TestZshSafeDeclName(t *testing.T) {
|
||||
t.Errorf("should not be safe")
|
||||
}
|
||||
}
|
||||
|
||||
func testValidate(t *testing.T, shell string, cmd string, expectErr bool) {
|
||||
var sapi ShellApi
|
||||
if shell == "bash" {
|
||||
sapi = bashShellApi{}
|
||||
} else if shell == "zsh" {
|
||||
sapi = zshShellApi{}
|
||||
} else {
|
||||
t.Errorf("unknown shell %q", shell)
|
||||
return
|
||||
}
|
||||
tstart := time.Now()
|
||||
err := sapi.ValidateCommandSyntax(cmd)
|
||||
log.Printf("shell:%s dur:%v err: %v\n", shell, time.Since(tstart), err)
|
||||
if expectErr && err == nil {
|
||||
t.Errorf("cmd %q, expected error", cmd)
|
||||
}
|
||||
if !expectErr && err != nil {
|
||||
t.Errorf("cmd %q, unexpected error: %v", cmd, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
testValidate(t, "zsh", "echo foo", false)
|
||||
testValidate(t, "zsh", "foo >& &", true)
|
||||
testValidate(t, "zsh", "cd .", false)
|
||||
testValidate(t, "zsh", "echo foo | grep foo", false)
|
||||
testValidate(t, "zsh", "x; echo \"hello", true)
|
||||
testValidate(t, "bash", "echo foo", false)
|
||||
testValidate(t, "bash", "foo >& &", true)
|
||||
testValidate(t, "bash", "cd .; echo \"", true)
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
@ -274,7 +275,7 @@ func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.Pa
|
||||
cmdDuration := endTs.Sub(cproc.StartTs)
|
||||
donePacket := packet.MakeCmdDonePacket(ck)
|
||||
donePacket.Ts = endTs.UnixMilli()
|
||||
donePacket.ExitCode = GetExitCode(exitErr)
|
||||
donePacket.ExitCode = utilfn.GetExitCode(exitErr)
|
||||
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
||||
sender.SendPacket(donePacket)
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellapi"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellenv"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/shellutil"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/utilfn"
|
||||
"github.com/wavetermdev/waveterm/waveshell/pkg/wlog"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/sys/unix"
|
||||
@ -826,6 +827,10 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
var rtnStateWriter *os.File
|
||||
rcFileStr := sapi.MakeRcFileStr(pk)
|
||||
if pk.ReturnState {
|
||||
err := sapi.ValidateCommandSyntax(pk.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create returnstate pipe: %v", err)
|
||||
@ -894,7 +899,12 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
|
||||
os.Remove(cmd.TmpRcFileName)
|
||||
}()
|
||||
}
|
||||
cmd.Cmd = sapi.MakeShExecCommand(pk.Command, rcFileName, pk.UsePty)
|
||||
fullCmdStr := pk.Command
|
||||
if pk.ReturnState {
|
||||
// this ensures that the last command is a shell buitin so we always get our exit trap to run
|
||||
fullCmdStr = fullCmdStr + "\nexit $? 2> /dev/null"
|
||||
}
|
||||
cmd.Cmd = sapi.MakeShExecCommand(fullCmdStr, rcFileName, pk.UsePty)
|
||||
if !pk.StateComplete {
|
||||
cmd.Cmd.Env = os.Environ()
|
||||
}
|
||||
@ -1075,34 +1085,6 @@ func copyToCirFile(dest *cirfile.File, src io.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
func GetCmdExitCode(cmd *exec.Cmd, err error) int {
|
||||
if cmd == nil || cmd.ProcessState == nil {
|
||||
return GetExitCode(err)
|
||||
}
|
||||
status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
|
||||
if !ok {
|
||||
return cmd.ProcessState.ExitCode()
|
||||
}
|
||||
signaled := status.Signaled()
|
||||
if signaled {
|
||||
signal := status.Signal()
|
||||
return 128 + int(signal)
|
||||
}
|
||||
exitStatus := status.ExitStatus()
|
||||
return exitStatus
|
||||
}
|
||||
|
||||
func GetExitCode(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return exitErr.ExitCode()
|
||||
} else {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ShExecType) ProcWait() error {
|
||||
exitErr := c.Cmd.Wait()
|
||||
c.Lock.Lock()
|
||||
@ -1139,7 +1121,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
endTs := time.Now()
|
||||
cmdDuration := endTs.Sub(c.StartTs)
|
||||
donePacket.Ts = endTs.UnixMilli()
|
||||
donePacket.ExitCode = GetCmdExitCode(c.Cmd, exitErr)
|
||||
donePacket.ExitCode = utilfn.GetCmdExitCode(c.Cmd, exitErr)
|
||||
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
||||
if c.FileNames != nil {
|
||||
os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)
|
||||
|
@ -15,9 +15,11 @@ import (
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
@ -635,3 +637,39 @@ func DetectMimeType(path string) string {
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func GetCmdExitCode(cmd *exec.Cmd, err error) int {
|
||||
if cmd == nil || cmd.ProcessState == nil {
|
||||
return GetExitCode(err)
|
||||
}
|
||||
status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus)
|
||||
if !ok {
|
||||
return cmd.ProcessState.ExitCode()
|
||||
}
|
||||
signaled := status.Signaled()
|
||||
if signaled {
|
||||
signal := status.Signal()
|
||||
return 128 + int(signal)
|
||||
}
|
||||
exitStatus := status.ExitStatus()
|
||||
return exitStatus
|
||||
}
|
||||
|
||||
func GetExitCode(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return exitErr.ExitCode()
|
||||
} else {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func GetFirstLine(s string) string {
|
||||
idx := strings.Index(s, "\n")
|
||||
if idx == -1 {
|
||||
return s
|
||||
}
|
||||
return s[0:idx]
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ CREATE TABLE remote_instance (
|
||||
festate json NOT NULL,
|
||||
statebasehash varchar(36) NOT NULL,
|
||||
statediffhasharr json NOT NULL
|
||||
);
|
||||
, shelltype varchar(20) NOT NULL DEFAULT 'bash');
|
||||
CREATE TABLE state_base (
|
||||
basehash varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
@ -55,10 +55,8 @@ CREATE TABLE remote (
|
||||
lastconnectts bigint NOT NULL,
|
||||
local boolean NOT NULL,
|
||||
archived boolean NOT NULL,
|
||||
remoteidx int NOT NULL,
|
||||
statevars json NOT NULL DEFAULT '{}',
|
||||
sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual',
|
||||
openaiopts json NOT NULL DEFAULT '{}');
|
||||
remoteidx int NOT NULL
|
||||
, statevars json NOT NULL DEFAULT '{}', openaiopts json NOT NULL DEFAULT '{}', sshconfigsrc varchar(36) NOT NULL DEFAULT 'waveterm-manual', shellpref varchar(20) NOT NULL DEFAULT 'detect');
|
||||
CREATE TABLE history (
|
||||
historyid varchar(36) PRIMARY KEY,
|
||||
ts bigint NOT NULL,
|
||||
@ -203,7 +201,7 @@ CREATE TABLE IF NOT EXISTS "cmd" (
|
||||
rtnstate boolean NOT NULL,
|
||||
rtnbasehash varchar(36) NOT NULL,
|
||||
rtndiffhasharr json NOT NULL,
|
||||
runout json NOT NULL,
|
||||
runout json NOT NULL, restartts bigint NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (screenid, lineid)
|
||||
);
|
||||
CREATE TABLE cmd_migrate20 (
|
||||
|
@ -1735,7 +1735,7 @@ func (msh *MShellProc) Launch(interactive bool) {
|
||||
msh.WriteToPtyBuffer("connected to %s\n", remoteCopy.RemoteCanonicalName)
|
||||
go func() {
|
||||
exitErr := cproc.Cmd.Wait()
|
||||
exitCode := shexec.GetExitCode(exitErr)
|
||||
exitCode := utilfn.GetExitCode(exitErr)
|
||||
msh.WithLock(func() {
|
||||
if msh.Status == StatusConnected || msh.Status == StatusConnecting {
|
||||
msh.Status = StatusDisconnected
|
||||
|
Loading…
Reference in New Issue
Block a user