go runner / runner-single fork flow working

This commit is contained in:
sawka 2022-06-10 21:37:21 -07:00
parent eeaeac8dc8
commit 1a3886c437
4 changed files with 321 additions and 73 deletions

View File

@ -11,23 +11,30 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
"github.com/google/uuid"
"github.com/scripthaus-dev/sh2-runner/pkg/base"
"github.com/scripthaus-dev/sh2-runner/pkg/packet" "github.com/scripthaus-dev/sh2-runner/pkg/packet"
"github.com/scripthaus-dev/sh2-runner/pkg/shexec" "github.com/scripthaus-dev/sh2-runner/pkg/shexec"
) )
func setupSignals(cmd *shexec.ShExecType) { // in single run mode, we don't want the runner to die from signals
// since we want the single runner to persist even if session / main runner
// is terminated.
func setupSingleSignals(cmd *shexec.ShExecType) {
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() { go func() {
for sig := range sigCh { for range sigCh {
cmd.Cmd.Process.Signal(sig) // do nothing
} }
}() }()
} }
func main() { func doSingle(cmdId string) {
packetCh := packet.PacketParser(os.Stdin) packetCh := packet.PacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
var runPacket *packet.RunPacketType var runPacket *packet.RunPacketType
for pk := range packetCh { for pk := range packetCh {
if pk.GetType() == packet.PingPacketStr { if pk.GetType() == packet.PingPacketStr {
@ -37,24 +44,134 @@ func main() {
runPacket, _ = pk.(*packet.RunPacketType) runPacket, _ = pk.(*packet.RunPacketType)
break break
} }
if pk.GetType() == packet.ErrorPacketStr { sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType()))
packet.SendPacket(os.Stdout, pk)
return
}
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType()))
return return
} }
if runPacket == nil { if runPacket == nil {
packet.SendErrorPacket(os.Stdout, "did not receive a 'run' packet") sender.SendErrorPacket("did not receive a 'run' packet")
return return
} }
cmd, err := shexec.RunCommand(runPacket) if runPacket.CmdId == "" {
runPacket.CmdId = cmdId
}
if runPacket.CmdId != cmdId {
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId))
return
}
cmd, err := shexec.RunCommand(runPacket, sender)
if err != nil { if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("error running command: %v", err)) sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err))
return return
} }
setupSignals(cmd) setupSingleSignals(cmd)
packet.SendPacket(os.Stdout, packet.MakeOkCmdPacket(fmt.Sprintf("running command %s/%s", runPacket.SessionId, runPacket.CmdId), runPacket.CmdId, cmd.Cmd.Process.Pid)) startPacket := packet.MakeCmdStartPacket()
cmd.WaitForCommand() startPacket.Ts = time.Now().UnixMilli()
packet.SendPacket(os.Stdout, packet.MakeDonePacket()) startPacket.CmdId = runPacket.CmdId
startPacket.Pid = cmd.Cmd.Process.Pid
startPacket.RunnerPid = os.Getpid()
sender.SendPacket(startPacket)
donePacket := cmd.WaitForCommand(runPacket.CmdId)
sender.SendPacket(donePacket)
sender.CloseSendCh()
sender.WaitForDone()
}
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
if pk.CmdId == "" {
pk.CmdId = uuid.New().String()
}
err := shexec.ValidateRunPacket(pk)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err)))
return
}
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err)))
return
}
cmd, err := shexec.MakeRunnerExec(pk.CmdId)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make runner command: %v", err)))
return
}
cmdStdin, err := cmd.StdinPipe()
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
return
}
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
return
}
cmd.Stdout = runnerOutFd
cmd.Stderr = runnerOutFd
err = cmd.Start()
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err)))
return
}
go func() {
err = packet.SendPacket(cmdStdin, pk)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err)))
return
}
cmdStdin.Close()
// clean up zombies
cmd.Wait()
}()
}
func doMain() {
homeDir, err := base.GetScHomeDir()
if err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
err = os.Chdir(homeDir)
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to scripthaus home '%s': %v", homeDir, err))
return
}
err = base.EnsureRunnerPath()
if err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
packetCh := packet.PacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
sender.SendPacket(packet.MakeMessagePacket(fmt.Sprintf("starting scripthaus runner @ %s", homeDir)))
for pk := range packetCh {
if pk.GetType() == packet.PingPacketStr {
continue
}
if pk.GetType() == packet.RunPacketStr {
doMainRun(pk.(*packet.RunPacketType), sender)
continue
}
if pk.GetType() == packet.ErrorPacketStr {
errPk := pk.(*packet.ErrorPacketType)
errPk.Error = "invalid packet sent to runner: " + errPk.Error
sender.SendPacket(errPk)
continue
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType()))
}
}
func main() {
if len(os.Args) >= 2 {
cmdId, err := uuid.Parse(os.Args[1])
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to runner", err))
return
}
doSingle(cmdId.String())
return
} else {
doMain()
}
} }

