From ad2cab595d6881df0ec429560fa5b2d14f60ca20 Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 5 Dec 2022 15:38:44 -0800 Subject: [PATCH] kill server on I/O write error, and add a pinger to continually send ping packets to test connection --- pkg/server/server.go | 103 +++++++++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index 9ff06dff9..daf5e5bf0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -24,13 +24,15 @@ import ( // TODO create unblockable packet-sender (backed by an array) for clientproc type MServer struct { - Lock *sync.Mutex - MainInput *packet.PacketParser - Sender *packet.PacketSender - ClientMap map[base.CommandKey]*shexec.ClientProc - Debug bool - StateMap map[string]*packet.ShellState // sha1->state - CurrentState string // sha1 + Lock *sync.Mutex + MainInput *packet.PacketParser + Sender *packet.PacketSender + ClientMap map[base.CommandKey]*shexec.ClientProc + Debug bool + StateMap map[string]*packet.ShellState // sha1->state + CurrentState string // sha1 + WriteErrorCh chan bool // closed if there is a I/O write error + WriteErrorChOnce *sync.Once } func (m *MServer) Close() { @@ -255,34 +257,16 @@ func (m *MServer) packetSenderErrorHandler(sender *packet.PacketSender, pk packe msg.CK = cpk.GetCK() } sender.SendPacket(msg) + return + } else { + // I/O error: close the WriteErrorCh to signal that we are dead (cannot continue if we can't write output) + m.WriteErrorChOnce.Do(func() { + close(m.WriteErrorCh) + }) } - // otherwise ignore (we can't output anything for a I/O error) } -func RunServer() (int, error) { - debug := false - if len(os.Args) >= 3 && os.Args[2] == "--debug" { - debug = true - } - server := &MServer{ - Lock: &sync.Mutex{}, - ClientMap: make(map[base.CommandKey]*shexec.ClientProc), - StateMap: make(map[string]*packet.ShellState), - Debug: debug, - } - if debug { - packet.GlobalDebug = true - } - server.MainInput = packet.MakePacketParser(os.Stdin) - server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler) - defer server.Close() - var err error - initPacket, err := shexec.MakeServerInitPacket() - if err != nil { - return 1, err - } - server.setCurrentState(initPacket.State) - server.Sender.SendPacket(initPacket) +func (server *MServer) runReadLoop() { builder := packet.MakeRunPacketBuilder() for pk := range server.MainInput.MainCh { if server.Debug { @@ -307,5 +291,60 @@ func RunServer() (int, error) { server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk)) continue } +} + +func RunServer() (int, error) { + debug := false + if len(os.Args) >= 3 && os.Args[2] == "--debug" { + debug = true + } + server := &MServer{ + Lock: &sync.Mutex{}, + ClientMap: make(map[base.CommandKey]*shexec.ClientProc), + StateMap: make(map[string]*packet.ShellState), + Debug: debug, + WriteErrorCh: make(chan bool), + WriteErrorChOnce: &sync.Once{}, + } + if debug { + packet.GlobalDebug = true + } + server.MainInput = packet.MakePacketParser(os.Stdin) + server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler) + defer server.Close() + var err error + initPacket, err := shexec.MakeServerInitPacket() + if err != nil { + return 1, err + } + server.setCurrentState(initPacket.State) + server.Sender.SendPacket(initPacket) + ticker := time.NewTicker(1 * time.Minute) + go func() { + for range ticker.C { + server.Sender.SendPacket(packet.MakePingPacket()) + } + }() + defer ticker.Stop() + readLoopDoneCh := make(chan bool) + + go func() { + defer close(readLoopDoneCh) + server.runReadLoop() + }() + + go func() { + time.Sleep(5 * time.Second) + respPk := packet.MakeResponsePacket("NA", make(chan bool)) + server.Sender.SendPacket(respPk) + }() + + select { + case <-readLoopDoneCh: + break + + case <-server.WriteErrorCh: + break + } return 0, nil }