From a4a4d53eb0dfaf5917da52ada30bd75a6434c311 Mon Sep 17 00:00:00 2001 From: sawka Date: Tue, 11 Apr 2023 23:52:58 -0700 Subject: [PATCH] grab git branch --- pkg/shexec/parser.go | 21 ++++++++++++++++++--- pkg/shexec/shexec.go | 18 +++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/pkg/shexec/parser.go b/pkg/shexec/parser.go index 299cecc80..4e0dba226 100644 --- a/pkg/shexec/parser.go +++ b/pkg/shexec/parser.go @@ -398,7 +398,7 @@ func parseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) { return rtn, nil } -func parseDeclareOutput(state *packet.ShellState, declareBytes []byte) error { +func parseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error { declareStr := string(declareBytes) r := bytes.NewReader(declareBytes) parser := syntax.NewParser(syntax.Variant(syntax.LangBash)) @@ -419,6 +419,21 @@ func parseDeclareOutput(state *packet.ShellState, declareBytes []byte) error { declMap[decl.Name] = decl } } + pvars := bytes.Split(pvarBytes, []byte{0}) + for _, pvarBA := range pvars { + pvarStr := string(pvarBA) + pvarFields := strings.SplitN(pvarStr, " ", 2) + if len(pvarFields) != 2 { + continue + } + if pvarFields[0] == "" || pvarFields[1] == "" { + continue + } + decl := &DeclareDeclType{Args: "x"} + decl.Name = "PROMPTVAR_" + pvarFields[0] + decl.Value = shellescape.Quote(pvarFields[1]) + declMap[decl.Name] = decl + } state.ShellVars = SerializeDeclMap(declMap) // this writes out the decls in a canonical order if firstParseErr != nil { state.Error = firstParseErr.Error() @@ -429,7 +444,7 @@ func parseDeclareOutput(state *packet.ShellState, declareBytes []byte) error { func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { // 5 fields: version, cwd, env/vars, aliases, funcs fields := bytes.Split(outputBytes, []byte{0, 0}) - if len(fields) != 5 { + if len(fields) != 6 { return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields)) } rtn := &packet.ShellState{} @@ -445,7 +460,7 @@ func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { cwdStr = cwdStr[0 : len(cwdStr)-1] } rtn.Cwd = string(cwdStr) - err := parseDeclareOutput(rtn, fields[2]) + err := parseDeclareOutput(rtn, fields[2], fields[5]) if err != nil { return nil, err } diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 60c6660fb..3f1f7ee0f 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -51,7 +51,15 @@ const MaxTotalRunDataSize = 10 * MaxRunDataSize const GetStateTimeout = 5 * time.Second const BaseBashOpts = `set +m; set +H; shopt -s extglob` -const GetShellStateCmd = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}; printf "\x00\x00"; pwd; printf "\x00\x00"; declare -p $(compgen -A variable); printf "\x00\x00"; alias -p; printf "\x00\x00"; declare -f;` + +var GetShellStateCmds = []string{ + `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]};`, + `pwd;`, + `declare -p $(compgen -A variable);`, + `alias -p;`, + `declare -f;`, + `printf "GITBRANCH %s\x00" "$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"`, +} const ClientCommandFmt = ` PATH=$PATH:~/.mshell; @@ -161,6 +169,10 @@ type ShExecUPR struct { UPR packet.UnknownPacketReporter } +func GetShellStateCmd() string { + return strings.Join(GetShellStateCmds, ` printf "\x00\x00";`) +} + func (s *ShExecType) processSpecialInputPacket(pk *packet.SpecialInputPacketType) error { base.Logf("processSpecialInputPacket: %#v\n", pk) if pk.WinSize != nil { @@ -1500,12 +1512,12 @@ func runSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) { } func GetShellStateRedirectCommandStr(outputFdNum int) string { - return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetShellStateCmd, outputFdNum) + return fmt.Sprintf("cat <(%s) > /dev/fd/%d", GetShellStateCmd(), outputFdNum) } func GetShellState() (*packet.ShellState, error) { ctx, _ := context.WithTimeout(context.Background(), GetStateTimeout) - cmdStr := BaseBashOpts + "; " + GetShellStateCmd + cmdStr := BaseBashOpts + "; " + GetShellStateCmd() ecmd := exec.CommandContext(ctx, "bash", "-l", "-i", "-c", cmdStr) outputBytes, err := runSimpleCmdInPty(ecmd) if err != nil {