go runner / runner-single fork flow working

This commit is contained in:
sawka 2022-06-10 21:37:21 -07:00
parent eeaeac8dc8
commit 1a3886c437
4 changed files with 321 additions and 73 deletions

View File

@ -11,23 +11,30 @@ import (
"os"
"os/signal"
"syscall"
"time"
"github.com/google/uuid"
"github.com/scripthaus-dev/sh2-runner/pkg/base"
"github.com/scripthaus-dev/sh2-runner/pkg/packet"
"github.com/scripthaus-dev/sh2-runner/pkg/shexec"
)
func setupSignals(cmd *shexec.ShExecType) {
// in single run mode, we don't want the runner to die from signals
// since we want the single runner to persist even if session / main runner
// is terminated.
func setupSingleSignals(cmd *shexec.ShExecType) {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
for sig := range sigCh {
cmd.Cmd.Process.Signal(sig)
for range sigCh {
// do nothing
}
}()
}
func main() {
func doSingle(cmdId string) {
packetCh := packet.PacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
var runPacket *packet.RunPacketType
for pk := range packetCh {
if pk.GetType() == packet.PingPacketStr {
@ -37,24 +44,134 @@ func main() {
runPacket, _ = pk.(*packet.RunPacketType)
break
}
if pk.GetType() == packet.ErrorPacketStr {
packet.SendPacket(os.Stdout, pk)
return
}
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType()))
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType()))
return
}
if runPacket == nil {
packet.SendErrorPacket(os.Stdout, "did not receive a 'run' packet")
sender.SendErrorPacket("did not receive a 'run' packet")
return
}
cmd, err := shexec.RunCommand(runPacket)
if runPacket.CmdId == "" {
runPacket.CmdId = cmdId
}
if runPacket.CmdId != cmdId {
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId))
return
}
cmd, err := shexec.RunCommand(runPacket, sender)
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("error running command: %v", err))
sender.SendErrorPacket(fmt.Sprintf("error running command: %v", err))
return
}
setupSignals(cmd)
packet.SendPacket(os.Stdout, packet.MakeOkCmdPacket(fmt.Sprintf("running command %s/%s", runPacket.SessionId, runPacket.CmdId), runPacket.CmdId, cmd.Cmd.Process.Pid))
cmd.WaitForCommand()
packet.SendPacket(os.Stdout, packet.MakeDonePacket())
setupSingleSignals(cmd)
startPacket := packet.MakeCmdStartPacket()
startPacket.Ts = time.Now().UnixMilli()
startPacket.CmdId = runPacket.CmdId
startPacket.Pid = cmd.Cmd.Process.Pid
startPacket.RunnerPid = os.Getpid()
sender.SendPacket(startPacket)
donePacket := cmd.WaitForCommand(runPacket.CmdId)
sender.SendPacket(donePacket)
sender.CloseSendCh()
sender.WaitForDone()
}
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
if pk.CmdId == "" {
pk.CmdId = uuid.New().String()
}
err := shexec.ValidateRunPacket(pk)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err)))
return
}
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err)))
return
}
cmd, err := shexec.MakeRunnerExec(pk.CmdId)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make runner command: %v", err)))
return
}
cmdStdin, err := cmd.StdinPipe()
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
return
}
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
return
}
cmd.Stdout = runnerOutFd
cmd.Stderr = runnerOutFd
err = cmd.Start()
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err)))
return
}
go func() {
err = packet.SendPacket(cmdStdin, pk)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err)))
return
}
cmdStdin.Close()
// clean up zombies
cmd.Wait()
}()
}
func doMain() {
homeDir, err := base.GetScHomeDir()
if err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
err = os.Chdir(homeDir)
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to scripthaus home '%s': %v", homeDir, err))
return
}
err = base.EnsureRunnerPath()
if err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
packetCh := packet.PacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
sender.SendPacket(packet.MakeMessagePacket(fmt.Sprintf("starting scripthaus runner @ %s", homeDir)))
for pk := range packetCh {
if pk.GetType() == packet.PingPacketStr {
continue
}
if pk.GetType() == packet.RunPacketStr {
doMainRun(pk.(*packet.RunPacketType), sender)
continue
}
if pk.GetType() == packet.ErrorPacketStr {
errPk := pk.(*packet.ErrorPacketType)
errPk.Error = "invalid packet sent to runner: " + errPk.Error
sender.SendPacket(errPk)
continue
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to runner", pk.GetType()))
}
}
func main() {
if len(os.Args) >= 2 {
cmdId, err := uuid.Parse(os.Args[1])
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to runner", err))
return
}
doSingle(cmdId.String())
return
} else {
doMain()
}
}

View File

