diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 88cd7369b..4a33fee1f 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -51,6 +51,7 @@ const ( RawPacketStr = "raw" SpecialInputPacketStr = "sinput" // command CompGenPacketStr = "compgen" // rpc + ReInitPacketStr = "reinit" // rpc ) const PacketSenderQueueSize = 20 @@ -111,10 +112,10 @@ func MakePacket(packetType string) (PacketType, error) { type ShellState struct { Version string `json:"version,omitempty"` Cwd string `json:"cwd,omitempty"` - ShellVars string `json:"shellvars,omitempty"` - Env0 []byte `json:"env0,omitempty"` + ShellVars []byte `json:"shellvars,omitempty"` Aliases string `json:"aliases,omitempty"` Funcs string `json:"funcs,omitempty"` + Error string `json:"error,omitempty"` } type CmdDataPacketType struct { @@ -445,6 +446,7 @@ func FmtMessagePacket(fmtStr string, args ...interface{}) *MessagePacketType { type InitPacketType struct { Type string `json:"type"` + RespId string `json:"respid,omitempty"` Version string `json:"version"` MShellHomeDir string `json:"mshellhomedir,omitempty"` HomeDir string `json:"homedir,omitempty"` @@ -460,6 +462,14 @@ func (*InitPacketType) GetType() string { return InitPacketStr } +func (pk *InitPacketType) GetResponseId() string { + return pk.RespId +} + +func (pk *InitPacketType) GetResponseDone() bool { + return true +} + func MakeInitPacket() *InitPacketType { return &InitPacketType{Type: InitPacketStr} } diff --git a/pkg/shexec/parser.go b/pkg/shexec/parser.go index 6f7eeef72..aafb39c37 100644 --- a/pkg/shexec/parser.go +++ b/pkg/shexec/parser.go @@ -4,8 +4,10 @@ import ( "bytes" "fmt" "io" + "regexp" "strings" + "github.com/alessio/shellescape" "github.com/scripthaus-dev/mshell/pkg/packet" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/syntax" @@ -78,8 +80,6 @@ var NoStoreVarNames = map[string]bool{ "BASH_REMATCH": true, "BASH_SOURCE": true, "BASH_SUBSHELL": true, - "BASH_VERSINFO": true, - "BASH_VERSION": true, "COPROC": true, "DIRSTACK": true, "EPOCHREALTIME": true, @@ -100,61 +100,200 @@ var NoStoreVarNames = map[string]bool{ "HISTSIZE": true, "HISTTIMEFORMAT": true, "SRANDOM": true, + + // we want these in our remote state object + // "EUID": true, + // "SHELLOPTS": true, + // "UID": true, + // "BASH_VERSINFO": true, + // "BASH_VERSION": true, } -func parseDeclareStmt(envBuffer *bytes.Buffer, varsBuffer *bytes.Buffer, stmt *syntax.Stmt, src []byte) error { +type DeclareDeclType struct { + Args string + Name string + Value string +} + +var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$") +var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$") + +func (d *DeclareDeclType) Validate() error { + if len(d.Name) == 0 || !IsValidBashIdentifier(d.Name) { + return fmt.Errorf("invalid shell variable name (invalid bash identifier)") + } + if strings.Index(d.Value, "\x00") >= 0 { + return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)") + } + if !declareDeclArgsRe.MatchString(d.Args) { + return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args)) + } + return nil +} + +func (d *DeclareDeclType) Serialize() string { + return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value) +} + +func (d *DeclareDeclType) EnvString() string { + return d.Name + "=" + d.Value +} + +func (d *DeclareDeclType) DeclareStmt() string { + var argsStr string + if d.Args == "" { + argsStr = "--" + } else { + argsStr = "-" + d.Args + } + return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, shellescape.Quote(d.Value)) +} + +// envline should be valid +func ParseDeclLine(envLine string) *DeclareDeclType { + eqIdx := strings.Index(envLine, "=") + if eqIdx == -1 { + return nil + } + namePart := envLine[0:eqIdx] + valPart := envLine[eqIdx+1:] + pipeIdx := strings.Index(namePart, "|") + if pipeIdx == -1 { + return nil + } + return &DeclareDeclType{ + Args: namePart[0:pipeIdx], + Name: namePart[pipeIdx+1:], + Value: valPart, + } +} + +func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType { + if state == nil { + return nil + } + rtn := make(map[string]*DeclareDeclType) + vars := bytes.Split(state.ShellVars, []byte{0}) + for _, varLine := range vars { + decl := ParseDeclLine(string(varLine)) + if decl != nil { + rtn[decl.Name] = decl + } + } + return rtn +} + +func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte { + var rtn bytes.Buffer + for _, decl := range declMap { + rtn.WriteString(decl.Serialize()) + } + return rtn.Bytes() +} + +func EnvMapFromState(state *packet.ShellState) map[string]string { + if state == nil { + return nil + } + rtn := make(map[string]string) + vars := bytes.Split(state.ShellVars, []byte{0}) + for _, varLine := range vars { + decl := ParseDeclLine(string(varLine)) + if decl != nil && decl.IsExport() { + rtn[decl.Name] = decl.Value + } + } + return rtn +} + +func ShellVarMapFromState(state *packet.ShellState) map[string]string { + if state == nil { + return nil + } + rtn := make(map[string]string) + vars := bytes.Split(state.ShellVars, []byte{0}) + for _, varLine := range vars { + decl := ParseDeclLine(string(varLine)) + if decl != nil { + rtn[decl.Name] = decl.Value + } + } + return rtn +} + +func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType { + if state == nil { + return nil + } + var rtn []*DeclareDeclType + vars := bytes.Split(state.ShellVars, []byte{0}) + for _, varLine := range vars { + decl := ParseDeclLine(string(varLine)) + if decl != nil { + rtn = append(rtn, decl) + } + } + return rtn +} + +func IsValidBashIdentifier(s string) bool { + return bashValidIdentifierRe.MatchString(s) +} + +func (d *DeclareDeclType) IsExport() bool { + return strings.Index(d.Args, "x") >= 0 +} + +func (d *DeclareDeclType) IsReadOnly() bool { + return strings.Index(d.Args, "r") >= 0 +} + +func parseDeclareStmt(stmt *syntax.Stmt, src []byte) (*DeclareDeclType, error) { cmd := stmt.Cmd decl, ok := cmd.(*syntax.DeclClause) if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 { - return fmt.Errorf("invalid declare variant") + return nil, fmt.Errorf("invalid declare variant") } + rtn := &DeclareDeclType{} declArgs := decl.Args[0] if !declArgs.Naked || len(declArgs.Value.Parts) != 1 { - return fmt.Errorf("wrong number of declare args parts") + return nil, fmt.Errorf("wrong number of declare args parts") } - declArgLit, ok := declArgs.Value.Parts[0].(*syntax.Lit) + declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit) if !ok { - return fmt.Errorf("declare args is not a literal") + return nil, fmt.Errorf("declare args is not a literal") } - declArgStr := declArgLit.Value - if !strings.HasPrefix(declArgStr, "-") { - return fmt.Errorf("declare args not an argument (does not start with '-')") + if !strings.HasPrefix(declArgsLit.Value, "-") { + return nil, fmt.Errorf("declare args not an argument (does not start with '-')") + } + if declArgsLit.Value == "--" { + rtn.Args = "" + } else { + rtn.Args = declArgsLit.Value[1:] } declAssign := decl.Args[1] if declAssign.Name == nil { - return fmt.Errorf("declare does not have a valid name") + return nil, fmt.Errorf("declare does not have a valid name") } - varName := declAssign.Name.Value - if NoStoreVarNames[varName] { - return nil + rtn.Name = declAssign.Name.Value + if declAssign.Naked || declAssign.Index != nil || declAssign.Append { + return nil, fmt.Errorf("invalid decl format") } - if strings.Index(varName, "=") != -1 || strings.Index(varName, "\x00") != -1 { - return fmt.Errorf("invalid varname (cannot contain '=' or 0 byte)") + if declAssign.Value != nil { + varValueStr, err := QuotedLitToStr(declAssign.Value) + if err != nil { + return nil, fmt.Errorf("parsing declare value: %w", err) + } + rtn.Value = varValueStr + } else if declAssign.Array != nil { + rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()]) + } else { + return nil, fmt.Errorf("invalid decl, not plain value or array") } - fullDeclBytes := src[decl.Pos().Offset():decl.End().Offset()] - if strings.Index(declArgStr, "x") == -1 { - // non-exported vars get written to vars as decl statements - varsBuffer.Write(fullDeclBytes) - varsBuffer.WriteRune('\n') - return nil + if err := rtn.Validate(); err != nil { + return nil, err } - if declArgStr != "-x" { - return fmt.Errorf("can only export plain bash variables (no arrays)") - } - // exported vars are parsed into Env0 format - if declAssign.Naked || declAssign.Array != nil || declAssign.Index != nil || declAssign.Append || declAssign.Value == nil { - return fmt.Errorf("invalid variable to export") - } - varValue := declAssign.Value - varValueStr, err := QuotedLitToStr(varValue) - if err != nil { - return fmt.Errorf("parsing declare value: %w", err) - } - if strings.Index(varValueStr, "\x00") != -1 { - return fmt.Errorf("invalid export var value (cannot contain 0 byte)") - } - envBuffer.WriteString(fmt.Sprintf("%s=%s\x00", varName, varValueStr)) - return nil + return rtn, nil } func parseDeclareOutput(state *packet.ShellState, declareBytes []byte) error { @@ -164,16 +303,23 @@ func parseDeclareOutput(state *packet.ShellState, declareBytes []byte) error { if err != nil { return err } - var envBuffer, varsBuffer bytes.Buffer + var varsBuffer bytes.Buffer + var firstParseErr error for _, stmt := range file.Stmts { - err = parseDeclareStmt(&envBuffer, &varsBuffer, stmt, declareBytes) + decl, err := parseDeclareStmt(stmt, declareBytes) if err != nil { - // TODO where to put parse errors? - continue + if firstParseErr == nil { + firstParseErr = err + } + } + if decl != nil && !NoStoreVarNames[decl.Name] { + varsBuffer.WriteString(decl.Serialize()) } } - state.Env0 = envBuffer.Bytes() - state.ShellVars = varsBuffer.String() + state.ShellVars = varsBuffer.Bytes() + if firstParseErr != nil { + state.Error = firstParseErr.Error() + } return nil } @@ -193,7 +339,10 @@ func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) { cwdStr = cwdStr[0 : len(cwdStr)-2] } rtn.Cwd = string(cwdStr) - parseDeclareOutput(rtn, fields[2]) + err := parseDeclareOutput(rtn, fields[2]) + if err != nil { + return nil, err + } rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n") rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n") return rtn, nil diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index e7c8818e0..7d29ccdc9 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -282,7 +282,7 @@ func MakeDetachedExecCmd(pk *packet.RunPacketType, cmdTty *os.File) (*exec.Cmd, if !pk.StateComplete { ecmd.Env = os.Environ() } - UpdateCmdEnv(ecmd, ParseEnv0(state.Env0)) + UpdateCmdEnv(ecmd, EnvMapFromState(state)) UpdateCmdEnv(ecmd, map[string]string{"TERM": getTermType(pk)}) if state.Cwd != "" { ecmd.Dir = base.ExpandHomeDir(state.Cwd) @@ -964,26 +964,38 @@ func getTermType(pk *packet.RunPacketType) string { } func makeRcFileStr(pk *packet.RunPacketType) string { - rcFileStr := ` + var rcBuf bytes.Buffer + rcBuf.WriteString(` set +m set +H shopt -s extglob -` +`) + + varDecls := VarDeclsFromState(pk.State) + for _, varDecl := range varDecls { + if varDecl.IsExport() || varDecl.IsReadOnly() { + continue + } + rcBuf.WriteString(varDecl.DeclareStmt()) + rcBuf.WriteString("\n") + } if pk.State != nil && pk.State.Funcs != "" { - rcFileStr += pk.State.Funcs + "\n" + rcBuf.WriteString(pk.State.Funcs) + rcBuf.WriteString("\n") } if pk.State != nil && pk.State.Aliases != "" { - rcFileStr += pk.State.Aliases + "\n" + rcBuf.WriteString(pk.State.Aliases) + rcBuf.WriteString("\n") } if pk.ReturnState { - rcFileStr += ` + rcBuf.WriteString(` _scripthaus_exittrap () { ` + GetShellStateCmd + ` } trap _scripthaus_exittrap EXIT -` +`) } - return rcFileStr + return rcBuf.String() } func makeExitTrap(fdNum int) string { @@ -1042,7 +1054,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro if !pk.StateComplete { cmd.Cmd.Env = os.Environ() } - UpdateCmdEnv(cmd.Cmd, ParseEnv0(state.Env0)) + UpdateCmdEnv(cmd.Cmd, EnvMapFromState(state)) if state.Cwd != "" { cmd.Cmd.Dir = base.ExpandHomeDir(state.Cwd) }