diff --git a/main-runner.go b/main-runner.go index 12a1e5aa6..f8b8dc398 100644 --- a/main-runner.go +++ b/main-runner.go @@ -101,6 +101,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err))) return } + // touch ptyout file (should exist for tailer to work correctly) + ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))) + return + } + ptyOutFd.Close() // just opened to create the file, can close right after runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) if err != nil { sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))) @@ -199,6 +206,7 @@ func main() { return } doSingle(cmdId.String()) + time.Sleep(100 * time.Millisecond) return } else { doMain() diff --git a/pkg/cmdtail/cmdtail.go b/pkg/cmdtail/cmdtail.go index 9bf1a6f54..44b44cc58 100644 --- a/pkg/cmdtail/cmdtail.go +++ b/pkg/cmdtail/cmdtail.go @@ -10,15 +10,19 @@ import ( "fmt" "io" "os" + "regexp" "sync" "time" + "github.com/fsnotify/fsnotify" "github.com/google/uuid" "github.com/scripthaus-dev/sh2-runner/pkg/base" "github.com/scripthaus-dev/sh2-runner/pkg/packet" ) const MaxDataBytes = 4096 +const FileTypePty = "ptyout" +const FileTypeRun = "runout" type TailPos struct { ReqId string @@ -78,7 +82,7 @@ type Tailer struct { Lock *sync.Mutex WatchList map[CmdKey]CmdWatchEntry ScHomeDir string - Watcher *SessionWatcher + Watcher *fsnotify.Watcher SendCh chan packet.PacketType } @@ -91,6 +95,24 @@ func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos) t.WatchList[cmdKey] = entry } +func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) { + entry, found := t.WatchList[cmdKey] + if !found { + return + } + entry.removeTailPos(reqId) + if len(entry.Tails) > 0 { + t.WatchList[cmdKey] = entry + return + } + + // delete from watchlist, remove watches + fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey.SessionId, cmdKey.CmdId) + delete(t.WatchList, cmdKey) + t.Watcher.Remove(fileNames.PtyOutFile) + t.Watcher.Remove(fileNames.RunnerOutFile) +} + func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int64) { entry, found := t.WatchList[cmdKey] if !found { @@ -124,7 +146,7 @@ func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) { ScHomeDir: scHomeDir, SendCh: sendCh, } - rtn.Watcher, err = MakeSessionWatcher() + rtn.Watcher, err = fsnotify.NewWatcher() if err != nil { return nil, err } @@ -213,17 +235,12 @@ func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDat func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) { t.Lock.Lock() defer t.Lock.Unlock() - entry, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId) + _, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId) if !foundPos { return } if !pos.Follow { - entry.removeTailPos(reqId) - if len(entry.Tails) == 0 { - delete(t.WatchList, cmdKey) - } else { - t.WatchList[cmdKey] = entry - } + t.removeTailPos_nolock(cmdKey, reqId) } } @@ -241,9 +258,12 @@ func (t *Tailer) RunDataTransfer(key CmdKey, reqId string) { } } -// should already hold t.Lock func (t *Tailer) tryStartRun_nolock(entry CmdWatchEntry, pos TailPos) { - if pos.Running || pos.IsCurrent(entry) { + if pos.Running { + return + } + if pos.IsCurrent(entry) { + return } pos.Running = true @@ -251,22 +271,30 @@ func (t *Tailer) tryStartRun_nolock(entry CmdWatchEntry, pos TailPos) { go t.RunDataTransfer(entry.CmdKey, pos.ReqId) } -func (t *Tailer) updateFile(event FileUpdateEvent) { - if event.Err != nil { - t.SendCh <- packet.FmtMessagePacket("error in FileUpdateEvent %s/%s: %v", event.SessionId, event.CmdId, event.Err) +var updateFileRe = regexp.MustCompile("/([a-z0-9-]+)/([a-z0-9-]+)\\.(ptyout|runout)$") + +func (t *Tailer) updateFile(relFileName string) { + m := updateFileRe.FindStringSubmatch(relFileName) + if m == nil { return } - cmdKey := CmdKey{SessionId: event.SessionId, CmdId: event.CmdId} + finfo, err := os.Stat(relFileName) + if err != nil { + t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err) + return + } + cmdKey := CmdKey{SessionId: m[1], CmdId: m[2]} t.Lock.Lock() defer t.Lock.Unlock() entry, foundEntry := t.WatchList[cmdKey] if !foundEntry { return } - if event.FileType == FileTypePty { - entry.FilePtyLen = event.Size - } else if event.FileType == FileTypeRun { - entry.FileRunLen = event.Size + fileType := m[3] + if fileType == FileTypePty { + entry.FilePtyLen = finfo.Size() + } else if fileType == FileTypeRun { + entry.FileRunLen = finfo.Size() } t.WatchList[cmdKey] = entry for _, pos := range entry.Tails { @@ -274,14 +302,26 @@ func (t *Tailer) updateFile(event FileUpdateEvent) { } } -func (t *Tailer) Run() error { - go func() { - for event := range t.Watcher.EventCh { - t.updateFile(event) +func (t *Tailer) Run() { + for { + select { + case event, ok := <-t.Watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + t.updateFile(event.Name) + } + + case err, ok := <-t.Watcher.Errors: + if !ok { + return + } + // what to do with this error? just send a message + t.SendCh <- packet.FmtMessagePacket("error in tailer: %v", err) } - }() - err := t.Watcher.Run(nil) - return err + } + return } func (t *Tailer) Close() error { @@ -307,6 +347,13 @@ func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) { } } +func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) { + t.Lock.Lock() + defer t.Lock.Unlock() + key := CmdKey{pk.SessionId, pk.CmdId} + t.removeTailPos_nolock(key, pk.ReqId) +} + func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error { _, err := uuid.Parse(getPacket.SessionId) if err != nil { @@ -319,19 +366,34 @@ func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error { if getPacket.ReqId == "" { return fmt.Errorf("getcmd, no reqid specified") } + fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.SessionId, getPacket.CmdId) t.Lock.Lock() defer t.Lock.Unlock() key := CmdKey{getPacket.SessionId, getPacket.CmdId} - err = t.Watcher.WatchSession(getPacket.SessionId) - if err != nil { - return fmt.Errorf("error trying to watch sesion '%s': %v", getPacket.SessionId, err) - } entry, foundEntry := t.WatchList[key] if !foundEntry { + // add watches, initialize entry + err = t.Watcher.Add(fileNames.PtyOutFile) + if err != nil { + return err + } + err = t.Watcher.Add(fileNames.RunnerOutFile) + if err != nil { + t.Watcher.Remove(fileNames.PtyOutFile) // best effort clean up + return err + } entry = CmdWatchEntry{CmdKey: key} entry.fillFilePos(t.ScHomeDir) } - pos := TailPos{ReqId: getPacket.ReqId, TailPtyPos: getPacket.PtyPos, TailRunPos: getPacket.RunPos, Follow: getPacket.Tail} + pos, foundPos := entry.getTailPos(getPacket.ReqId) + if !foundPos { + // initialize a new tailpos + pos = TailPos{ReqId: getPacket.ReqId} + } + // update tailpos with new values from getpacket + pos.TailPtyPos = getPacket.PtyPos + pos.TailRunPos = getPacket.RunPos + pos.Follow = getPacket.Tail // convert negative pos to positive if pos.TailPtyPos < 0 { pos.TailPtyPos = max(0, entry.FilePtyLen+pos.TailPtyPos) // + because negative diff --git a/pkg/cmdtail/sessionwatcher.go b/pkg/cmdtail/sessionwatcher.go deleted file mode 100644 index 1e736cb89..000000000 --- a/pkg/cmdtail/sessionwatcher.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2022 Dashborg Inc -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -package cmdtail - -import ( - "fmt" - "os" - "path" - "regexp" - "sync" - - "github.com/fsnotify/fsnotify" - "github.com/google/uuid" - "github.com/scripthaus-dev/sh2-runner/pkg/base" -) - -const FileTypePty = "ptyout" -const FileTypeRun = "runout" -const eventChSize = 10 - -type FileUpdateEvent struct { - SessionId string - CmdId string - FileType string - Size int64 - Err error -} - -type SessionWatcher struct { - Lock *sync.Mutex - Sessions map[string]bool - ScHomeDir string - Watcher *fsnotify.Watcher - EventCh chan FileUpdateEvent - Err error - Running bool -} - -func MakeSessionWatcher() (*SessionWatcher, error) { - scHomeDir, err := base.GetScHomeDir() - if err != nil { - return nil, err - } - rtn := &SessionWatcher{ - Lock: &sync.Mutex{}, - Sessions: make(map[string]bool), - ScHomeDir: scHomeDir, - EventCh: make(chan FileUpdateEvent, eventChSize), - } - rtn.Watcher, err = fsnotify.NewWatcher() - if err != nil { - return nil, err - } - return rtn, nil -} - -func (w *SessionWatcher) Close() error { - return w.Watcher.Close() -} - -func (w *SessionWatcher) UnWatchSession(sessionId string) error { - _, err := uuid.Parse(sessionId) - if err != nil { - return fmt.Errorf("WatchSession, bad sessionid '%s': %w", sessionId, err) - } - w.Lock.Lock() - defer w.Lock.Unlock() - if !w.Sessions[sessionId] { - return nil - } - sessionDir := path.Join(w.ScHomeDir, base.SessionsDirBaseName, sessionId) - err = w.Watcher.Remove(sessionDir) - if err != nil { - return err - } - w.Sessions[sessionId] = false - return nil -} - -func (w *SessionWatcher) WatchSession(sessionId string) error { - _, err := uuid.Parse(sessionId) - if err != nil { - return fmt.Errorf("WatchSession, bad sessionid '%s': %w", sessionId, err) - } - - w.Lock.Lock() - defer w.Lock.Unlock() - if w.Sessions[sessionId] { - return nil - } - sessionDir := path.Join(w.ScHomeDir, base.SessionsDirBaseName, sessionId) - err = w.Watcher.Add(sessionDir) - if err != nil { - return err - } - w.Sessions[sessionId] = true - return nil -} - -func (w *SessionWatcher) setRunning() bool { - w.Lock.Lock() - defer w.Lock.Unlock() - if w.Running { - return false - } - w.Running = true - return true -} - -var swUpdateFileRe = regexp.MustCompile("/([a-z0-9-]+)/([a-z0-9-]+)\\.(ptyout|runout)$") - -func (w *SessionWatcher) updateFile(relFileName string) { - m := swUpdateFileRe.FindStringSubmatch(relFileName) - if m == nil { - return - } - event := FileUpdateEvent{SessionId: m[1], CmdId: m[2], FileType: m[3]} - finfo, err := os.Stat(relFileName) - if err != nil { - event.Err = err - w.EventCh <- event - return - } - event.Size = finfo.Size() - w.EventCh <- event - return -} - -func (w *SessionWatcher) Run(stopCh chan bool) error { - ok := w.setRunning() - if !ok { - return fmt.Errorf("Cannot run SessionWatcher (alreaady running)") - } - defer func() { - w.Lock.Lock() - defer w.Lock.Unlock() - w.Running = false - close(w.EventCh) - }() - for { - select { - case event, ok := <-w.Watcher.Events: - if !ok { - return nil - } - if (event.Op&fsnotify.Write == fsnotify.Write) || (event.Op&fsnotify.Create == fsnotify.Create) { - w.updateFile(event.Name) - } - - case err, ok := <-w.Watcher.Errors: - if !ok { - return nil - } - return fmt.Errorf("Got error in SessionWatcher: %w", err) - - case <-stopCh: - return nil - } - } - return nil -} diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 3b0e5c201..e872e34b6 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -7,10 +7,8 @@ package shexec import ( - "errors" "fmt" "io" - "io/fs" "os" "os/exec" "strings" @@ -163,8 +161,12 @@ func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecT if err != nil { return nil, err } - if _, err = os.Stat(fileNames.PtyOutFile); !errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("cmdid '%s' was already used", pk.CmdId) + ptyOutInfo, err := os.Stat(fileNames.PtyOutFile) + if err == nil { // non-nil error will be caught by regular OpenFile below + // must have size 0 + if ptyOutInfo.Size() != 0 { + return nil, fmt.Errorf("cmdid '%s' was already used (ptyout len=%d)", pk.CmdId, ptyOutInfo.Size()) + } } cmdPty, cmdTty, err := pty.Open() if err != nil {