mshell single writes ping packets to detect when the server has died. sends SIGHUP to children

This commit is contained in:
sawka 2022-12-05 22:26:13 -08:00
parent 39e5e6c729
commit f010758b36
5 changed files with 85 additions and 12 deletions

View File

@ -12,6 +12,7 @@ import (
"os"
"strconv"
"strings"
"time"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
@ -194,6 +195,16 @@ func handleSingle(fromServer bool) {
cmd.DetachedWait(startPk)
return
} else {
shexec.IgnoreSigPipe()
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
// this will let the command detect when the server has gone away
// that will then trigger cmd.SendHup() to send SIGHUP to the exec'ed process
sender.SendPacket(packet.MakePingPacket())
}
}()
defer ticker.Stop()
cmd, err := shexec.RunCommandSimple(runPacket, sender, true)
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("error running command: %w", err))
@ -202,6 +213,13 @@ func handleSingle(fromServer bool) {
defer cmd.Close()
startPacket := cmd.MakeCmdStartPacket(runPacket.ReqId)
sender.SendPacket(startPacket)
go func() {
exitErr := sender.WaitForDone()
if exitErr != nil {
base.Logf("I/O error talking to server, sending SIGHUP to children\n")
cmd.SendHup()
}
}()
cmd.RunRemoteIOAndWait(packetParser, sender)
return
}

View File

@ -80,6 +80,7 @@ func InitDebugLog(prefix string) {
return
}
DebugLogger = log.New(fd, prefix+" ", log.LstdFlags)
Logf("logger initialized\n")
}
func SetEnableDebugLog(enable bool) {

View File

@ -722,7 +722,7 @@ func (e *SendError) Error() string {
if e.IsMarshalError {
return fmt.Sprintf("SendPacket marshal-error '%s' packet: %v", e.PacketType, e.Err)
} else if e.IsWriteError {
return fmt.Sprintf("SendPacket write-error: %v", e.Err)
return fmt.Sprintf("SendPacket write-error packet[%s]: %v", e.PacketType, e.Err)
} else {
return e.Err.Error()
}
@ -742,7 +742,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
outBuf.Write(jsonBytes)
outBuf.WriteByte('\n')
if GlobalDebug {
fmt.Printf("SEND> %s\n", AsString(packet))
base.Logf("SEND> %s\n", AsString(packet))
}
outBytes := outBuf.Bytes()
sanitizeBytes(outBytes)
@ -763,6 +763,7 @@ type PacketSender struct {
Done bool
DoneCh chan bool
ErrHandler func(*PacketSender, PacketType, error)
ExitErr error
}
func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketType, error)) *PacketSender {
@ -784,6 +785,9 @@ func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketTyp
continue
}
// write errors are not recoverable
sender.Lock.Lock()
sender.ExitErr = err
sender.Lock.Unlock()
return
}
}
@ -825,10 +829,18 @@ func (sender *PacketSender) Close() {
close(sender.SendCh)
}
func (sender *PacketSender) WaitForDone() {
// returns ExitErr if set
func (sender *PacketSender) WaitForDone() error {
<-sender.DoneCh
sender.Lock.Lock()
defer sender.Lock.Unlock()
return sender.ExitErr
}
// this is "advisory", as there is a race condition between the loop closing and setting Done.
// that's okay because that's an impossible race condition anyway (you could enqueue the packet
// and then the connection dies, or it dies half way, etc.). this just stops blindly adding
// packets forever when the loop is done.
func (sender *PacketSender) checkStatus() error {
sender.Lock.Lock()
defer sender.Lock.Unlock()

View File

@ -9,7 +9,6 @@ import (
"strings"
"github.com/alessio/shellescape"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/simpleexpand"
"github.com/scripthaus-dev/mshell/pkg/statediff"
@ -453,10 +452,6 @@ func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = removeFunc(rtn.Funcs, "_scripthaus_exittrap")
lines := strings.Split(rtn.Funcs, "\n")
for _, line := range lines {
base.Logf("func-line: [%s]\n", line)
}
return rtn, nil
}

View File

@ -102,6 +102,7 @@ func MakeReturnStateBuf() *ReturnStateBuf {
}
type ShExecType struct {
Lock *sync.Mutex // only locks "Exited" field
StartTs time.Time
CK base.CommandKey
FileNames *base.CommandFileNames
@ -114,6 +115,7 @@ type ShExecType struct {
RunnerOutFd *os.File
MsgSender *packet.PacketSender // where to send out-of-band messages back to calling proceess
ReturnState *ReturnStateBuf
Exited bool // locked via Lock
}
type StdContext struct{}
@ -195,6 +197,7 @@ func (s ShExecUPR) UnknownPacket(pk packet.PacketType) {
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
return &ShExecType{
Lock: &sync.Mutex{},
StartTs: time.Now(),
CK: ck,
Multiplexer: mpio.MakeMultiplexer(ck, upr),
@ -1000,6 +1003,25 @@ trap _scripthaus_exittrap EXIT
return fmt.Sprintf(fmtStr, stateCmd)
}
func (s *ShExecType) SendHup() {
base.Logf("sendhup start\n")
if s.Cmd == nil || s.Cmd.Process == nil || s.IsExited() {
return
}
pgroup := false
if s.Cmd.SysProcAttr != nil && (s.Cmd.SysProcAttr.Setsid || s.Cmd.SysProcAttr.Setpgid) {
pgroup = true
}
pid := s.Cmd.Process.Pid
if pgroup {
base.Logf("sendhup %d (pgroup)\n", -pid)
syscall.Kill(-pid, syscall.SIGHUP)
} else {
base.Logf("sendhup %d (normal)\n", pid)
syscall.Kill(pid, syscall.SIGHUP)
}
}
func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (rtnShExec *ShExecType, rtnErr error) {
state := pk.State
if state == nil {
@ -1172,7 +1194,7 @@ func (rs *ReturnStateBuf) Run() {
// since we want mshell to persist even if the mshell --server is terminated
func SetupSignalsForDetach() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGPIPE)
go func() {
for range sigCh {
// do nothing
@ -1180,6 +1202,18 @@ func SetupSignalsForDetach() {
}()
}
// in detached run mode, we don't want mshell to die from signals
// since we want mshell to persist even if the mshell --server is terminated
func IgnoreSigPipe() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGPIPE)
go func() {
for sig := range sigCh {
base.Logf("ignoring signal %v\n", sig)
}
}()
}
func copyToCirFile(dest *cirfile.File, src io.Reader) error {
buf := make([]byte, 64*1024)
for {
@ -1308,9 +1342,23 @@ func GetExitCode(err error) int {
}
}
func (c *ShExecType) ProcWait() error {
exitErr := c.Cmd.Wait()
c.Lock.Lock()
c.Exited = true
c.Lock.Unlock()
return exitErr
}
func (c *ShExecType) IsExited() bool {
c.Lock.Lock()
defer c.Lock.Unlock()
return c.Exited
}
func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
donePacket := packet.MakeCmdDonePacket(c.CK)
exitErr := c.Cmd.Wait()
exitErr := c.ProcWait()
if c.ReturnState != nil {
<-c.ReturnState.DoneCh
state, _ := ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
@ -1318,9 +1366,8 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
}
endTs := time.Now()
cmdDuration := endTs.Sub(c.StartTs)
exitCode := GetExitCode(exitErr)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = exitCode
donePacket.ExitCode = GetExitCode(exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
if c.FileNames != nil {
os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)