View File

@ -27,9 +27,9 @@ const ScReadyString = "scripthaus runner ready"
const OSCEscError = "error" const OSCEscError = "error"
type CommandFileNames struct { type CommandFileNames struct {
PtyOutFile string PtyOutFile string
StdinFifo string StdinFifo string
DoneFile string RunnerOutFile string
} }
func GetScHomeDir() (string, error) { func GetScHomeDir() (string, error) {
@ -54,9 +54,9 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err
} }
base := path.Join(sdir, cmdId) base := path.Join(sdir, cmdId)
return &CommandFileNames{ return &CommandFileNames{
PtyOutFile: base + ".ptyout", PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin", StdinFifo: base + ".stdin",
DoneFile: base + ".done", RunnerOutFile: base + ".runout",
}, nil }, nil
} }
@ -108,32 +108,50 @@ func EnsureSessionDir(sessionId string) (string, error) {
return sdir, nil return sdir, nil
} }
func GetScRunnerPath() string { func GetScRunnerPath() (string, error) {
runnerPath := os.Getenv(ScRunnerVarName) runnerPath := os.Getenv(ScRunnerVarName)
if runnerPath != "" { if runnerPath != "" {
return runnerPath return runnerPath, nil
} }
scHome, err := GetScHomeDir() scHome, err := GetScHomeDir()
if err != nil { if err != nil {
panic(err) return "", err
} }
return path.Join(scHome, RunnerBaseName) return path.Join(scHome, RunnerBaseName), nil
} }
func GetScSessionsDir() string { func EnsureRunnerPath() error {
scHome, err := GetScHomeDir() runnerPath, err := GetScRunnerPath()
if err != nil { if err != nil {
panic(err) return err
} }
return path.Join(scHome, SessionsDirBaseName) info, err := os.Stat(runnerPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot find scripthaus runner at path '%s'", runnerPath)
}
return fmt.Errorf("error stating scripthaus runner at path '%s'", runnerPath)
}
if info.Mode()&0100 == 0 {
return fmt.Errorf("scripthaus runner at path '%s' is not executable mode=%#o", runnerPath, info.Mode())
}
return nil
} }
func GetSessionDBName(sessionId string) string { func GetScSessionsDir() (string, error) {
scHome, err := GetScHomeDir() scHome, err := GetScHomeDir()
if err != nil { if err != nil {
panic(err) return "", err
} }
return path.Join(scHome, SessionDBName) return path.Join(scHome, SessionsDirBaseName), nil
}
func GetSessionDBName(sessionId string) (string, error) {
scHome, err := GetScHomeDir()
if err != nil {
return "", err
}
return path.Join(scHome, SessionDBName), nil
} }
// SH OSC Escapes (code 198, S=19, H=8) // SH OSC Escapes (code 198, S=19, H=8)

View File

