From 674a6ef11eea7846cebc733a893cb4bc71ba3cc8 Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 24 Oct 2022 21:26:39 -0700 Subject: [PATCH] grab shell vars with export vars --- main-mshell.go | 11 +-- pkg/packet/packet.go | 10 ++- pkg/shexec/parser.go | 200 +++++++++++++++++++++++++++++++++++++++++++ pkg/shexec/shexec.go | 55 +++--------- 4 files changed, 222 insertions(+), 54 deletions(-) create mode 100644 pkg/shexec/parser.go diff --git a/main-mshell.go b/main-mshell.go index 59a826ae1..af680c68a 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -526,13 +526,14 @@ func main() { } else if firstArg == "--version" { fmt.Printf("mshell %s\n", base.MShellVersion) return - } else if firstArg == "--env" { - rtnCode, err := handleEnv() + } else if firstArg == "--test-env" { + state, err := shexec.GetShellState() + if state != nil { + + } if err != nil { fmt.Fprintf(os.Stderr, "[error] %v\n", err) - } - if rtnCode != 0 { - os.Exit(rtnCode) + os.Exit(1) } } else if firstArg == "--single" { handleSingle(false) diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 02346f642..88cd7369b 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -109,10 +109,12 @@ func MakePacket(packetType string) (PacketType, error) { } type ShellState struct { - Cwd string `json:"cwd,omitempty"` - Env0 []byte `json:"env0,omitempty"` - Aliases string `json:"aliases,omitempty"` - Funcs string `json:"funcs,omitempty"` + Version string `json:"version,omitempty"` + Cwd string `json:"cwd,omitempty"` + ShellVars string `json:"shellvars,omitempty"` + Env0 []byte `json:"env0,omitempty"` + Aliases string `json:"aliases,omitempty"` + Funcs string `json:"funcs,omitempty"` } type CmdDataPacketType struct { diff --git a/pkg/shexec/parser.go b/pkg/shexec/parser.go new file mode 100644 index 000000000..6f7eeef72 --- /dev/null +++ b/pkg/shexec/parser.go @@ -0,0 +1,200 @@ +package shexec + +import ( + "bytes" + "fmt" + "io" + "strings" + + "github.com/scripthaus-dev/mshell/pkg/packet" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/syntax" +) + +type ParseEnviron struct { + Env map[string]string +} + +func (e *ParseEnviron) Get(name string) expand.Variable { + val, ok := e.Env[name] + if !ok { + return expand.Variable{} + } + return expand.Variable{ + Exported: true, + Kind: expand.String, + Str: val, + } +} + +func (e *ParseEnviron) Each(fn func(name string, vr expand.Variable) bool) { + for key, _ := range e.Env { + rtn := fn(key, e.Get(key)) + if !rtn { + break + } + } +} + +func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error { + return nil +} + +func doProcSubst(w *syntax.ProcSubst) (string, error) { + return "", nil +} + +func GetParserConfig(envMap map[string]string) *expand.Config { + cfg := &expand.Config{ + Env: &ParseEnviron{Env: envMap}, + GlobStar: false, + NullGlob: false, + NoUnset: false, + CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) }, + ProcSubst: doProcSubst, + ReadDir: nil, + } + return cfg +} + +func QuotedLitToStr(word *syntax.Word) (string, error) { + cfg := GetParserConfig(nil) + return expand.Literal(cfg, word) +} + +// https://wiki.bash-hackers.org/syntax/shellvars +var NoStoreVarNames = map[string]bool{ + "BASH": true, + "BASHOPTS": true, + "BASHPID": true, + "BASH_ALIASES": true, + "BASH_ARGC": true, + "BASH_ARGV": true, + "BASH_ARGV0": true, + "BASH_CMDS": true, + "BASH_COMMAND": true, + "BASH_EXECUTION_STRING": true, + "BASH_LINENO": true, + "BASH_REMATCH": true, + "BASH_SOURCE": true, + "BASH_SUBSHELL": true, + "BASH_VERSINFO": true, + "BASH_VERSION": true, + "COPROC": true, + "DIRSTACK": true, + "EPOCHREALTIME": true, + "EPOCHSECONDS": true, + "FUNCNAME": true, + "HISTCMD": true, + "OLDPWD": true, + "PIPESTATUS": true, + "PPID": true, + "PWD": true, + "RANDOM": true, + "SECONDS": true, + "SHLVL": true, + "HISTFILE": true, + "HISTFILESIZE": true, + "HISTCONTROL": true, + "HISTIGNORE": true, + "HISTSIZE": true, + "HISTTIMEFORMAT": true, + "SRANDOM": true, +} + +func parseDeclareStmt(envBuffer *bytes.Buffer, varsBuffer *bytes.Buffer, stmt *syntax.Stmt, src []byte) error { + cmd := stmt.Cmd + decl, ok := cmd.(*syntax.DeclClause) + if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 { + return fmt.Errorf("invalid declare variant") + } + declArgs := decl.Args[0] + if !declArgs.Naked || len(declArgs.Value.Parts) != 1 { + return fmt.Errorf("wrong number of declare args parts") + } + declArgLit, ok := declArgs.Value.Parts[0].(*syntax.Lit) + if !ok { + return fmt.Errorf("declare args is not a literal") + } + declArgStr := declArgLit.Value + if !strings.HasPrefix(declArgStr, "-") { + return fmt.Errorf("declare args not an argument (does not start with '-')") + } + declAssign := decl.Args[1] + if declAssign.Name == nil { + return fmt.Errorf("declare does not have a valid name") + } + varName := declAssign.Name.Value + if NoStoreVarNames[varName] { + return nil + } + if strings.Index(varName, "=") != -1 || strings.Index(varName, "\x00") != -1 { + return fmt.Errorf("invalid varname (cannot contain '=' or 0 byte)") + } + fullDeclBytes := src[decl.Pos().Offset():decl.End().Offset()] + if strings.Index(declArgStr, "x") == -1 { + // non-exported vars get written to vars as decl statements + varsBuffer.Write(fullDeclBytes) + varsBuffer.WriteRune('\n') + return nil + } + if declArgStr != "-x" { + return fmt.Errorf("can only export plain bash variables (no arrays)") + } + // exported vars are parsed into Env0 format + if declAssign.Naked || declAssign.Array != nil || declAssign.Index != nil || declAssign.Append || declAssign.Value == nil { + return fmt.Errorf("invalid variable to export") + } + varValue := declAssign.Value + varValueStr, err := QuotedLitToStr(varValue) + if err != nil { + return fmt.Errorf("parsing declare value: %w", err) + } + if strings.Index(varValueStr, "\x00") != -1 { + return fmt.Errorf("invalid export var value (cannot contain 0 byte)") + } + envBuffer.WriteString(fmt.Sprintf("%s=%s\x00", varName, varValueStr)) + return nil +} + +func parseDeclareOutput(state *packet.ShellState, declareBytes []byte) error { + r := bytes.NewReader(declareBytes) + parser := syntax.NewParser(syntax.Variant(syntax.LangBash)) + file, err := parser.Parse(r, "aliases") + if err != nil { + return err + } + var envBuffer, varsBuffer bytes.Buffer + for _, stmt := range file.Stmts { + err = parseDeclareStmt(&envBuffer, &varsBuffer, stmt, declareBytes) + if err != nil { + // TODO where to put parse errors? + continue + } + } + state.Env0 = envBuffer.Bytes() + state.ShellVars = varsBuffer.String() + return nil +} + +func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { + // 5 fields: version, cwd, env/vars, aliases, funcs + fields := bytes.Split(outputBytes, []byte{0, 0}) + if len(fields) != 5 { + return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields)) + } + rtn := &packet.ShellState{} + rtn.Version = string(fields[0]) + if strings.Index(rtn.Version, "bash") == -1 { + return nil, fmt.Errorf("invalid shell state output, only bash is supported") + } + cwdStr := string(fields[1]) + if strings.HasSuffix(cwdStr, "\r\n") { + cwdStr = cwdStr[0 : len(cwdStr)-2] + } + rtn.Cwd = string(cwdStr) + parseDeclareOutput(rtn, fields[2]) + rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n") + rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n") + return rtn, nil +} diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 9dc59a405..e7c8818e0 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -47,6 +47,8 @@ const MaxMaxPtySize = 100 * 1024 * 1024 const GetStateTimeout = 5 * time.Second +const GetShellStateCmd = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}; printf "\x00\x00"; pwd; printf "\x00\x00"; declare -p $(compgen -A variable); printf "\x00\x00"; alias -p; printf "\x00\x00"; declare -f;` + const ClientCommandFmt = ` PATH=$PATH:~/.mshell; which mshell > /dev/null; @@ -976,7 +978,7 @@ shopt -s extglob if pk.ReturnState { rcFileStr += ` _scripthaus_exittrap () { - %s --env; alias -p; printf \"\\x00\\x00\"; declare -f; +` + GetShellStateCmd + ` } trap _scripthaus_exittrap EXIT ` @@ -984,18 +986,15 @@ trap _scripthaus_exittrap EXIT return rcFileStr } -func makeExitTrap(fdNum int) (string, error) { - stateCmd, err := GetShellStateRedirectCommandStr(fdNum) - if err != nil { - return "", err - } +func makeExitTrap(fdNum int) string { + stateCmd := GetShellStateRedirectCommandStr(fdNum) fmtStr := ` _scripthaus_exittrap () { %s } trap _scripthaus_exittrap EXIT ` - return fmt.Sprintf(fmtStr, stateCmd), nil + return fmt.Sprintf(fmtStr, stateCmd) } func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (rtnShExec *ShExecType, rtnErr error) { @@ -1028,10 +1027,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro cmd.ReturnState.FdNum = 20 rtnStateWriter = pw defer pw.Close() - trapCmdStr, err := makeExitTrap(cmd.ReturnState.FdNum) - if err != nil { - return nil, err - } + trapCmdStr := makeExitTrap(cmd.ReturnState.FdNum) rcFileStr += trapCmdStr } rcFileFdNum, err := AddRunData(pk, rcFileStr, "rcfile") @@ -1438,44 +1434,13 @@ func runSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) { return outputBuf.Bytes(), nil } -func GetShellStateCommandStr() (string, error) { - execFile, err := os.Executable() - if err != nil { - return "", fmt.Errorf("cannot find local mshell executable: %w", err) - } - return fmt.Sprintf(`%s --env; alias -p; printf \"\\x00\\x00\"; declare -f`, shellescape.Quote(execFile)), nil -} - -func GetShellStateRedirectCommandStr(outputFdNum int) (string, error) { - cmdStr, err := GetShellStateCommandStr() - if err != nil { - return "", err - } - return fmt.Sprintf("cat <(%s) > /dev/fd/%d", cmdStr, outputFdNum), nil -} - -func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { - fields := bytes.Split(outputBytes, []byte{0, 0}) - if len(fields) != 4 { - return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields)) - } - rtn := &packet.ShellState{} - rtn.Cwd = string(fields[0]) - if len(fields[1]) > 0 { - rtn.Env0 = append(fields[1], '\x00') - } - rtn.Aliases = strings.ReplaceAll(string(fields[2]), "\r\n", "\n") - rtn.Funcs = strings.ReplaceAll(string(fields[3]), "\r\n", "\n") - return rtn, nil +func GetShellStateRedirectCommandStr(outputFdNum int) string { + return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetShellStateCmd, outputFdNum) } func GetShellState() (*packet.ShellState, error) { ctx, _ := context.WithTimeout(context.Background(), GetStateTimeout) - cmdStr, err := GetShellStateCommandStr() - if err != nil { - return nil, err - } - ecmd := exec.CommandContext(ctx, "bash", "-l", "-i", "-c", cmdStr) + ecmd := exec.CommandContext(ctx, "bash", "-l", "-i", "-c", GetShellStateCmd) outputBytes, err := runSimpleCmdInPty(ecmd) if err != nil { return nil, err