returnstate option for runpk (for sourcing files)

This commit is contained in:
sawka 2022-10-22 14:45:31 -07:00
parent b9c3940b99
commit d8b5508b77
3 changed files with 163 additions and 44 deletions

View File

@ -480,6 +480,7 @@ type CmdDonePacketType struct {
CK base.CommandKey `json:"ck"`
ExitCode int `json:"exitcode"`
DurationMs int64 `json:"durationms"`
FinalState *ShellState `json:"state,omitempty"`
}
func (*CmdDonePacketType) GetType() string {
@ -551,6 +552,7 @@ type RunPacketType struct {
Fds []RemoteFd `json:"fds,omitempty"`
RunData []RunDataType `json:"rundata,omitempty"`
Detached bool `json:"detached,omitempty"`
ReturnState bool `json:"returnstate,omitempty"`
}
func (*RunPacketType) GetType() string {

View File

@ -14,6 +14,8 @@ import (
// TODO - track buffer sizes for sending input
const NotFoundVersion = "v0.0"
type ClientProc struct {
Cmd *exec.Cmd
InitPk *packet.InitPacketType
@ -25,24 +27,24 @@ type ClientProc struct {
Output *packet.PacketParser
}
// returns (clientproc, uname, error)
func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, string, error) {
// returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.InitPacketType, error) {
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, "", fmt.Errorf("creating stdin pipe: %v", err)
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, "", fmt.Errorf("creating stdout pipe: %v", err)
return nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
return nil, "", fmt.Errorf("creating stderr pipe: %v", err)
return nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
startTs := time.Now()
err = ecmd.Start()
if err != nil {
return nil, "", fmt.Errorf("running local client: %w", err)
return nil, nil, fmt.Errorf("running local client: %w", err)
}
sender := packet.MakePacketSender(inputWriter)
stdoutPacketParser := packet.MakePacketParser(stdoutReader)
@ -63,29 +65,29 @@ func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, string, e
case pk = <-packetParser.MainCh:
case <-ctx.Done():
cproc.Close()
return nil, "", ctx.Err()
return nil, nil, ctx.Err()
}
if pk != nil {
if pk.GetType() != packet.InitPacketStr {
cproc.Close()
return nil, "", fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
}
initPk := pk.(*packet.InitPacketType)
if initPk.NotFound {
cproc.Close()
return nil, initPk.UName, fmt.Errorf("mshell-%s command not found on local server", semver.MajorMinor(base.MShellVersion))
return nil, initPk, fmt.Errorf("mshell-%s command not found on local server", semver.MajorMinor(base.MShellVersion))
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close()
return nil, initPk.UName, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
}
cproc.InitPk = initPk
}
if cproc.InitPk == nil {
cproc.Close()
return nil, "", fmt.Errorf("no init packet received from mshell client")
return nil, nil, fmt.Errorf("no init packet received from mshell client")
}
return cproc, cproc.InitPk.UName, nil
return cproc, cproc.InitPk, nil
}
func (cproc *ClientProc) Close() {

View File

@ -18,6 +18,7 @@ import (
"os/user"
"runtime"
"strings"
"sync"
"syscall"
"time"
@ -81,6 +82,20 @@ const RunCommandFmt = `%s`
const RunSudoCommandFmt = `sudo -n -C %d bash /dev/fd/%d`
const RunSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -k -S -C %d bash -c "echo '[from-mshell]'; exec %d>&-; bash /dev/fd/%d < /dev/fd/%d"`
type ReturnStateBuf struct {
Lock *sync.Mutex
Buf []byte
Done bool
Err error
Reader *os.File
FdNum int
DoneCh chan bool
}
func MakeReturnStateBuf() *ReturnStateBuf {
return &ReturnStateBuf{Lock: &sync.Mutex{}, DoneCh: make(chan bool)}
}
type ShExecType struct {
StartTs time.Time
CK base.CommandKey
@ -93,6 +108,7 @@ type ShExecType struct {
DetachedOutput *packet.PacketSender
RunnerOutFd *os.File
MsgSender *packet.PacketSender // where to send out-of-band messages back to calling proceess
ReturnState *ReturnStateBuf
}
type StdContext struct{}
@ -192,6 +208,9 @@ func (c *ShExecType) Close() {
if c.RunnerOutFd != nil {
c.RunnerOutFd.Close()
}
if c.ReturnState != nil {
c.ReturnState.Reader.Close()
}
}
func (c *ShExecType) MakeCmdStartPacket(reqId string) *packet.CmdStartPacketType {
@ -926,6 +945,9 @@ func DetectGoArch(uname string) (string, string, error) {
func (cmd *ShExecType) RunRemoteIOAndWait(packetParser *packet.PacketParser, sender *packet.PacketSender) {
defer cmd.Close()
if cmd.ReturnState != nil {
go cmd.ReturnState.Run()
}
cmd.Multiplexer.RunIOAndWait(packetParser, sender, true, false, false)
donePacket := cmd.WaitForCommand()
sender.SendPacket(donePacket)
@ -939,42 +961,87 @@ func getTermType(pk *packet.RunPacketType) string {
return termType
}
func makeEnvCommandStr(pk *packet.RunPacketType) string {
fmtStr := `
shopt -q -s expand_aliases
func makeRcFileStr(pk *packet.RunPacketType) string {
rcFileStr := `
set +m
%s
%s
%s
set +H
shopt -s extglob
`
state := pk.State
if state == nil {
state = &packet.ShellState{}
if pk.State != nil && pk.State.Funcs != "" {
rcFileStr += pk.State.Funcs + "\n"
}
return fmt.Sprintf(fmtStr, state.Aliases, state.Funcs, pk.Command)
if pk.State != nil && pk.State.Aliases != "" {
rcFileStr += pk.State.Aliases + "\n"
}
if pk.ReturnState {
rcFileStr += `
_scripthaus_exittrap () {
%s --env; alias -p; printf \"\\x00\\x00\"; declare -f;
}
trap _scripthaus_exittrap EXIT
`
}
return rcFileStr
}
func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (*ShExecType, error) {
func makeExitTrap(fdNum int) (string, error) {
stateCmd, err := GetShellStateRedirectCommandStr(fdNum)
if err != nil {
return "", err
}
fmtStr := `
_scripthaus_exittrap () {
%s
}
trap _scripthaus_exittrap EXIT
`
return fmt.Sprintf(fmtStr, stateCmd), nil
}
func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (rtnShExec *ShExecType, rtnErr error) {
state := pk.State
if state == nil {
state = &packet.ShellState{}
}
cmd := MakeShExec(pk.CK, nil)
defer func() {
// on error, call cmd.Close()
if rtnErr != nil {
cmd.Close()
}
}()
if fromServer {
msgUpr := packet.MessageUPR{CK: pk.CK, Sender: sender}
upr := ShExecUPR{ShExec: cmd, UPR: msgUpr}
cmd.Multiplexer.UPR = upr
cmd.MsgSender = sender
}
commandStr := makeEnvCommandStr(pk)
commandFdNum, err := AddRunData(pk, commandStr, "command")
var rtnStateWriter *os.File
rcFileStr := makeRcFileStr(pk)
if pk.ReturnState {
pr, pw, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("cannot create returnstate pipe: %v", err)
}
cmd.ReturnState = MakeReturnStateBuf()
cmd.ReturnState.Reader = pr
cmd.ReturnState.FdNum = 20
rtnStateWriter = pw
defer pw.Close()
trapCmdStr, err := makeExitTrap(cmd.ReturnState.FdNum)
if err != nil {
return nil, err
}
rcFileStr += trapCmdStr
}
rcFileFdNum, err := AddRunData(pk, rcFileStr, "rcfile")
if err != nil {
return nil, err
}
if pk.UsePty {
cmd.Cmd = exec.Command("bash", "-i", fmt.Sprintf("/dev/fd/%d", commandFdNum))
cmd.Cmd = exec.Command("bash", "--rcfile", fmt.Sprintf("/dev/fd/%d", rcFileFdNum), "-i", "-c", pk.Command)
} else {
cmd.Cmd = exec.Command("bash", fmt.Sprintf("/dev/fd/%d", commandFdNum))
cmd.Cmd = exec.Command("bash", "--rcfile", fmt.Sprintf("/dev/fd/%d", rcFileFdNum), "-c", pk.Command)
}
if !pk.StateComplete {
cmd.Cmd.Env = os.Environ()
@ -985,7 +1052,6 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
}
err = ValidateRemoteFds(pk.Fds)
if err != nil {
cmd.Close()
return nil, err
}
var cmdPty *os.File
@ -1014,24 +1080,20 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
cmd.Multiplexer.MakeRawFdReader(1, cmdPty, false, true)
nullFd, err := os.Open("/dev/null")
if err != nil {
cmd.Close()
return nil, fmt.Errorf("cannot open /dev/null: %w", err)
}
cmd.Multiplexer.MakeRawFdReader(2, nullFd, true, false)
} else {
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
if err != nil {
cmd.Close()
return nil, err
}
cmd.Cmd.Stdout, err = cmd.Multiplexer.MakeReaderPipe(1)
if err != nil {
cmd.Close()
return nil, err
}
cmd.Cmd.Stderr, err = cmd.Multiplexer.MakeReaderPipe(2)
if err != nil {
cmd.Close()
return nil, err
}
}
@ -1042,7 +1104,6 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
}
extraFiles[runData.FdNum], err = cmd.Multiplexer.MakeStaticWriterPipe(runData.FdNum, runData.Data)
if err != nil {
cmd.Close()
return nil, err
}
}
@ -1054,7 +1115,6 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
// client file is open for reading, so we make a writer pipe
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)
if err != nil {
cmd.Close()
return nil, err
}
}
@ -1062,23 +1122,53 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
// client file is open for writing, so we make a reader pipe
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeReaderPipe(rfd.FdNum)
if err != nil {
cmd.Close()
return nil, err
}
}
}
if cmd.ReturnState != nil {
if cmd.ReturnState.FdNum >= len(extraFiles) {
extraFiles = extraFiles[:cmd.ReturnState.FdNum+1]
}
extraFiles[cmd.ReturnState.FdNum] = rtnStateWriter
}
if len(extraFiles) > FirstExtraFilesFdNum {
cmd.Cmd.ExtraFiles = extraFiles[FirstExtraFilesFdNum:]
}
err = cmd.Cmd.Start()
if err != nil {
cmd.Close()
return nil, err
}
return cmd, nil
}
// TODO limit size of read state buffer
func (rs *ReturnStateBuf) Run() {
buf := make([]byte, 1024)
defer func() {
rs.Lock.Lock()
defer rs.Lock.Unlock()
rs.Reader.Close()
rs.Done = true
close(rs.DoneCh)
}()
for {
n, readErr := rs.Reader.Read(buf)
if readErr == io.EOF {
break
}
if readErr != nil {
rs.Lock.Lock()
rs.Err = readErr
rs.Lock.Unlock()
break
}
rs.Lock.Lock()
rs.Buf = append(rs.Buf, buf[0:n]...)
rs.Lock.Unlock()
}
}
// in detached run mode, we don't want mshell to die from signals
// since we want mshell to persist even if the mshell --server is terminated
func SetupSignalsForDetach() {
@ -1220,11 +1310,16 @@ func GetExitCode(err error) int {
}
func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
donePacket := packet.MakeCmdDonePacket(c.CK)
exitErr := c.Cmd.Wait()
if c.ReturnState != nil {
<-c.ReturnState.DoneCh
state, _ := ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
donePacket.FinalState = state
}
endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs)
exitCode := GetExitCode(exitErr)
donePacket := packet.MakeCmdDonePacket(c.CK)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = exitCode
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
@ -1343,17 +1438,23 @@ func runSimpleCmdInPty(ecmd *exec.Cmd) ([]byte, error) {
return outputBuf.Bytes(), nil
}
func GetShellState() (*packet.ShellState, error) {
func GetShellStateCommandStr() (string, error) {
execFile, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("cannot find local mshell executable: %w", err)
return "", fmt.Errorf("cannot find local mshell executable: %w", err)
}
ctx, _ := context.WithTimeout(context.Background(), GetStateTimeout)
ecmd := exec.CommandContext(ctx, "bash", "-l", "-i", "-c", fmt.Sprintf("%s --env; alias -p; printf \"\\x00\\x00\"; declare -f", shellescape.Quote(execFile)))
outputBytes, err := runSimpleCmdInPty(ecmd)
return fmt.Sprintf(`%s --env; alias -p; printf \"\\x00\\x00\"; declare -f`, shellescape.Quote(execFile)), nil
}
func GetShellStateRedirectCommandStr(outputFdNum int) (string, error) {
cmdStr, err := GetShellStateCommandStr()
if err != nil {
return nil, err
return "", err
}
return fmt.Sprintf("cat <(%s) > /dev/fd/%d", cmdStr, outputFdNum), nil
}
func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
fields := bytes.Split(outputBytes, []byte{0, 0})
if len(fields) != 4 {
return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields))
@ -1367,3 +1468,17 @@ func GetShellState() (*packet.ShellState, error) {
rtn.Funcs = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
return rtn, nil
}
func GetShellState() (*packet.ShellState, error) {
ctx, _ := context.WithTimeout(context.Background(), GetStateTimeout)
cmdStr, err := GetShellStateCommandStr()
if err != nil {
return nil, err
}
ecmd := exec.CommandContext(ctx, "bash", "-l", "-i", "-c", cmdStr)
outputBytes, err := runSimpleCmdInPty(ecmd)
if err != nil {
return nil, err
}
return ParseShellStateOutput(outputBytes)
}