mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-08 00:21:23 +01:00
mshell single writes ping packets to detect when the server has died. sends SIGHUP to children
This commit is contained in:
parent
39e5e6c729
commit
f010758b36
@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
@ -194,6 +195,16 @@ func handleSingle(fromServer bool) {
|
||||
cmd.DetachedWait(startPk)
|
||||
return
|
||||
} else {
|
||||
shexec.IgnoreSigPipe()
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
// this will let the command detect when the server has gone away
|
||||
// that will then trigger cmd.SendHup() to send SIGHUP to the exec'ed process
|
||||
sender.SendPacket(packet.MakePingPacket())
|
||||
}
|
||||
}()
|
||||
defer ticker.Stop()
|
||||
cmd, err := shexec.RunCommandSimple(runPacket, sender, true)
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("error running command: %w", err))
|
||||
@ -202,6 +213,13 @@ func handleSingle(fromServer bool) {
|
||||
defer cmd.Close()
|
||||
startPacket := cmd.MakeCmdStartPacket(runPacket.ReqId)
|
||||
sender.SendPacket(startPacket)
|
||||
go func() {
|
||||
exitErr := sender.WaitForDone()
|
||||
if exitErr != nil {
|
||||
base.Logf("I/O error talking to server, sending SIGHUP to children\n")
|
||||
cmd.SendHup()
|
||||
}
|
||||
}()
|
||||
cmd.RunRemoteIOAndWait(packetParser, sender)
|
||||
return
|
||||
}
|
||||
|
@ -80,6 +80,7 @@ func InitDebugLog(prefix string) {
|
||||
return
|
||||
}
|
||||
DebugLogger = log.New(fd, prefix+" ", log.LstdFlags)
|
||||
Logf("logger initialized\n")
|
||||
}
|
||||
|
||||
func SetEnableDebugLog(enable bool) {
|
||||
|
@ -722,7 +722,7 @@ func (e *SendError) Error() string {
|
||||
if e.IsMarshalError {
|
||||
return fmt.Sprintf("SendPacket marshal-error '%s' packet: %v", e.PacketType, e.Err)
|
||||
} else if e.IsWriteError {
|
||||
return fmt.Sprintf("SendPacket write-error: %v", e.Err)
|
||||
return fmt.Sprintf("SendPacket write-error packet[%s]: %v", e.PacketType, e.Err)
|
||||
} else {
|
||||
return e.Err.Error()
|
||||
}
|
||||
@ -742,7 +742,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
|
||||
outBuf.Write(jsonBytes)
|
||||
outBuf.WriteByte('\n')
|
||||
if GlobalDebug {
|
||||
fmt.Printf("SEND> %s\n", AsString(packet))
|
||||
base.Logf("SEND> %s\n", AsString(packet))
|
||||
}
|
||||
outBytes := outBuf.Bytes()
|
||||
sanitizeBytes(outBytes)
|
||||
@ -763,6 +763,7 @@ type PacketSender struct {
|
||||
Done bool
|
||||
DoneCh chan bool
|
||||
ErrHandler func(*PacketSender, PacketType, error)
|
||||
ExitErr error
|
||||
}
|
||||
|
||||
func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketType, error)) *PacketSender {
|
||||
@ -784,6 +785,9 @@ func MakePacketSender(output io.Writer, errHandler func(*PacketSender, PacketTyp
|
||||
continue
|
||||
}
|
||||
// write errors are not recoverable
|
||||
sender.Lock.Lock()
|
||||
sender.ExitErr = err
|
||||
sender.Lock.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -825,10 +829,18 @@ func (sender *PacketSender) Close() {
|
||||
close(sender.SendCh)
|
||||
}
|
||||
|
||||
func (sender *PacketSender) WaitForDone() {
|
||||
// returns ExitErr if set
|
||||
func (sender *PacketSender) WaitForDone() error {
|
||||
<-sender.DoneCh
|
||||
sender.Lock.Lock()
|
||||
defer sender.Lock.Unlock()
|
||||
return sender.ExitErr
|
||||
}
|
||||
|
||||
// this is "advisory", as there is a race condition between the loop closing and setting Done.
|
||||
// that's okay because that's an impossible race condition anyway (you could enqueue the packet
|
||||
// and then the connection dies, or it dies half way, etc.). this just stops blindly adding
|
||||
// packets forever when the loop is done.
|
||||
func (sender *PacketSender) checkStatus() error {
|
||||
sender.Lock.Lock()
|
||||
defer sender.Lock.Unlock()
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"github.com/scripthaus-dev/mshell/pkg/simpleexpand"
|
||||
"github.com/scripthaus-dev/mshell/pkg/statediff"
|
||||
@ -453,10 +452,6 @@ func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
|
||||
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
|
||||
rtn.Funcs = removeFunc(rtn.Funcs, "_scripthaus_exittrap")
|
||||
lines := strings.Split(rtn.Funcs, "\n")
|
||||
for _, line := range lines {
|
||||
base.Logf("func-line: [%s]\n", line)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
|
@ -102,6 +102,7 @@ func MakeReturnStateBuf() *ReturnStateBuf {
|
||||
}
|
||||
|
||||
type ShExecType struct {
|
||||
Lock *sync.Mutex // only locks "Exited" field
|
||||
StartTs time.Time
|
||||
CK base.CommandKey
|
||||
FileNames *base.CommandFileNames
|
||||
@ -114,6 +115,7 @@ type ShExecType struct {
|
||||
RunnerOutFd *os.File
|
||||
MsgSender *packet.PacketSender // where to send out-of-band messages back to calling proceess
|
||||
ReturnState *ReturnStateBuf
|
||||
Exited bool // locked via Lock
|
||||
}
|
||||
|
||||
type StdContext struct{}
|
||||
@ -195,6 +197,7 @@ func (s ShExecUPR) UnknownPacket(pk packet.PacketType) {
|
||||
|
||||
func MakeShExec(ck base.CommandKey, upr packet.UnknownPacketReporter) *ShExecType {
|
||||
return &ShExecType{
|
||||
Lock: &sync.Mutex{},
|
||||
StartTs: time.Now(),
|
||||
CK: ck,
|
||||
Multiplexer: mpio.MakeMultiplexer(ck, upr),
|
||||
@ -1000,6 +1003,25 @@ trap _scripthaus_exittrap EXIT
|
||||
return fmt.Sprintf(fmtStr, stateCmd)
|
||||
}
|
||||
|
||||
func (s *ShExecType) SendHup() {
|
||||
base.Logf("sendhup start\n")
|
||||
if s.Cmd == nil || s.Cmd.Process == nil || s.IsExited() {
|
||||
return
|
||||
}
|
||||
pgroup := false
|
||||
if s.Cmd.SysProcAttr != nil && (s.Cmd.SysProcAttr.Setsid || s.Cmd.SysProcAttr.Setpgid) {
|
||||
pgroup = true
|
||||
}
|
||||
pid := s.Cmd.Process.Pid
|
||||
if pgroup {
|
||||
base.Logf("sendhup %d (pgroup)\n", -pid)
|
||||
syscall.Kill(-pid, syscall.SIGHUP)
|
||||
} else {
|
||||
base.Logf("sendhup %d (normal)\n", pid)
|
||||
syscall.Kill(pid, syscall.SIGHUP)
|
||||
}
|
||||
}
|
||||
|
||||
func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fromServer bool) (rtnShExec *ShExecType, rtnErr error) {
|
||||
state := pk.State
|
||||
if state == nil {
|
||||
@ -1172,7 +1194,7 @@ func (rs *ReturnStateBuf) Run() {
|
||||
// since we want mshell to persist even if the mshell --server is terminated
|
||||
func SetupSignalsForDetach() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGPIPE)
|
||||
go func() {
|
||||
for range sigCh {
|
||||
// do nothing
|
||||
@ -1180,6 +1202,18 @@ func SetupSignalsForDetach() {
|
||||
}()
|
||||
}
|
||||
|
||||
// in detached run mode, we don't want mshell to die from signals
|
||||
// since we want mshell to persist even if the mshell --server is terminated
|
||||
func IgnoreSigPipe() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGPIPE)
|
||||
go func() {
|
||||
for sig := range sigCh {
|
||||
base.Logf("ignoring signal %v\n", sig)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func copyToCirFile(dest *cirfile.File, src io.Reader) error {
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
@ -1308,9 +1342,23 @@ func GetExitCode(err error) int {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ShExecType) ProcWait() error {
|
||||
exitErr := c.Cmd.Wait()
|
||||
c.Lock.Lock()
|
||||
c.Exited = true
|
||||
c.Lock.Unlock()
|
||||
return exitErr
|
||||
}
|
||||
|
||||
func (c *ShExecType) IsExited() bool {
|
||||
c.Lock.Lock()
|
||||
defer c.Lock.Unlock()
|
||||
return c.Exited
|
||||
}
|
||||
|
||||
func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
donePacket := packet.MakeCmdDonePacket(c.CK)
|
||||
exitErr := c.Cmd.Wait()
|
||||
exitErr := c.ProcWait()
|
||||
if c.ReturnState != nil {
|
||||
<-c.ReturnState.DoneCh
|
||||
state, _ := ParseShellStateOutput(c.ReturnState.Buf) // TODO what to do with error?
|
||||
@ -1318,9 +1366,8 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
}
|
||||
endTs := time.Now()
|
||||
cmdDuration := endTs.Sub(c.StartTs)
|
||||
exitCode := GetExitCode(exitErr)
|
||||
donePacket.Ts = endTs.UnixMilli()
|
||||
donePacket.ExitCode = exitCode
|
||||
donePacket.ExitCode = GetExitCode(exitErr)
|
||||
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
||||
if c.FileNames != nil {
|
||||
os.Remove(c.FileNames.StdinFifo) // best effort (no need to check error)
|
||||
|
Loading…
Reference in New Issue
Block a user