@ -27,9 +27,9 @@ const ScReadyString = "scripthaus runner ready"
const OSCEscError = "error"
type CommandFileNames struct {
PtyOutFile string
StdinFifo string
DoneFile string
PtyOutFile string
StdinFifo string
RunnerOutFile string
}
func GetScHomeDir() (string, error) {
@ -54,9 +54,9 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err
}
base := path.Join(sdir, cmdId)
return &CommandFileNames{
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",
DoneFile: base + ".done",
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",
RunnerOutFile: base + ".runout",
}, nil
}
@ -108,32 +108,50 @@ func EnsureSessionDir(sessionId string) (string, error) {
return sdir, nil
}
func GetScRunnerPath() string {
func GetScRunnerPath() (string, error) {
runnerPath := os.Getenv(ScRunnerVarName)
if runnerPath != "" {
return runnerPath
return runnerPath, nil
}
scHome, err := GetScHomeDir()
if err != nil {
panic(err)
return "", err
}
return path.Join(scHome, RunnerBaseName)
return path.Join(scHome, RunnerBaseName), nil
}
func GetScSessionsDir() string {
scHome, err := GetScHomeDir()
func EnsureRunnerPath() error {
runnerPath, err := GetScRunnerPath()
if err != nil {
panic(err)
return err
}
return path.Join(scHome, SessionsDirBaseName)
info, err := os.Stat(runnerPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot find scripthaus runner at path '%s'", runnerPath)
}
return fmt.Errorf("error stating scripthaus runner at path '%s'", runnerPath)
}
if info.Mode()&0100 == 0 {
return fmt.Errorf("scripthaus runner at path '%s' is not executable mode=%#o", runnerPath, info.Mode())
}
return nil
}
func GetSessionDBName(sessionId string) string {
func GetScSessionsDir() (string, error) {
scHome, err := GetScHomeDir()
if err != nil {
panic(err)
return "", err
}
return path.Join(scHome, SessionDBName)
return path.Join(scHome, SessionsDirBaseName), nil
}
func GetSessionDBName(sessionId string) (string, error) {
scHome, err := GetScHomeDir()
if err != nil {
return "", err
}
return path.Join(scHome, SessionDBName), nil
}
// SH OSC Escapes (code 198, S=19, H=8)

View File

@ -11,13 +11,16 @@ import (
"encoding/json"
"fmt"
"io"
"sync"
)
const RunPacketStr = "run"
const PingPacketStr = "ping"
const DonePacketStr = "done"
const ErrorPacketStr = "error"
const OkCmdPacketStr = "okcmd"
const MessagePacketStr = "message"
const CmdStartPacketStr = "cmdstart"
const CmdDonePacketStr = "cmddone"
type PingPacketType struct {
Type string `json:"type"`
@ -31,6 +34,19 @@ func MakePingPacket() *PingPacketType {
return &PingPacketType{Type: PingPacketStr}
}
type MessagePacketType struct {
Type string `json:"type"`
Message string `json:"message"`
}
func (*MessagePacketType) GetType() string {
return MessagePacketStr
}
func MakeMessagePacket(message string) *MessagePacketType {
return &MessagePacketType{Type: MessagePacketStr, Message: message}
}
type DonePacketType struct {
Type string `json:"type"`
}
@ -43,27 +59,44 @@ func MakeDonePacket() *DonePacketType {
return &DonePacketType{Type: DonePacketStr}
}
type OkCmdPacketType struct {
Type string `json:"type"`
Message string `json:"message"`
CmdId string `json:"cmdid"`
Pid int `json:"pid"`
type CmdDonePacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
CmdId string `json:"cmdid"`
ExitCode int `json:"exitcode"`
DurationMs int64 `json:"durationms"`
}
func (*OkCmdPacketType) GetType() string {
return OkCmdPacketStr
func (*CmdDonePacketType) GetType() string {
return CmdDonePacketStr
}
func MakeOkCmdPacket(message string, cmdId string, pid int) *OkCmdPacketType {
return &OkCmdPacketType{Type: OkCmdPacketStr, Message: message, CmdId: cmdId, Pid: pid}
func MakeCmdDonePacket() *CmdDonePacketType {
return &CmdDonePacketType{Type: CmdDonePacketStr}
}
type CmdStartPacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
CmdId string `json:"cmdid"`
Pid int `json:"pid"`
RunnerPid int `json:"runnerpid"`
}
func (*CmdStartPacketType) GetType() string {
return CmdStartPacketStr
}
func MakeCmdStartPacket() *CmdStartPacketType {
return &CmdStartPacketType{Type: CmdStartPacketStr}
}
type RunPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
ChDir string `json:"chdir"`
Env map[string]string `json:"env"`
ChDir string `json:"chdir,omitempty"`
Env map[string]string `json:"env,omitempty"`
Command string `json:"command"`
}
@ -76,6 +109,7 @@ type BarePacketType struct {
}
type ErrorPacketType struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Error string `json:"error"`
}
@ -88,6 +122,10 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr}
}
func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr}
}
type PacketType interface {
GetType() string
}
@ -123,13 +161,21 @@ func ParseJsonPacket(jsonBuf []byte) (PacketType, error) {
}
return &errorPacket, nil
}
if bareCmd.Type == OkCmdPacketStr {
var okPacket OkCmdPacketType
err = json.Unmarshal(jsonBuf, &okPacket)
if bareCmd.Type == CmdStartPacketStr {
var startPacket CmdStartPacketType
err = json.Unmarshal(jsonBuf, &startPacket)
if err != nil {
return nil, err
}
return &okPacket, nil
return &startPacket, nil
}
if bareCmd.Type == CmdDonePacketStr {
var donePacket CmdDonePacketType
err = json.Unmarshal(jsonBuf, &donePacket)
if err != nil {
return nil, err
}
return &donePacket, nil
}
return nil, fmt.Errorf("invalid packet-type '%s'", bareCmd.Type)
}
@ -154,6 +200,73 @@ func SendErrorPacket(w io.Writer, errorStr string) error {
return SendPacket(w, MakeErrorPacket(errorStr))
}
type PacketSender struct {
Lock *sync.Mutex
SendCh chan PacketType
Err error
Done bool
DoneCh chan bool
}
func MakePacketSender(output io.Writer) *PacketSender {
sender := &PacketSender{
Lock: &sync.Mutex{},
SendCh: make(chan PacketType),
DoneCh: make(chan bool),
}
go func() {
defer func() {
sender.Lock.Lock()
sender.Done = true
sender.Lock.Unlock()
close(sender.DoneCh)
}()
for pk := range sender.SendCh {
err := SendPacket(output, pk)
if err != nil {
sender.Lock.Lock()
sender.Err = err
sender.Lock.Unlock()
return
}
}
}()
return sender
}
func (sender *PacketSender) CloseSendCh() {
close(sender.SendCh)
}
func (sender *PacketSender) WaitForDone() {
<-sender.DoneCh
}
func (sender *PacketSender) checkStatus() error {
sender.Lock.Lock()
defer sender.Lock.Unlock()
if sender.Done {
return fmt.Errorf("cannot send packet, sender write loop is closed")
}
if sender.Err != nil {
return fmt.Errorf("cannot send packet, sender had error: %w", sender.Err)
}
return nil
}
func (sender *PacketSender) SendPacket(pk PacketType) error {
err := sender.checkStatus()
if err != nil {
return err
}
sender.SendCh <- pk
return nil
}
func (sender *PacketSender) SendErrorPacket(errVal string) error {
return sender.SendPacket(MakeErrorPacket(errVal))
}
func PacketParser(input io.Reader) chan PacketType {
bufReader := bufio.NewReader(input)
rtnCh := make(chan PacketType)

View File

@ -7,7 +7,6 @@
package shexec
import (
"encoding/json"
"errors"
"fmt"
"io"
@ -24,11 +23,6 @@ import (
"github.com/scripthaus-dev/sh2-runner/pkg/packet"
)
type DoneData struct {
DurationMs int64 `json:"durationms"`
ExitCode int `json:"exitcode"`
}
type ShExecType struct {
FileNames *base.CommandFileNames
Cmd *exec.Cmd
@ -95,6 +89,15 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
return ecmd
}
func MakeRunnerExec(cmdId string) (*exec.Cmd, error) {
runnerPath, err := base.GetScRunnerPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(runnerPath, cmdId)
return ecmd, nil
}
// this will never return (unless there is an error creating/opening the file), as fifoFile will never EOF
func MakeAndCopyStdinFifo(dst *os.File, fifoName string) error {
os.Remove(fifoName)
@ -147,8 +150,8 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
return nil
}
// returning nil error means the process has successfully been kicked-off
func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) {
// when err is nil, the command will have already been started
func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
if pk.CmdId == "" {
pk.CmdId = uuid.New().String()
}
@ -184,14 +187,14 @@ func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) {
// copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmdPty)
if copyErr != nil {
base.WriteErrorMsg(fileNames.PtyOutFile, fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
sender.SendErrorPacket(fmt.Sprintf("copying pty output to ptyout file: %v", copyErr))
}
}()
go func() {
// copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
if copyFifoErr != nil {
base.WriteErrorMsg(fileNames.PtyOutFile, fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
sender.SendErrorPacket(fmt.Sprintf("reading from stdin fifo: %v", copyFifoErr))
}
}()
return &ShExecType{
@ -202,9 +205,10 @@ func RunCommand(pk *packet.RunPacketType) (*ShExecType, error) {
}, nil
}
func (c *ShExecType) WaitForCommand() {
func (c *ShExecType) WaitForCommand(cmdId string) *packet.CmdDonePacketType {
err := c.Cmd.Wait()
cmdDuration := time.Since(c.StartTs)
endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs)
exitCode := 0
if err != nil {
exitErr, ok := err.(*exec.ExitError)
@ -212,15 +216,11 @@ func (c *ShExecType) WaitForCommand() {
exitCode = exitErr.ExitCode()
}
}
doneData := DoneData{
DurationMs: int64(cmdDuration / time.Millisecond),
ExitCode: exitCode,
}
doneDataBytes, _ := json.Marshal(doneData)
doneDataBytes = append(doneDataBytes, '\n')
err = os.WriteFile(c.FileNames.DoneFile, doneDataBytes, 0600)
if err != nil {
base.WriteErrorMsg(c.FileNames.PtyOutFile, fmt.Sprintf("reading from stdin fifo: %v", err))
}
return
donePacket := packet.MakeCmdDonePacket()
donePacket.Ts = endTs.UnixMilli()
donePacket.CmdId = cmdId
donePacket.ExitCode = exitCode
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)
return donePacket
}