diff --git a/pkg/server/server.go b/pkg/server/server.go index ad4b3367b..d0cc680fd 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -166,7 +166,7 @@ func (m *MServer) runCommand(runPacket *packet.RunPacketType) { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err)) return } - cproc, _, err := shexec.MakeClientProc(ecmd) + cproc, _, err := shexec.MakeClientProc(context.Background(), ecmd) if err != nil { m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err)) return diff --git a/pkg/shexec/client.go b/pkg/shexec/client.go index fb6e6e4e4..c14b17e59 100644 --- a/pkg/shexec/client.go +++ b/pkg/shexec/client.go @@ -1,6 +1,7 @@ package shexec import ( + "context" "fmt" "io" "os/exec" @@ -24,7 +25,7 @@ type ClientProc struct { } // returns (clientproc, uname, error) -func MakeClientProc(ecmd *exec.Cmd) (*ClientProc, string, error) { +func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, string, error) { inputWriter, err := ecmd.StdinPipe() if err != nil { return nil, "", fmt.Errorf("creating stdin pipe: %v", err) @@ -55,7 +56,15 @@ func MakeClientProc(ecmd *exec.Cmd) (*ClientProc, string, error) { Input: sender, Output: packetParser, } - for pk := range packetParser.MainCh { + + var pk packet.PacketType + select { + case pk = <-packetParser.MainCh: + case <-ctx.Done(): + cproc.Close() + return nil, "", ctx.Err() + } + if pk != nil { if pk.GetType() != packet.InitPacketStr { cproc.Close() return nil, "", fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk)) @@ -70,7 +79,6 @@ func MakeClientProc(ecmd *exec.Cmd) (*ClientProc, string, error) { return nil, initPk.UName, fmt.Errorf("invalid remote mshell version 'v%s', must be v%s", initPk.Version, base.MShellVersion) } cproc.InitPk = initPk - break } if cproc.InitPk == nil { cproc.Close()