diff --git a/main-runner.go b/main-runner.go index 04fb9f4a6..19b0ba3c2 100644 --- a/main-runner.go +++ b/main-runner.go @@ -11,23 +11,30 @@ import ( "os" "os/signal" "syscall" + "time" + "github.com/google/uuid" + "github.com/scripthaus-dev/sh2-runner/pkg/base" "github.com/scripthaus-dev/sh2-runner/pkg/packet" "github.com/scripthaus-dev/sh2-runner/pkg/shexec" ) -func setupSignals(cmd *shexec.ShExecType) { +// in single run mode, we don't want the runner to die from signals +// since we want the single runner to persist even if session / main runner +// is terminated. +func setupSingleSignals(cmd *shexec.ShExecType) { sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) go func() { - for sig := range sigCh { - cmd.Cmd.Process.Signal(sig) + for range sigCh { + // do nothing } }() } -func main() { +func doSingle(cmdId string) { packetCh := packet.PacketParser(os.Stdin) + sender := packet.MakePacketSender(os.Stdout) var runPacket *packet.RunPacketType for pk := range packetCh { if pk.GetType() == packet.PingPacketStr { @@ -37,24 +44,134 @@ func main() { runPacket, _ = pk.(*packet.RunPacketType) break } - if pk.GetType() == packet.ErrorPacketStr { - packet.SendPacket(os.Stdout, pk) - return - } - packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType())) + sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType())) return } if runPacket == nil { - packet.SendErrorPacket(os.Stdout, "did not receive a 'run' packet") + sender.SendErrorPacket("did not receive a 'run' packet") return } - cmd, err := shexec.RunCommand(runPacket) + if runPacket.CmdId == "" { + runPacket.CmdId = cmdId + } + if runPacket.CmdId != cmdId { + sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId)) + return + } + cmd, err := shexec.RunCommand(runPacket, sender) if err != nil { - packet.SendErrorPacket(os.Stdout, fmt.Sprintf("error running command: %v", err)) + sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err)) return } - setupSignals(cmd) - packet.SendPacket(os.Stdout, packet.MakeOkCmdPacket(fmt.Sprintf("running command %s/%s", runPacket.SessionId, runPacket.CmdId), runPacket.CmdId, cmd.Cmd.Process.Pid)) - cmd.WaitForCommand() - packet.SendPacket(os.Stdout, packet.MakeDonePacket()) + setupSingleSignals(cmd) + startPacket := packet.MakeCmdStartPacket() + startPacket.Ts = time.Now().UnixMilli() + startPacket.CmdId = runPacket.CmdId + startPacket.Pid = cmd.Cmd.Process.Pid + startPacket.RunnerPid = os.Getpid() + sender.SendPacket(startPacket) + donePacket := cmd.WaitForCommand(runPacket.CmdId) + sender.SendPacket(donePacket) + sender.CloseSendCh() + sender.WaitForDone() +} + +func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { + if pk.CmdId == "" { + pk.CmdId = uuid.New().String() + } + err := shexec.ValidateRunPacket(pk) + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err))) + return + } + fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId) + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err))) + return + } + cmd, err := shexec.MakeRunnerExec(pk.CmdId) + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make runner command: %v", err))) + return + } + cmdStdin, err := cmd.StdinPipe() + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err))) + return + } + runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))) + return + } + cmd.Stdout = runnerOutFd + cmd.Stderr = runnerOutFd + err = cmd.Start() + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err))) + return + } + go func() { + err = packet.SendPacket(cmdStdin, pk) + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err))) + return + } + cmdStdin.Close() + + // clean up zombies + cmd.Wait() + }() +} + +func doMain() { + homeDir, err := base.GetScHomeDir() + if err != nil { + packet.SendErrorPacket(os.Stdout, err.Error()) + return + } + err = os.Chdir(homeDir) + if err != nil { + packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to scripthaus home '%s': %v", homeDir, err)) + return + } + err = base.EnsureRunnerPath() + if err != nil { + packet.SendErrorPacket(os.Stdout, err.Error()) + return + } + packetCh := packet.PacketParser(os.Stdin) + sender := packet.MakePacketSender(os.Stdout) + sender.SendPacket(packet.MakeMessagePacket(fmt.Sprintf("starting scripthaus runner @ %s", homeDir))) + for pk := range packetCh { + if pk.GetType() == packet.PingPacketStr { + continue + } + if pk.GetType() == packet.RunPacketStr { + doMainRun(pk.(*packet.RunPacketType), sender) + continue + } + if pk.GetType() == packet.ErrorPacketStr { + errPk := pk.(*packet.ErrorPacketType) + errPk.Error = "invalid packet sent to runner: " + errPk.Error + sender.SendPacket(errPk) + continue + } + sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType())) + } +} + +func main() { + if len(os.Args) >= 2 { + cmdId, err := uuid.Parse(os.Args[1]) + if err != nil { + packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to runner", err)) + return + } + doSingle(cmdId.String()) + return + } else { + doMain() + } } diff --git a/pkg/base/base.go b/pkg/base/base.go index ae6774cec..fd2d64ed0 100644 --- a/pkg/base/base.go +++ b/pkg/base/base.go @@ -27,9 +27,9 @@ const ScReadyString = "scripthaus runner ready" const OSCEscError = "error" type CommandFileNames struct { - PtyOutFile string - StdinFifo string - DoneFile string + PtyOutFile string + StdinFifo string + RunnerOutFile string } func GetScHomeDir() (string, error) { @@ -54,9 +54,9 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err } base := path.Join(sdir, cmdId) return &CommandFileNames{ - PtyOutFile: base + ".ptyout", - StdinFifo: base + ".stdin", - DoneFile: base + ".done", + PtyOutFile: base + ".ptyout", + StdinFifo: base + ".stdin", + RunnerOutFile: base + ".runout", }, nil } @@ -108,32 +108,50 @@ func EnsureSessionDir(sessionId string) (string, error) { return sdir, nil } -func GetScRunnerPath() string { +func GetScRunnerPath() (string, error) { runnerPath := os.Getenv(ScRunnerVarName) if runnerPath != "" { - return runnerPath + return runnerPath, nil } scHome, err := GetScHomeDir() if err != nil { - panic(err) + return "", err } - return path.Join(scHome, RunnerBaseName) + return path.Join(scHome, RunnerBaseName), nil } -func GetScSessionsDir() string { - scHome, err := GetScHomeDir() +func EnsureRunnerPath() error { + runnerPath, err := GetScRunnerPath() if err != nil { - panic(err) + return err } - return path.Join(scHome, SessionsDirBaseName) + info, err := os.Stat(runnerPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("cannot find scripthaus runner at path '%s'", runnerPath) + } + return fmt.Errorf("error stating scripthaus runner at path '%s'", runnerPath) + } + if info.Mode()&0100 == 0 { + return fmt.Errorf("scripthaus runner at path '%s' is not executable mode=%#o", runnerPath, info.Mode()) + } + return nil } -func GetSessionDBName(sessionId string) string { +func GetScSessionsDir() (string, error) { scHome, err := GetScHomeDir() if err != nil { - panic(err) + return "", err } - return path.Join(scHome, SessionDBName) + return path.Join(scHome, SessionsDirBaseName), nil +} + +func GetSessionDBName(sessionId string) (string, error) { + scHome, err := GetScHomeDir() + if err != nil { + return "", err + } + return path.Join(scHome, SessionDBName), nil } // SH OSC Escapes (code 198, S=19, H=8) diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 52e3a0da7..623284673 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -11,13 +11,16 @@ import ( "encoding/json" "fmt" "io" + "sync" ) const RunPacketStr = "run" const PingPacketStr = "ping" const DonePacketStr = "done" const ErrorPacketStr = "error" -const OkCmdPacketStr = "okcmd" +const MessagePacketStr = "message" +const CmdStartPacketStr = "cmdstart" +const CmdDonePacketStr = "cmddone" type PingPacketType struct { Type string `json:"type"` @@ -31,6 +34,19 @@ func MakePingPacket() *PingPacketType { return &PingPacketType{Type: PingPacketStr} } +type MessagePacketType struct { + Type string `json:"type"` + Message string `json:"message"` +} + +func (*MessagePacketType) GetType() string { + return MessagePacketStr +} + +func MakeMessagePacket(message string) *MessagePacketType { + return &MessagePacketType{Type: MessagePacketStr, Message: message} +} + type DonePacketType struct { Type string `json:"type"` } @@ -43,27 +59,44 @@ func MakeDonePacket() *DonePacketType { return &DonePacketType{Type: DonePacketStr} } -type OkCmdPacketType struct { - Type string `json:"type"` - Message string `json:"message"` - CmdId string `json:"cmdid"` - Pid int `json:"pid"` +type CmdDonePacketType struct { + Type string `json:"type"` + Ts int64 `json:"ts"` + CmdId string `json:"cmdid"` + ExitCode int `json:"exitcode"` + DurationMs int64 `json:"durationms"` } -func (*OkCmdPacketType) GetType() string { - return OkCmdPacketStr +func (*CmdDonePacketType) GetType() string { + return CmdDonePacketStr } -func MakeOkCmdPacket(message string, cmdId string, pid int) *OkCmdPacketType { - return &OkCmdPacketType{Type: OkCmdPacketStr, Message: message, CmdId: cmdId, Pid: pid} +func MakeCmdDonePacket() *CmdDonePacketType { + return &CmdDonePacketType{Type: CmdDonePacketStr} +} + +type CmdStartPacketType struct { + Type string `json:"type"` + Ts int64 `json:"ts"` + CmdId string `json:"cmdid"` + Pid int `json:"pid"` + RunnerPid int `json:"runnerpid"` +} + +func (*CmdStartPacketType) GetType() string { + return CmdStartPacketStr +} + +func MakeCmdStartPacket() *CmdStartPacketType { + return &CmdStartPacketType{Type: CmdStartPacketStr} } type RunPacketType struct { Type string `json:"type"` SessionId string `json:"sessionid"` CmdId string `json:"cmdid"` - ChDir string `json:"chdir"` - Env map[string]string `json:"env"` + ChDir string `json:"chdir,omitempty"` + Env map[string]string `json:"env,omitempty"` Command string `json:"command"` } @@ -76,6 +109,7 @@ type BarePacketType struct { } type ErrorPacketType struct { + Id string `json:"id,omitempty"` Type string `json:"type"` Error string `json:"error"` } @@ -88,6 +122,10 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType { return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr} } +func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType { + return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr} +} + type PacketType interface { GetType() string } @@ -123,13 +161,21 @@ func ParseJsonPacket(jsonBuf []byte) (PacketType, error) { } return &errorPacket, nil } - if bareCmd.Type == OkCmdPacketStr { - var okPacket OkCmdPacketType - err = json.Unmarshal(jsonBuf, &okPacket) + if bareCmd.Type == CmdStartPacketStr { + var startPacket CmdStartPacketType + err = json.Unmarshal(jsonBuf, &startPacket) if err != nil { return nil, err } - return &okPacket, nil + return &startPacket, nil + } + if bareCmd.Type == CmdDonePacketStr { + var donePacket CmdDonePacketType + err = json.Unmarshal(jsonBuf, &donePacket) + if err != nil { + return nil, err + } + return &donePacket, nil } return nil, fmt.Errorf("invalid packet-type '%s'", bareCmd.Type) } @@ -154,6 +200,73 @@ func SendErrorPacket(w io.Writer, errorStr string) error { return SendPacket(w, MakeErrorPacket(errorStr)) } +type PacketSender struct { + Lock *sync.Mutex + SendCh chan PacketType + Err error + Done bool + DoneCh chan bool +} + +func MakePacketSender(output io.Writer) *PacketSender { + sender := &PacketSender{ + Lock: &sync.Mutex{}, + SendCh: make(chan PacketType), + DoneCh: make(chan bool), + } + go func() { + defer func() { + sender.Lock.Lock() + sender.Done = true + sender.Lock.Unlock() + close(sender.DoneCh) + }() + for pk := range sender.SendCh { + err := SendPacket(output, pk) + if err != nil { + sender.Lock.Lock() + sender.Err = err + sender.Lock.Unlock() + return + } + } + }() + return sender +} + +func (sender *PacketSender) CloseSendCh() { + close(sender.SendCh) +} + +func (sender *PacketSender) WaitForDone() { + <-sender.DoneCh +} + +func (sender *PacketSender) checkStatus() error { + sender.Lock.Lock() + defer sender.Lock.Unlock() + if sender.Done { + return fmt.Errorf("cannot send packet, sender write loop is closed") + } + if sender.Err != nil { + return fmt.Errorf("cannot send packet, sender had error: %w", sender.Err) + } + return nil +} + +func (sender *PacketSender) SendPacket(pk PacketType) error { + err := sender.checkStatus() + if err != nil { + return err + } + sender.SendCh <- pk + return nil +} + +func (sender *PacketSender) SendErrorPacket(errVal string) error { + return sender.SendPacket(MakeErrorPacket(errVal)) +} + func PacketParser(input io.Reader) chan PacketType { bufReader := bufio.NewReader(input) rtnCh := make(chan PacketType) diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 4ef5f6607..70e75f1d4 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -7,7 +7,6 @@ package shexec import ( - "encoding/json" "errors" "fmt" "io" @@ -24,11 +23,6 @@ import ( "github.com/scripthaus-dev/sh2-runner/pkg/packet" ) -type DoneData struct { - DurationMs int64 `json:"durationms"` - ExitCode int `json:"exitcode"` -} - type ShExecType struct { FileNames *base.CommandFileNames Cmd *exec.Cmd @@ -95,6 +89,15 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd { return ecmd } +func MakeRunnerExec(cmdId string) (*exec.Cmd, error) { + runnerPath, err := base.GetScRunnerPath() + if err != nil { + return nil, err + } + ecmd := exec.Command(runnerPath, cmdId) + return ecmd, nil +} + // this will never return (unless there is an error creating/opening the file), as fifoFile will never EOF func MakeAndCopyStdinFifo(dst *os.File, fifoName string) error { os.Remove(fifoName) @@ -147,8 +150,8 @@ func ValidateRunPacket(pk *packet.RunPacketType) error { return nil } -// returning nil error means the process has successfully been kicked-off -func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) { +// when err is nil, the command will have already been started +func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) { if pk.CmdId == "" { pk.CmdId = uuid.New().String() } @@ -184,14 +187,14 @@ func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) { // copy pty output to .ptyout file _, copyErr := io.Copy(ptyOutFd, cmdPty) if copyErr != nil { - base.WriteErrorMsg(fileNames.PtyOutFile, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr)) + sender.SendErrorPacket(fmt.Sprintf("copying pty output to ptyout file: %v", copyErr)) } }() go func() { // copy .stdin fifo contents to pty input copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo) if copyFifoErr != nil { - base.WriteErrorMsg(fileNames.PtyOutFile, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr)) + sender.SendErrorPacket(fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr)) } }() return &ShExecType{ @@ -202,9 +205,10 @@ func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) { }, nil } -func (c *ShExecType) WaitForCommand() { +func (c *ShExecType) WaitForCommand(cmdId string) *packet.CmdDonePacketType { err := c.Cmd.Wait() - cmdDuration := time.Since(c.StartTs) + endTs := time.Now() + cmdDuration := endTs.Sub(c.StartTs) exitCode := 0 if err != nil { exitErr, ok := err.(*exec.ExitError) @@ -212,15 +216,11 @@ func (c *ShExecType) WaitForCommand() { exitCode = exitErr.ExitCode() } } - doneData := DoneData{ - DurationMs: int64(cmdDuration / time.Millisecond), - ExitCode: exitCode, - } - doneDataBytes, _ := json.Marshal(doneData) - doneDataBytes = append(doneDataBytes, '\n') - err = os.WriteFile(c.FileNames.DoneFile, doneDataBytes, 0600) - if err != nil { - base.WriteErrorMsg(c.FileNames.PtyOutFile, fmt.Sprintf("reading from stdin fifo: %v", err)) - } - return + donePacket := packet.MakeCmdDonePacket() + donePacket.Ts = endTs.UnixMilli() + donePacket.CmdId = cmdId + donePacket.ExitCode = exitCode + donePacket.DurationMs = int64(cmdDuration / time.Millisecond) + os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error) + return donePacket }