@ -11,13 +11,16 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"sync"
) )
const RunPacketStr = "run" const RunPacketStr = "run"
const PingPacketStr = "ping" const PingPacketStr = "ping"
const DonePacketStr = "done" const DonePacketStr = "done"
const ErrorPacketStr = "error" const ErrorPacketStr = "error"
const OkCmdPacketStr = "okcmd" const MessagePacketStr = "message"
const CmdStartPacketStr = "cmdstart"
const CmdDonePacketStr = "cmddone"
type PingPacketType struct { type PingPacketType struct {
Type string `json:"type"` Type string `json:"type"`
@ -31,6 +34,19 @@ func MakePingPacket() *PingPacketType {
return &PingPacketType{Type: PingPacketStr} return &PingPacketType{Type: PingPacketStr}
} }
type MessagePacketType struct {
Type string `json:"type"`
Message string `json:"message"`
}
func (*MessagePacketType) GetType() string {
return MessagePacketStr
}
func MakeMessagePacket(message string) *MessagePacketType {
return &MessagePacketType{Type: MessagePacketStr, Message: message}
}
type DonePacketType struct { type DonePacketType struct {
Type string `json:"type"` Type string `json:"type"`
} }
@ -43,27 +59,44 @@ func MakeDonePacket() *DonePacketType {
return &DonePacketType{Type: DonePacketStr} return &DonePacketType{Type: DonePacketStr}
} }
type OkCmdPacketType struct { type CmdDonePacketType struct {
Type string `json:"type"` Type string `json:"type"`
Message string `json:"message"` Ts int64 `json:"ts"`
CmdId string `json:"cmdid"` CmdId string `json:"cmdid"`
Pid int `json:"pid"` ExitCode int `json:"exitcode"`
DurationMs int64 `json:"durationms"`
} }
func (*OkCmdPacketType) GetType() string { func (*CmdDonePacketType) GetType() string {
return OkCmdPacketStr return CmdDonePacketStr
} }
func MakeOkCmdPacket(message string, cmdId string, pid int) *OkCmdPacketType { func MakeCmdDonePacket() *CmdDonePacketType {
return &OkCmdPacketType{Type: OkCmdPacketStr, Message: message, CmdId: cmdId, Pid: pid} return &CmdDonePacketType{Type: CmdDonePacketStr}
}
type CmdStartPacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
CmdId string `json:"cmdid"`
Pid int `json:"pid"`
RunnerPid int `json:"runnerpid"`
}
func (*CmdStartPacketType) GetType() string {
return CmdStartPacketStr
}
func MakeCmdStartPacket() *CmdStartPacketType {
return &CmdStartPacketType{Type: CmdStartPacketStr}
} }
type RunPacketType struct { type RunPacketType struct {
Type string `json:"type"` Type string `json:"type"`
SessionId string `json:"sessionid"` SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"` CmdId string `json:"cmdid"`
ChDir string `json:"chdir"` ChDir string `json:"chdir,omitempty"`
Env map[string]string `json:"env"` Env map[string]string `json:"env,omitempty"`
Command string `json:"command"` Command string `json:"command"`
} }
@ -76,6 +109,7 @@ type BarePacketType struct {
} }
type ErrorPacketType struct { type ErrorPacketType struct {
Id string `json:"id,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Error string `json:"error"` Error string `json:"error"`
} }
@ -88,6 +122,10 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr} return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr}
} }
func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr}
}
type PacketType interface { type PacketType interface {
GetType() string GetType() string
} }
@ -123,13 +161,21 @@ func ParseJsonPacket(jsonBuf []byte) (PacketType, error) {
} }
return &errorPacket, nil return &errorPacket, nil
} }
if bareCmd.Type == OkCmdPacketStr { if bareCmd.Type == CmdStartPacketStr {
var okPacket OkCmdPacketType var startPacket CmdStartPacketType
err = json.Unmarshal(jsonBuf, &okPacket) err = json.Unmarshal(jsonBuf, &startPacket)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &okPacket, nil return &startPacket, nil
}
if bareCmd.Type == CmdDonePacketStr {
var donePacket CmdDonePacketType
err = json.Unmarshal(jsonBuf, &donePacket)
if err != nil {
return nil, err
}
return &donePacket, nil
} }
return nil, fmt.Errorf("invalid packet-type '%s'", bareCmd.Type) return nil, fmt.Errorf("invalid packet-type '%s'", bareCmd.Type)
} }
@ -154,6 +200,73 @@ func SendErrorPacket(w io.Writer, errorStr string) error {
return SendPacket(w, MakeErrorPacket(errorStr)) return SendPacket(w, MakeErrorPacket(errorStr))
} }
type PacketSender struct {
Lock *sync.Mutex
SendCh chan PacketType
Err error
Done bool
DoneCh chan bool
}
func MakePacketSender(output io.Writer) *PacketSender {
sender := &PacketSender{
Lock: &sync.Mutex{},
SendCh: make(chan PacketType),
DoneCh: make(chan bool),
}
go func() {
defer func() {
sender.Lock.Lock()
sender.Done = true
sender.Lock.Unlock()
close(sender.DoneCh)
}()
for pk := range sender.SendCh {
err := SendPacket(output, pk)
if err != nil {
sender.Lock.Lock()
sender.Err = err
sender.Lock.Unlock()
return
}
}
}()
return sender
}
func (sender *PacketSender) CloseSendCh() {
close(sender.SendCh)
}
func (sender *PacketSender) WaitForDone() {
<-sender.DoneCh
}
func (sender *PacketSender) checkStatus() error {
sender.Lock.Lock()
defer sender.Lock.Unlock()
if sender.Done {
return fmt.Errorf("cannot send packet, sender write loop is closed")
}
if sender.Err != nil {
return fmt.Errorf("cannot send packet, sender had error: %w", sender.Err)
}
return nil
}
func (sender *PacketSender) SendPacket(pk PacketType) error {
err := sender.checkStatus()
if err != nil {
return err
}
sender.SendCh <- pk
return nil
}
func (sender *PacketSender) SendErrorPacket(errVal string) error {
return sender.SendPacket(MakeErrorPacket(errVal))
}
func PacketParser(input io.Reader) chan PacketType { func PacketParser(input io.Reader) chan PacketType {
bufReader := bufio.NewReader(input) bufReader := bufio.NewReader(input)
rtnCh := make(chan PacketType) rtnCh := make(chan PacketType)

View File

@ -7,7 +7,6 @@
package shexec package shexec
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -24,11 +23,6 @@ import (
"github.com/scripthaus-dev/sh2-runner/pkg/packet" "github.com/scripthaus-dev/sh2-runner/pkg/packet"
) )
type DoneData struct {
DurationMs int64 `json:"durationms"`
ExitCode int `json:"exitcode"`
}
type ShExecType struct { type ShExecType struct {
FileNames *base.CommandFileNames FileNames *base.CommandFileNames
Cmd *exec.Cmd Cmd *exec.Cmd
@ -95,6 +89,15 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
return ecmd return ecmd
} }
func MakeRunnerExec(cmdId string) (*exec.Cmd, error) {
runnerPath, err := base.GetScRunnerPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(runnerPath, cmdId)
return ecmd, nil
}
// this will never return (unless there is an error creating/opening the file), as fifoFile will never EOF // this will never return (unless there is an error creating/opening the file), as fifoFile will never EOF
func MakeAndCopyStdinFifo(dst *os.File, fifoName string) error { func MakeAndCopyStdinFifo(dst *os.File, fifoName string) error {
os.Remove(fifoName) os.Remove(fifoName)
@ -147,8 +150,8 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
return nil return nil
} }
// returning nil error means the process has successfully been kicked-off // when err is nil, the command will have already been started
func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) { func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
if pk.CmdId == "" { if pk.CmdId == "" {
pk.CmdId = uuid.New().String() pk.CmdId = uuid.New().String()
} }
@ -184,14 +187,14 @@ func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) {
// copy pty output to .ptyout file // copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmdPty) _, copyErr := io.Copy(ptyOutFd, cmdPty)
if copyErr != nil { if copyErr != nil {
base.WriteErrorMsg(fileNames.PtyOutFile, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr)) sender.SendErrorPacket(fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
} }
}() }()
go func() { go func() {
// copy .stdin fifo contents to pty input // copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo) copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
if copyFifoErr != nil { if copyFifoErr != nil {
base.WriteErrorMsg(fileNames.PtyOutFile, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr)) sender.SendErrorPacket(fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
} }
}() }()
return &ShExecType{ return &ShExecType{
@ -202,9 +205,10 @@ func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) {
}, nil }, nil
} }
func (c *ShExecType) WaitForCommand() { func (c *ShExecType) WaitForCommand(cmdId string) *packet.CmdDonePacketType {
err := c.Cmd.Wait() err := c.Cmd.Wait()
cmdDuration := time.Since(c.StartTs) endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs)
exitCode := 0 exitCode := 0
if err != nil { if err != nil {
exitErr, ok := err.(*exec.ExitError) exitErr, ok := err.(*exec.ExitError)
@ -212,15 +216,11 @@ func (c *ShExecType) WaitForCommand() {
exitCode = exitErr.ExitCode() exitCode = exitErr.ExitCode()
} }
} }
doneData := DoneData{ donePacket := packet.MakeCmdDonePacket()
DurationMs: int64(cmdDuration / time.Millisecond), donePacket.Ts = endTs.UnixMilli()
ExitCode: exitCode, donePacket.CmdId = cmdId
} donePacket.ExitCode = exitCode
doneDataBytes, _ := json.Marshal(doneData) donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
doneDataBytes = append(doneDataBytes, '\n') os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)
err = os.WriteFile(c.FileNames.DoneFile, doneDataBytes, 0600) return donePacket
if err != nil {
base.WriteErrorMsg(c.FileNames.PtyOutFile, fmt.Sprintf("reading from stdin fifo: %v", err))
}
return
} }