checkpoint for tightened runtime semantics for calls -- always send response packets, make sure correct response ids are set, etc.

This commit is contained in:
sawka 2022-07-05 23:14:14 -07:00
parent 0c204e8b2b
commit 96123c8e1a
6 changed files with 236 additions and 123 deletions

View File

@ -13,6 +13,7 @@ import (
"strings" "strings"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/server" "github.com/scripthaus-dev/mshell/pkg/server"
"github.com/scripthaus-dev/mshell/pkg/shexec" "github.com/scripthaus-dev/mshell/pkg/shexec"
@ -73,13 +74,13 @@ import (
// }() // }()
// } // }
// func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error { func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
// err := tailer.AddWatch(pk) err := tailer.AddWatch(pk)
// if err != nil { if err != nil {
// return err return err
// } }
// return nil return nil
// } }
// func doMain() { // func doMain() {
// homeDir := base.GetHomeDir() // homeDir := base.GetHomeDir()
@ -176,11 +177,16 @@ func handleSingle() {
return return
} }
if runPacket.Detached { if runPacket.Detached {
err := shexec.RunCommandDetached(runPacket, sender) cmd, startPk, err := shexec.RunCommandDetached(runPacket, sender)
if err != nil { if err != nil {
sender.SendErrorResponse(runPacket.ReqId, err) sender.SendErrorResponse(runPacket.ReqId, err)
return return
} }
sender.SendPacket(startPk)
sender.Close()
sender.WaitForDone()
cmd.DetachedWait(startPk)
return
} else { } else {
cmd, err := shexec.RunCommandSimple(runPacket, sender) cmd, err := shexec.RunCommandSimple(runPacket, sender)
if err != nil { if err != nil {
@ -188,7 +194,7 @@ func handleSingle() {
return return
} }
defer cmd.Close() defer cmd.Close()
startPacket := cmd.MakeCmdStartPacket() startPacket := cmd.MakeCmdStartPacket(runPacket.ReqId)
sender.SendPacket(startPacket) sender.SendPacket(startPacket)
cmd.RunRemoteIOAndWait(packetParser, sender) cmd.RunRemoteIOAndWait(packetParser, sender)
return return

View File

@ -7,6 +7,7 @@
package cmdtail package cmdtail
import ( import (
"encoding/base64"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -89,6 +90,12 @@ func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos
t.WatchList[cmdKey] = entry t.WatchList[cmdKey] = entry
} }
func (t *Tailer) removeTailPos(cmdKey base.CommandKey, reqId string) {
t.Lock.Lock()
defer t.Lock.Unlock()
t.removeTailPos_nolock(cmdKey, reqId)
}
func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) { func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
entry, found := t.WatchList[cmdKey] entry, found := t.WatchList[cmdKey]
if !found { if !found {
@ -107,16 +114,6 @@ func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
t.Watcher.Remove(fileNames.RunnerOutFile) t.Watcher.Remove(fileNames.RunnerOutFile)
} }
func (t *Tailer) updateEntrySizes_nolock(cmdKey base.CommandKey, ptyLen int64, runLen int64) {
entry, found := t.WatchList[cmdKey]
if !found {
return
}
entry.FilePtyLen = ptyLen
entry.FileRunLen = runLen
t.WatchList[cmdKey] = entry
}
func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) { func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) {
entry, found := t.WatchList[cmdKey] entry, found := t.WatchList[cmdKey]
if !found { if !found {
@ -159,90 +156,98 @@ func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]b
return buf[0:nr], nil return buf[0:nr], nil
} }
func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) *packet.CmdDataPacketType { func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) (*packet.CmdDataPacketType, error) {
dataPacket := packet.MakeCmdDataPacket() dataPacket := packet.MakeCmdDataPacket(pos.ReqId)
dataPacket.RespId = pos.ReqId
dataPacket.CK = entry.CmdKey dataPacket.CK = entry.CmdKey
dataPacket.PtyPos = pos.TailPtyPos dataPacket.PtyPos = pos.TailPtyPos
dataPacket.RunPos = pos.TailRunPos dataPacket.RunPos = pos.TailRunPos
if entry.FilePtyLen > pos.TailPtyPos { if entry.FilePtyLen > pos.TailPtyPos {
ptyData, err := t.readDataFromFile(fileNames.PtyOutFile, pos.TailPtyPos, MaxDataBytes) ptyData, err := t.readDataFromFile(fileNames.PtyOutFile, pos.TailPtyPos, MaxDataBytes)
if err != nil { if err != nil {
dataPacket.Error = err.Error() return nil, err
return dataPacket
} }
dataPacket.PtyData = string(ptyData) dataPacket.PtyData64 = base64.StdEncoding.EncodeToString(ptyData)
dataPacket.PtyDataLen = len(ptyData) dataPacket.PtyDataLen = len(ptyData)
} }
if entry.FileRunLen > pos.TailRunPos { if entry.FileRunLen > pos.TailRunPos {
runData, err := t.readDataFromFile(fileNames.RunnerOutFile, pos.TailRunPos, MaxDataBytes) runData, err := t.readDataFromFile(fileNames.RunnerOutFile, pos.TailRunPos, MaxDataBytes)
if err != nil { if err != nil {
dataPacket.Error = err.Error() return nil, err
return dataPacket
} }
dataPacket.RunData = string(runData) dataPacket.RunData64 = base64.StdEncoding.EncodeToString(runData)
dataPacket.RunDataLen = len(runData) dataPacket.RunDataLen = len(runData)
} }
return dataPacket return dataPacket, nil
} }
// returns (data-packet, keepRunning) // returns (data-packet, keepRunning)
func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool) { func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool, error) {
t.Lock.Lock() t.Lock.Lock()
entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId) entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId)
t.Lock.Unlock() t.Lock.Unlock()
if !foundPos { if !foundPos {
return nil, false return nil, false, nil
} }
fileNames := base.MakeCommandFileNamesWithHome(t.MHomeDir, key) fileNames := base.MakeCommandFileNamesWithHome(t.MHomeDir, key)
dataPacket := t.makeCmdDataPacket(fileNames, entry, pos) dataPacket, dataErr := t.makeCmdDataPacket(fileNames, entry, pos)
t.Lock.Lock() t.Lock.Lock()
defer t.Lock.Unlock() defer t.Lock.Unlock()
entry, pos, foundPos = t.getEntryAndPos_nolock(key, reqId) entry, pos, foundPos = t.getEntryAndPos_nolock(key, reqId)
if !foundPos { if !foundPos {
return nil, false return nil, false, nil
} }
// pos was updated between first and second get, throw out data-packet and re-run // pos was updated between first and second get, throw out data-packet and re-run
if pos.TailPtyPos != dataPacket.PtyPos || pos.TailRunPos != dataPacket.RunPos { if pos.TailPtyPos != dataPacket.PtyPos || pos.TailRunPos != dataPacket.RunPos {
return nil, true return nil, true, nil
} }
if dataPacket.Error != "" { if dataErr != nil {
// error, so return error packet, and stop running // error, so return error packet, and stop running
pos.Running = false pos.Running = false
t.updateTailPos_nolock(key, reqId, pos) t.updateTailPos_nolock(key, reqId, pos)
return dataPacket, false return nil, false, dataErr
} }
pos.TailPtyPos += int64(len(dataPacket.PtyData)) pos.TailPtyPos += int64(dataPacket.PtyDataLen)
pos.TailRunPos += int64(len(dataPacket.RunData)) pos.TailRunPos += int64(dataPacket.RunDataLen)
if pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen { if pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen {
// we caught up, tail position equals file length // we caught up, tail position equals file length
pos.Running = false pos.Running = false
} }
t.updateTailPos_nolock(key, reqId, pos) t.updateTailPos_nolock(key, reqId, pos)
return dataPacket, pos.Running return dataPacket, pos.Running, nil
} }
func (t *Tailer) checkRemoveNoFollow(cmdKey base.CommandKey, reqId string) { // returns (removed)
func (t *Tailer) checkRemoveNoFollow(cmdKey base.CommandKey, reqId string) bool {
t.Lock.Lock() t.Lock.Lock()
defer t.Lock.Unlock() defer t.Lock.Unlock()
_, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId) _, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId)
if !foundPos { if !foundPos {
return return false
} }
if !pos.Follow { if !pos.Follow {
t.removeTailPos_nolock(cmdKey, reqId) t.removeTailPos_nolock(cmdKey, reqId)
return true
} }
return false
} }
func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) { func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) {
for { for {
dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId) dataPacket, keepRunning, err := t.runSingleDataTransfer(key, reqId)
if dataPacket != nil { if dataPacket != nil {
t.Sender.SendPacket(dataPacket) t.Sender.SendPacket(dataPacket)
} }
if err != nil {
t.removeTailPos(key, reqId)
t.Sender.SendErrorResponse(reqId, err)
break
}
if !keepRunning { if !keepRunning {
t.checkRemoveNoFollow(key, reqId) removed := t.checkRemoveNoFollow(key, reqId)
if removed {
t.Sender.SendResponse(reqId, true)
}
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -254,7 +259,6 @@ func (t *Tailer) tryStartRun_nolock(entry CmdWatchEntry, pos TailPos) {
return return
} }
if pos.IsCurrent(entry) { if pos.IsCurrent(entry) {
return return
} }
pos.Running = true pos.Running = true
@ -344,6 +348,19 @@ func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) {
t.removeTailPos_nolock(pk.CK, pk.ReqId) t.removeTailPos_nolock(pk.CK, pk.ReqId)
} }
func (t *Tailer) AddFileWatches_nolock(fileNames *base.CommandFileNames) error {
err := t.Watcher.Add(fileNames.PtyOutFile)
if err != nil {
return err
}
err = t.Watcher.Add(fileNames.RunnerOutFile)
if err != nil {
t.Watcher.Remove(fileNames.PtyOutFile) // best effort clean up
return err
}
return nil
}
func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error { func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
if err := getPacket.CK.Validate("getcmd"); err != nil { if err := getPacket.CK.Validate("getcmd"); err != nil {
return err return err
@ -357,16 +374,7 @@ func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
key := getPacket.CK key := getPacket.CK
entry, foundEntry := t.WatchList[key] entry, foundEntry := t.WatchList[key]
if !foundEntry { if !foundEntry {
// add watches, initialize entry // initialize entry, add watches
err := t.Watcher.Add(fileNames.PtyOutFile)
if err != nil {
return err
}
err = t.Watcher.Add(fileNames.RunnerOutFile)
if err != nil {
t.Watcher.Remove(fileNames.PtyOutFile) // best effort clean up
return err
}
entry = CmdWatchEntry{CmdKey: key} entry = CmdWatchEntry{CmdKey: key}
entry.fillFilePos(t.MHomeDir) entry.fillFilePos(t.MHomeDir)
} }
@ -387,6 +395,14 @@ func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
pos.TailRunPos = max(0, entry.FileRunLen+pos.TailRunPos) // + because negative pos.TailRunPos = max(0, entry.FileRunLen+pos.TailRunPos) // + because negative
} }
entry.updateTailPos(pos.ReqId, pos) entry.updateTailPos(pos.ReqId, pos)
if !pos.Follow && pos.IsCurrent(entry) {
// don't add to t.WatchList, don't t.AddFileWatches_nolock, send rpc response
go func() { t.Sender.SendResponse(getPacket.ReqId, true) }()
return nil
}
if !foundEntry {
t.AddFileWatches_nolock(fileNames)
}
t.WatchList[key] = entry t.WatchList[key] = entry
t.tryStartRun_nolock(entry, pos) t.tryStartRun_nolock(entry, pos)
return nil return nil

View File

@ -109,12 +109,10 @@ type CmdDataPacketType struct {
PtyLen int64 `json:"ptylen"` PtyLen int64 `json:"ptylen"`
RunPos int64 `json:"runpos"` RunPos int64 `json:"runpos"`
RunLen int64 `json:"runlen"` RunLen int64 `json:"runlen"`
PtyData string `json:"ptydata"` PtyData64 string `json:"ptydata64"`
PtyDataLen int `json:"ptydatalen"` PtyDataLen int `json:"ptydatalen"`
RunData string `json:"rundata"` RunData64 string `json:"rundata64"`
RunDataLen int `json:"rundatalen"` RunDataLen int `json:"rundatalen"`
Error string `json:"error"`
NotFound bool `json:"notfound,omitempty"`
} }
func (*CmdDataPacketType) GetType() string { func (*CmdDataPacketType) GetType() string {
@ -125,8 +123,12 @@ func (p *CmdDataPacketType) GetResponseId() string {
return p.RespId return p.RespId
} }
func MakeCmdDataPacket() *CmdDataPacketType { func (*CmdDataPacketType) GetResponseDone() bool {
return &CmdDataPacketType{Type: CmdDataPacketStr} return false
}
func MakeCmdDataPacket(reqId string) *CmdDataPacketType {
return &CmdDataPacketType{Type: CmdDataPacketStr, RespId: reqId}
} }
type PingPacketType struct { type PingPacketType struct {
@ -326,6 +328,10 @@ func (p *ResponsePacketType) GetResponseId() string {
return p.RespId return p.RespId
} }
func (*ResponsePacketType) GetResponseDone() bool {
return true
}
func MakeErrorResponsePacket(reqId string, err error) *ResponsePacketType { func MakeErrorResponsePacket(reqId string, err error) *ResponsePacketType {
return &ResponsePacketType{Type: ResponsePacketStr, RespId: reqId, Error: err.Error()} return &ResponsePacketType{Type: ResponsePacketStr, RespId: reqId, Error: err.Error()}
} }
@ -421,8 +427,8 @@ func (p *CmdDonePacketType) GetCK() base.CommandKey {
return p.CK return p.CK
} }
func MakeCmdDonePacket() *CmdDonePacketType { func MakeCmdDonePacket(ck base.CommandKey) *CmdDonePacketType {
return &CmdDonePacketType{Type: CmdDonePacketStr} return &CmdDonePacketType{Type: CmdDonePacketStr, CK: ck}
} }
type CmdStartPacketType struct { type CmdStartPacketType struct {
@ -442,8 +448,12 @@ func (p *CmdStartPacketType) GetResponseId() string {
return p.RespId return p.RespId
} }
func MakeCmdStartPacket() *CmdStartPacketType { func (*CmdStartPacketType) GetResponseDone() bool {
return &CmdStartPacketType{Type: CmdStartPacketStr} return true
}
func MakeCmdStartPacket(reqId string) *CmdStartPacketType {
return &CmdStartPacketType{Type: CmdStartPacketStr, RespId: reqId}
} }
type TermSize struct { type TermSize struct {
@ -534,6 +544,7 @@ type RpcPacketType interface {
type RpcResponsePacketType interface { type RpcResponsePacketType interface {
GetType() string GetType() string
GetResponseId() string GetResponseId() string
GetResponseDone() bool
} }
type CommandPacketType interface { type CommandPacketType interface {

View File

@ -8,6 +8,7 @@ package packet
import ( import (
"bufio" "bufio"
"context"
"io" "io"
"strconv" "strconv"
"strings" "strings"
@ -17,9 +18,15 @@ import (
type PacketParser struct { type PacketParser struct {
Lock *sync.Mutex Lock *sync.Mutex
MainCh chan PacketType MainCh chan PacketType
RpcMap map[string]*RpcEntry
Err error Err error
} }
type RpcEntry struct {
ReqId string
RespCh chan RpcResponsePacketType
}
func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser { func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser {
rtnParser := &PacketParser{ rtnParser := &PacketParser{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
@ -46,6 +53,70 @@ func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser) *PacketParser {
return rtnParser return rtnParser
} }
// should have already registered rpc
func (p *PacketParser) WaitForResponse(ctx context.Context, reqId string) RpcResponsePacketType {
entry := p.getRpcEntry(reqId, false)
if entry == nil {
return nil
}
defer p.UnRegisterRpc(reqId)
select {
case resp := <-entry.RespCh:
return resp
case <-ctx.Done():
return nil
}
}
func (p *PacketParser) UnRegisterRpc(reqId string) {
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[reqId]
if entry != nil {
close(entry.RespCh)
delete(p.RpcMap, reqId)
}
}
func (p *PacketParser) RegisterRpc(reqId string, queueSize int) chan RpcResponsePacketType {
p.Lock.Lock()
defer p.Lock.Unlock()
ch := make(chan RpcResponsePacketType, queueSize)
entry := &RpcEntry{ReqId: reqId, RespCh: ch}
p.RpcMap[reqId] = entry
return ch
}
func (p *PacketParser) getRpcEntry(reqId string, remove bool) *RpcEntry {
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[reqId]
if entry != nil && remove {
delete(p.RpcMap, reqId)
close(entry.RespCh)
}
return entry
}
func (p *PacketParser) trySendRpcResponse(respPk RpcResponsePacketType) bool {
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[respPk.GetResponseId()]
if entry == nil {
return false
}
// nonblocking send
select {
case entry.RespCh <- respPk:
default:
}
if respPk.GetResponseDone() {
delete(p.RpcMap, respPk.GetResponseId())
close(entry.RespCh)
}
return true
}
func (p *PacketParser) GetErr() error { func (p *PacketParser) GetErr() error {
p.Lock.Lock() p.Lock.Lock()
defer p.Lock.Unlock() defer p.Lock.Unlock()
@ -108,6 +179,12 @@ func MakePacketParser(input io.Reader) *PacketParser {
if pk.GetType() == PingPacketStr { if pk.GetType() == PingPacketStr {
continue continue
} }
if respPk, ok := pk.(RpcResponsePacketType); ok {
sent := parser.trySendRpcResponse(respPk)
if sent {
continue
}
}
parser.MainCh <- pk parser.MainCh <- pk
} }
}() }()

View File

@ -161,6 +161,8 @@ func RunServer() (int, error) {
if server.Debug { if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk)) fmt.Printf("PK> %s\n", packet.AsString(pk))
} }
// run-start combo
ok, runPacket := builder.ProcessPacket(pk) ok, runPacket := builder.ProcessPacket(pk)
if server.Debug { if server.Debug {
fmt.Printf("PP> %s | %v\n", pk.GetType(), ok) fmt.Printf("PP> %s | %v\n", pk.GetType(), ok)
@ -179,6 +181,8 @@ func RunServer() (int, error) {
server.Sender.SendPacket(startPk) server.Sender.SendPacket(startPk)
continue continue
} }
// command packet
if cmdPk, ok := pk.(packet.CommandPacketType); ok { if cmdPk, ok := pk.(packet.CommandPacketType); ok {
server.ProcessCommandPacket(cmdPk) server.ProcessCommandPacket(cmdPk)
continue continue

View File

@ -129,8 +129,8 @@ func (c *ShExecType) Close() {
} }
} }
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType { func (c *ShExecType) MakeCmdStartPacket(reqId string) *packet.CmdStartPacketType {
startPacket := packet.MakeCmdStartPacket() startPacket := packet.MakeCmdStartPacket(reqId)
startPacket.Ts = time.Now().UnixMilli() startPacket.Ts = time.Now().UnixMilli()
startPacket.CK = c.CK startPacket.CK = c.CK
startPacket.Pid = c.Cmd.Process.Pid startPacket.Pid = c.Cmd.Process.Pid
@ -848,21 +848,67 @@ func SetupSignalsForDetach() {
}() }()
} }
func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) error { func (cmd *ShExecType) DetachedWait(startPacket *packet.CmdStartPacketType) {
// after Start(), any output/errors must go to DetachedOutput
// close stdin/stdout/stderr, but wait for cmdstart packet to get sent
nullFd, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
if err != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("cannot open /dev/null: %w", err))
}
if nullFd != nil {
err := unix.Dup2(int(nullFd.Fd()), int(os.Stdin.Fd()))
if err != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdout.Fd()))
if err != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
err = unix.Dup2(int(nullFd.Fd()), int(os.Stderr.Fd()))
if err != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
}
cmd.DetachedOutput.SendPacket(startPacket)
ptyOutFd, err := os.OpenFile(cmd.FileNames.PtyOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("cannot open ptyout file '%s': %w", cmd.FileNames.PtyOutFile, err))
// don't return (command is already running)
}
go func() {
// copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmd.CmdPty)
if copyErr != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("copying pty output to ptyout file: %w", copyErr))
}
}()
go func() {
// copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmd.CmdPty, cmd.FileNames.StdinFifo)
if copyFifoErr != nil {
cmd.DetachedOutput.SendCmdError(cmd.CK, fmt.Errorf("reading from stdin fifo: %w", copyFifoErr))
}
}()
donePacket := cmd.WaitForCommand()
cmd.DetachedOutput.SendPacket(donePacket)
return
}
func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, *packet.CmdStartPacketType, error) {
fileNames, err := base.GetCommandFileNames(pk.CK) fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil { if err != nil {
return err return nil, nil, err
} }
ptyOutInfo, err := os.Stat(fileNames.PtyOutFile) ptyOutInfo, err := os.Stat(fileNames.PtyOutFile)
if err == nil { // non-nil error will be caught by regular OpenFile below if err == nil { // non-nil error will be caught by regular OpenFile below
// must have size 0 // must have size 0
if ptyOutInfo.Size() != 0 { if ptyOutInfo.Size() != 0 {
return fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size()) return nil, nil, fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size())
} }
} }
cmdPty, cmdTty, err := pty.Open() cmdPty, cmdTty, err := pty.Open()
if err != nil { if err != nil {
return fmt.Errorf("opening new pty: %w", err) return nil, nil, fmt.Errorf("opening new pty: %w", err)
} }
pty.Setsize(cmdPty, GetWinsize(pk)) pty.Setsize(cmdPty, GetWinsize(pk))
defer func() { defer func() {
@ -874,72 +920,26 @@ func RunCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) e
cmd.Detached = true cmd.Detached = true
cmd.RunnerOutFd, err = os.OpenFile(fileNames.RunnerOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) cmd.RunnerOutFd, err = os.OpenFile(fileNames.RunnerOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil { if err != nil {
return fmt.Errorf("cannot open runout file '%s': %w", fileNames.RunnerOutFile, err) return nil, nil, fmt.Errorf("cannot open runout file '%s': %w", fileNames.RunnerOutFile, err)
}
nullFd, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("cannot open /dev/null: %w", err)
} }
cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd) cmd.DetachedOutput = packet.MakePacketSender(cmd.RunnerOutFd)
ecmd, err := MakeDetachedExecCmd(pk, cmdTty) ecmd, err := MakeDetachedExecCmd(pk, cmdTty)
if err != nil { if err != nil {
return err return nil, nil, err
} }
cmd.Cmd = ecmd cmd.Cmd = ecmd
SetupSignalsForDetach() SetupSignalsForDetach()
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
return fmt.Errorf("starting command: %w", err) return nil, nil, fmt.Errorf("starting command: %w", err)
} }
for _, fd := range ecmd.ExtraFiles { for _, fd := range ecmd.ExtraFiles {
if fd != cmdTty { if fd != cmdTty {
fd.Close() fd.Close()
} }
} }
// after Start(), any errors must go to DetachedOutput startPacket := cmd.MakeCmdStartPacket(pk.ReqId)
// close stdin/stdout/stderr, but wait for cmdstart packet to get sent return cmd, startPacket, nil
startPacket := cmd.MakeCmdStartPacket()
go func() {
sender.SendPacket(startPacket)
sender.Close()
sender.WaitForDone()
fmt.Printf("sender done! start: %v\n", startPacket)
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdin.Fd()))
if err != nil {
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
err = unix.Dup2(int(nullFd.Fd()), int(os.Stdout.Fd()))
if err != nil {
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
err = unix.Dup2(int(nullFd.Fd()), int(os.Stderr.Fd()))
if err != nil {
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot dup2 stdin to /dev/null: %w", err))
}
cmd.DetachedOutput.SendPacket(startPacket)
}()
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("cannot open ptyout file '%s': %w", fileNames.PtyOutFile, err))
// don't return (command is already running)
}
go func() {
// copy pty output to .ptyout file
_, copyErr := io.Copy(ptyOutFd, cmdPty)
if copyErr != nil {
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("copying pty output to ptyout file: %w", copyErr))
}
}()
go func() {
// copy .stdin fifo contents to pty input
copyFifoErr := MakeAndCopyStdinFifo(cmdPty, fileNames.StdinFifo)
if copyFifoErr != nil {
cmd.DetachedOutput.SendCmdError(pk.CK, fmt.Errorf("reading from stdin fifo: %w", copyFifoErr))
}
}()
donePacket := cmd.WaitForCommand()
cmd.DetachedOutput.SendPacket(donePacket)
return nil
} }
func GetExitCode(err error) int { func GetExitCode(err error) int {
@ -958,9 +958,8 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
endTs := time.Now() endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs) cmdDuration := endTs.Sub(c.StartTs)
exitCode := GetExitCode(exitErr) exitCode := GetExitCode(exitErr)
donePacket := packet.MakeCmdDonePacket() donePacket := packet.MakeCmdDonePacket(c.CK)
donePacket.Ts = endTs.UnixMilli() donePacket.Ts = endTs.UnixMilli()
donePacket.CK = c.CK
donePacket.ExitCode = exitCode donePacket.ExitCode = exitCode
donePacket.DurationMs = int64(cmdDuration / time.Millisecond) donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
if c.FileNames != nil { if c.FileNames != nil {