From ecceb67f2063353ee94ba7d90be13b731d5d6b40 Mon Sep 17 00:00:00 2001 From: sawka Date: Tue, 14 Jun 2022 22:16:58 -0700 Subject: [PATCH] got session/command tailing working. server can send getcmd packets, and client responds with cmddata packets --- go.mod | 2 + go.sum | 4 + main-runner.go | 37 ++++++- pkg/base/base.go | 19 +++- pkg/cmdtail/cmdtail.go | 220 +++++++++++++++++++++++++++++++++++++++-- pkg/packet/packet.go | 80 ++++++++++++++- 6 files changed, 350 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index a794172c0..e64ed5405 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,7 @@ go 1.17 require ( github.com/creack/pty v1.1.18 // indirect + github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/google/uuid v1.3.0 // indirect + golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect ) diff --git a/go.sum b/go.sum index 12d5eadca..fd7b5ca7f 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,8 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main-runner.go b/main-runner.go index 19b0ba3c2..43a32904f 100644 --- a/main-runner.go +++ b/main-runner.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/scripthaus-dev/sh2-runner/pkg/base" + "github.com/scripthaus-dev/sh2-runner/pkg/cmdtail" "github.com/scripthaus-dev/sh2-runner/pkg/packet" "github.com/scripthaus-dev/sh2-runner/pkg/shexec" ) @@ -125,15 +126,26 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { }() } +func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error { + // non-tail packets? + sender.SendPacket(packet.MakeMessagePacket(fmt.Sprintf("getcmd %s", pk.CmdId))) + err := tailer.AddWatch(pk) + if err != nil { + return err + } + return nil +} + func doMain() { - homeDir, err := base.GetScHomeDir() + scHomeDir, err := base.GetScHomeDir() if err != nil { packet.SendErrorPacket(os.Stdout, err.Error()) return } + homeDir := base.GetHomeDir() err = os.Chdir(homeDir) if err != nil { - packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to scripthaus home '%s': %v", homeDir, err)) + packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err)) return } err = base.EnsureRunnerPath() @@ -143,7 +155,18 @@ func doMain() { } packetCh := packet.PacketParser(os.Stdin) sender := packet.MakePacketSender(os.Stdout) - sender.SendPacket(packet.MakeMessagePacket(fmt.Sprintf("starting scripthaus runner @ %s", homeDir))) + tailer, err := cmdtail.MakeTailer(sender) + if err != nil { + packet.SendErrorPacket(os.Stdout, err.Error()) + return + } + go tailer.Run() + sender.SendPacket(packet.MakeMessagePacket(fmt.Sprintf("starting scripthaus runner @ %s", scHomeDir))) + initPacket := packet.MakeRunnerInitPacket() + initPacket.Env = os.Environ() + initPacket.HomeDir = homeDir + initPacket.ScHomeDir = scHomeDir + sender.SendPacket(initPacket) for pk := range packetCh { if pk.GetType() == packet.PingPacketStr { continue @@ -152,6 +175,14 @@ func doMain() { doMainRun(pk.(*packet.RunPacketType), sender) continue } + if pk.GetType() == packet.GetCmdPacketStr { + err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender) + if err != nil { + errPk := packet.MakeErrorPacket(err.Error()) + sender.SendPacket(errPk) + } + continue + } if pk.GetType() == packet.ErrorPacketStr { errPk := pk.(*packet.ErrorPacketType) errPk.Error = "invalid packet sent to runner: " + errPk.Error diff --git a/pkg/base/base.go b/pkg/base/base.go index fd2d64ed0..5da63c01e 100644 --- a/pkg/base/base.go +++ b/pkg/base/base.go @@ -32,6 +32,14 @@ type CommandFileNames struct { RunnerOutFile string } +func GetHomeDir() string { + homeVar := os.Getenv(HomeVarName) + if homeVar == "" { + return "/" + } + return homeVar +} + func GetScHomeDir() (string, error) { scHome := os.Getenv(ScHomeVarName) if scHome == "" { @@ -60,6 +68,15 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err }, nil } +func MakeCommandFileNamesWithHome(scHome string, sessionId string, cmdId string) *CommandFileNames { + base := path.Join(scHome, SessionsDirBaseName, sessionId, cmdId) + return &CommandFileNames{ + PtyOutFile: base + ".ptyout", + StdinFifo: base + ".stdin", + RunnerOutFile: base + ".runout", + } +} + func CleanUpCmdFiles(sessionId string, cmdId string) error { if cmdId == "" { return fmt.Errorf("bad cmdid, cannot clean up") @@ -90,7 +107,7 @@ func EnsureSessionDir(sessionId string) (string, error) { if err != nil { return "", err } - sdir := path.Join(shhome, ".sessions", sessionId) + sdir := path.Join(shhome, SessionsDirBaseName, sessionId) info, err := os.Stat(sdir) if errors.Is(err, fs.ErrNotExist) { err = os.MkdirAll(sdir, 0777) diff --git a/pkg/cmdtail/cmdtail.go b/pkg/cmdtail/cmdtail.go index f95c5d822..219ec2369 100644 --- a/pkg/cmdtail/cmdtail.go +++ b/pkg/cmdtail/cmdtail.go @@ -7,17 +7,30 @@ package cmdtail import ( + "fmt" + "io" + "os" + "path" + "regexp" "sync" + "time" "github.com/fsnotify/fsnotify" + "github.com/google/uuid" + "github.com/scripthaus-dev/sh2-runner/pkg/base" "github.com/scripthaus-dev/sh2-runner/pkg/packet" ) +const MaxDataBytes = 4096 + type TailPos struct { - CmdKey CmdKey - Pos int - RunOut bool - RunOutPos int + CmdKey CmdKey + Running bool // an active tailer sending data + Version int + FilePtyLen int64 + FileRunLen int64 + TailPtyPos int64 + TailRunPos int64 } type CmdKey struct { @@ -30,15 +43,22 @@ type Tailer struct { WatchList map[CmdKey]TailPos Sessions map[string]bool Watcher *fsnotify.Watcher + ScHomeDir string + Sender *packet.PacketSender } -func MakeTailer() (*Tailer, error) { +func MakeTailer(sender *packet.PacketSender) (*Tailer, error) { + scHomeDir, err := base.GetScHomeDir() + if err != nil { + return nil, err + } rtn := &Tailer{ Lock: &sync.Mutex{}, WatchList: make(map[CmdKey]TailPos), Sessions: make(map[string]bool), + ScHomeDir: scHomeDir, + Sender: sender, } - var err error rtn.Watcher, err = fsnotify.NewWatcher() if err != nil { return nil, err @@ -46,6 +66,192 @@ func MakeTailer() (*Tailer, error) { return rtn, nil } -func AddWatch(getPacket *packet.GetCmdPacketType) error { +func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]byte, error) { + fd, err := os.Open(fileName) + defer fd.Close() + if err != nil { + return nil, err + } + buf := make([]byte, maxBytes) + nr, err := fd.ReadAt(buf, pos) + if err != nil && err != io.EOF { // ignore EOF error + return nil, err + } + return buf[0:nr], nil +} + +func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, pos TailPos) *packet.CmdDataPacketType { + dataPacket := packet.MakeCmdDataPacket() + dataPacket.SessionId = pos.CmdKey.SessionId + dataPacket.CmdId = pos.CmdKey.CmdId + dataPacket.PtyPos = pos.TailPtyPos + dataPacket.RunPos = pos.TailRunPos + if pos.FilePtyLen > pos.TailPtyPos { + ptyData, err := t.readDataFromFile(fileNames.PtyOutFile, pos.TailPtyPos, MaxDataBytes) + if err != nil { + dataPacket.Error = err.Error() + return dataPacket + } + dataPacket.PtyData = string(ptyData) + } + if pos.FileRunLen > pos.TailRunPos { + runData, err := t.readDataFromFile(fileNames.RunnerOutFile, pos.TailRunPos, MaxDataBytes) + if err != nil { + dataPacket.Error = err.Error() + return dataPacket + } + dataPacket.RunData = string(runData) + } + return dataPacket +} + +var updateFileRe = regexp.MustCompile("/([a-z0-9-]+)/([a-z0-9-]+)\\.(ptyout|runout)$") + +// returns (data-packet, keepRunning) +func (t *Tailer) runSingleDataTransfer(key CmdKey) (*packet.CmdDataPacketType, bool) { + t.Lock.Lock() + pos, foundPos := t.WatchList[key] + t.Lock.Unlock() + if !foundPos { + return nil, false + } + fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key.SessionId, key.CmdId) + dataPacket := t.makeCmdDataPacket(fileNames, pos) + + t.Lock.Lock() + defer t.Lock.Unlock() + pos, foundPos = t.WatchList[key] + if !foundPos { + return nil, false + } + // pos was updated between first and second get, throw out data-packet and re-run + if pos.TailPtyPos != dataPacket.PtyPos || pos.TailRunPos != dataPacket.RunPos { + return nil, true + } + if dataPacket.Error != "" { + // error, so return error packet, and stop running + pos.Running = false + t.WatchList[key] = pos + return dataPacket, false + } + pos.TailPtyPos += int64(len(dataPacket.PtyData)) + pos.TailRunPos += int64(len(dataPacket.RunData)) + if pos.TailPtyPos > pos.FilePtyLen { + pos.FilePtyLen = pos.TailPtyPos + } + if pos.TailRunPos > pos.FileRunLen { + pos.FileRunLen = pos.TailRunPos + } + if pos.TailPtyPos >= pos.FilePtyLen && pos.TailRunPos >= pos.FileRunLen { + // we caught up, tail position equals file length + pos.Running = false + } + t.WatchList[key] = pos + return dataPacket, pos.Running +} + +func (t *Tailer) RunDataTransfer(key CmdKey) { + for { + dataPacket, keepRunning := t.runSingleDataTransfer(key) + if dataPacket != nil { + t.Sender.SendPacket(dataPacket) + } + if !keepRunning { + break + } + time.Sleep(10 * time.Millisecond) + } +} + +func (t *Tailer) UpdateFile(relFileName string) { + m := updateFileRe.FindStringSubmatch(relFileName) + if m == nil { + return + } + finfo, err := os.Stat(relFileName) + if err != nil { + t.Sender.SendMessage("error stating file '%s': %w", relFileName, err) + return + } + isPtyFile := m[3] == "ptyout" + cmdKey := CmdKey{m[1], m[2]} + fileSize := finfo.Size() + t.Lock.Lock() + defer t.Lock.Unlock() + pos, foundPos := t.WatchList[cmdKey] + if !foundPos { + return + } + if isPtyFile { + pos.FilePtyLen = fileSize + } else { + pos.FileRunLen = fileSize + } + t.WatchList[cmdKey] = pos + if !pos.Running && (pos.FilePtyLen > pos.TailPtyPos || pos.FileRunLen > pos.TailRunPos) { + go t.RunDataTransfer(cmdKey) + } +} + +func (t *Tailer) Run() { + for { + select { + case event, ok := <-t.Watcher.Events: + if !ok { + return + } + if (event.Op&fsnotify.Write == fsnotify.Write) || (event.Op&fsnotify.Create == fsnotify.Create) { + t.UpdateFile(event.Name) + } + + case err, ok := <-t.Watcher.Errors: + if !ok { + return + } + // what to do with watcher error? + t.Sender.SendMessage("error in tailer '%v'", err) + } + } +} + +func (tp *TailPos) fillFilePos(scHomeDir string) { + fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, tp.CmdKey.SessionId, tp.CmdKey.CmdId) + ptyInfo, _ := os.Stat(fileNames.PtyOutFile) + if ptyInfo != nil { + tp.FilePtyLen = ptyInfo.Size() + } + runoutInfo, _ := os.Stat(fileNames.RunnerOutFile) + if runoutInfo != nil { + tp.FileRunLen = runoutInfo.Size() + } +} + +func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error { + if !getPacket.Tail { + return fmt.Errorf("cannot add a watch for non-tail packet") + } + _, err := uuid.Parse(getPacket.SessionId) + if err != nil { + return fmt.Errorf("getcmd, bad sessionid '%s': %w", getPacket.SessionId, err) + } + _, err = uuid.Parse(getPacket.CmdId) + if err != nil { + return fmt.Errorf("getcmd, bad cmdid '%s': %w", getPacket.CmdId, err) + } + t.Lock.Lock() + defer t.Lock.Unlock() + key := CmdKey{getPacket.SessionId, getPacket.CmdId} + if !t.Sessions[getPacket.SessionId] { + sessionDir := path.Join(t.ScHomeDir, base.SessionsDirBaseName, getPacket.SessionId) + err = t.Watcher.Add(sessionDir) + if err != nil { + return fmt.Errorf("error adding watcher for session dir '%s': %v", sessionDir, err) + } + t.Sessions[getPacket.SessionId] = true + } + oldPos := t.WatchList[key] + pos := TailPos{CmdKey: key, TailPtyPos: getPacket.PtyPos, TailRunPos: getPacket.RunPos, Version: oldPos.Version + 1} + pos.fillFilePos(t.ScHomeDir) + t.WatchList[key] = pos return nil } diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 0a5f24065..987650eb1 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -24,6 +24,10 @@ const CmdStartPacketStr = "cmdstart" const CmdDonePacketStr = "cmddone" const ListCmdPacketStr = "lscmd" const GetCmdPacketStr = "getcmd" +const RunnerInitPacketStr = "runnerinit" +const CdPacketStr = "cd" +const CdResponseStr = "cdresp" +const CmdDataPacketStr = "cmddata" var TypeStrToFactory map[string]reflect.Type @@ -38,6 +42,11 @@ func init() { TypeStrToFactory[CmdDonePacketStr] = reflect.TypeOf(CmdDonePacketType{}) TypeStrToFactory[ListCmdPacketStr] = reflect.TypeOf(ListCmdPacketType{}) TypeStrToFactory[GetCmdPacketStr] = reflect.TypeOf(GetCmdPacketType{}) + TypeStrToFactory[RunnerInitPacketStr] = reflect.TypeOf(RunnerInitPacketType{}) + TypeStrToFactory[CdPacketStr] = reflect.TypeOf(CdPacketType{}) + TypeStrToFactory[CdResponseStr] = reflect.TypeOf(CdResponseType{}) + TypeStrToFactory[CmdDataPacketStr] = reflect.TypeOf(CmdDataPacketType{}) + } func MakePacket(packetType string) (PacketType, error) { @@ -49,6 +58,26 @@ func MakePacket(packetType string) (PacketType, error) { return rtn.Interface().(PacketType), nil } +type CmdDataPacketType struct { + Type string `json:"type"` + SessionId string `json:"sessionid"` + CmdId string `json:"cmdid"` + PtyPos int64 `json:"ptypos"` + RunPos int64 `json:"runpos"` + PtyData string `json:"ptydata"` + RunData string `json:"rundata"` + Done bool `json:"done"` + Error string `json:"error"` +} + +func (*CmdDataPacketType) GetType() string { + return CmdDataPacketStr +} + +func MakeCmdDataPacket() *CmdDataPacketType { + return &CmdDataPacketType{Type: CmdDataPacketStr} +} + type PingPacketType struct { Type string `json:"type"` } @@ -65,8 +94,9 @@ type GetCmdPacketType struct { Type string `json:"type"` SessionId string `json:"sessionid"` CmdId string `json:"cmdid"` + PtyPos int64 `json:"ptypos"` + RunPos int64 `json:"runpos"` Tail bool `json:"tail,omitempty"` - RunOut bool `json:"runout,omitempty"` } func (*GetCmdPacketType) GetType() string { @@ -90,6 +120,35 @@ func MakeListCmdPacket(sessionId string) *ListCmdPacketType { return &ListCmdPacketType{Type: ListCmdPacketStr, SessionId: sessionId} } +type CdPacketType struct { + Type string `json:"type"` + PacketId string `json:"packetid"` + Dir string `json:"dir"` +} + +func (*CdPacketType) GetType() string { + return CdPacketStr +} + +func MakeCdPacket() *CdPacketType { + return &CdPacketType{Type: CdPacketStr} +} + +type CdResponseType struct { + Type string `json:"type"` + PacketId string `json:"packetid"` + Success bool `json:"success"` + Error string `json:"error"` +} + +func (*CdResponseType) GetType() string { + return CdResponseStr +} + +func MakeCdResponse() *CdResponseType { + return &CdResponseType{Type: CdResponseStr} +} + type MessagePacketType struct { Type string `json:"type"` Message string `json:"message"` @@ -103,6 +162,21 @@ func MakeMessagePacket(message string) *MessagePacketType { return &MessagePacketType{Type: MessagePacketStr, Message: message} } +type RunnerInitPacketType struct { + Type string `json:"type"` + ScHomeDir string `json:"schomedir"` + HomeDir string `json:"homedir"` + Env []string `json:"env"` +} + +func (*RunnerInitPacketType) GetType() string { + return RunnerInitPacketStr +} + +func MakeRunnerInitPacket() *RunnerInitPacketType { + return &RunnerInitPacketType{Type: RunnerInitPacketStr} +} + type DonePacketType struct { Type string `json:"type"` } @@ -297,6 +371,10 @@ func (sender *PacketSender) SendErrorPacket(errVal string) error { return sender.SendPacket(MakeErrorPacket(errVal)) } +func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) error { + return sender.SendPacket(MakeMessagePacket(fmt.Sprintf(fmtStr, args...))) +} + func PacketParser(input io.Reader) chan PacketType { bufReader := bufio.NewReader(input) rtnCh := make(chan PacketType)