From a1b82349544d9aa7ae6b9025e13e9ecd50f46192 Mon Sep 17 00:00:00 2001 From: sawka Date: Wed, 6 Jul 2022 12:16:37 -0700 Subject: [PATCH] fix tty TERM for ssh connections when usepty is set. also ignore pty read errors --- pkg/mpio/bufreader.go | 8 +++++++- pkg/mpio/mpio.go | 6 +++--- pkg/shexec/shexec.go | 13 +++++++------ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pkg/mpio/bufreader.go b/pkg/mpio/bufreader.go index bcda317e4..68af4cb52 100644 --- a/pkg/mpio/bufreader.go +++ b/pkg/mpio/bufreader.go @@ -21,9 +21,10 @@ type FdReader struct { BufSize int Closed bool ShouldCloseFd bool + IsPty bool } -func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool) *FdReader { +func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool, isPty bool) *FdReader { fr := &FdReader{ CVar: sync.NewCond(&sync.Mutex{}), M: m, @@ -31,6 +32,7 @@ func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd boo Fd: fd, BufSize: 0, ShouldCloseFd: shouldCloseFd, + IsPty: isPty, } return fr } @@ -136,6 +138,10 @@ func (r *FdReader) ReadLoop(wg *sync.WaitGroup) { } } if err != nil { + if r.IsPty { + r.WriteWait(nil, true) + return + } errPk := r.M.makeDataPacket(r.FdNum, nil, err) r.M.sendPacket(errPk) return diff --git a/pkg/mpio/mpio.go b/pkg/mpio/mpio.go index 4fcea20c0..f16a411c0 100644 --- a/pkg/mpio/mpio.go +++ b/pkg/mpio/mpio.go @@ -89,7 +89,7 @@ func (m *Multiplexer) MakeReaderPipe(fdNum int) (*os.File, error) { } m.Lock.Lock() defer m.Lock.Unlock() - m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum, true) + m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum, true, false) m.CloseAfterStart = append(m.CloseAfterStart, pw) return pw, nil } @@ -125,10 +125,10 @@ func (m *Multiplexer) MakeStaticWriterPipe(fdNum int, data []byte) (*os.File, er return pr, nil } -func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool) { +func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool, isPty bool) { m.Lock.Lock() defer m.Lock.Unlock() - m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose) + m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose, isPty) } func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd io.WriteCloser, shouldClose bool) { diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 96a7d3e1d..bf935e7bf 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -150,7 +150,7 @@ func UpdateCmdEnv(cmd *exec.Cmd, envVars map[string]string) { if len(envVars) == 0 { return } - if cmd.Env != nil { + if cmd.Env == nil { cmd.Env = os.Environ() } found := make(map[string]bool) @@ -631,18 +631,18 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon return nil, fmt.Errorf("creating stderr pipe: %v", err) } if !HasDupStdin(runPacket.Fds) { - cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false) + cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false, false) } cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false) cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false) for _, rfd := range runPacket.Fds { if rfd.Read && rfd.DupStdin { - cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false) + cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false, false) continue } if rfd.Read { fd := fdContext.GetReader(rfd.FdNum) - cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false) + cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false, false) } else if rfd.Write { fd := fdContext.GetWriter(rfd.FdNum) cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true) @@ -791,6 +791,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S cmdTty.Close() }() cmd.CmdPty = cmdPty + UpdateCmdEnv(cmd.Cmd, map[string]string{"TERM": "xterm-256color"}) } if cmdTty != nil { cmd.Cmd.Stdin = cmdTty @@ -801,13 +802,13 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S Setctty: true, } cmd.Multiplexer.MakeRawFdWriter(0, cmdPty, false) - cmd.Multiplexer.MakeRawFdReader(1, cmdPty, false) + cmd.Multiplexer.MakeRawFdReader(1, cmdPty, false, true) nullFd, err := os.Open("/dev/null") if err != nil { cmd.Close() return nil, fmt.Errorf("cannot open /dev/null: %w", err) } - cmd.Multiplexer.MakeRawFdReader(2, nullFd, true) + cmd.Multiplexer.MakeRawFdReader(2, nullFd, true, false) } else { cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0) if err != nil {