diff --git a/main-mshell.go b/main-mshell.go index d43aafd5f..b8b2b85f3 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -376,7 +376,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) { return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password") } opts.Sudo = true - opts.SSHOpts.SudoWithPass = true + opts.SudoWithPass = true opts.SudoPw = iter.Next() continue } @@ -385,7 +385,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) { return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file") } opts.Sudo = true - opts.SSHOpts.SudoWithPass = true + opts.SudoWithPass = true fileName := iter.Next() contents, err := os.ReadFile(fileName) if err != nil { @@ -433,7 +433,7 @@ func handleClient() (int, error) { if err != nil { return 1, err } - donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, opts.SSHOpts, opts.Debug) + donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, opts.Debug) if err != nil { return 1, err } diff --git a/pkg/mpio/bufreader.go b/pkg/mpio/bufreader.go index 654d8d680..bcda317e4 100644 --- a/pkg/mpio/bufreader.go +++ b/pkg/mpio/bufreader.go @@ -8,7 +8,6 @@ package mpio import ( "io" - "os" "sync" "github.com/scripthaus-dev/mshell/pkg/packet" @@ -18,13 +17,13 @@ type FdReader struct { CVar *sync.Cond M *Multiplexer FdNum int - Fd *os.File + Fd io.ReadCloser BufSize int Closed bool ShouldCloseFd bool } -func MakeFdReader(m *Multiplexer, fd *os.File, fdNum int, shouldCloseFd bool) *FdReader { +func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool) *FdReader { fr := &FdReader{ CVar: sync.NewCond(&sync.Mutex{}), M: m, diff --git a/pkg/mpio/bufwriter.go b/pkg/mpio/bufwriter.go index c1b36e2b4..977a2493c 100644 --- a/pkg/mpio/bufwriter.go +++ b/pkg/mpio/bufwriter.go @@ -8,7 +8,7 @@ package mpio import ( "fmt" - "os" + "io" "sync" ) @@ -17,13 +17,13 @@ type FdWriter struct { M *Multiplexer FdNum int Buffer []byte - Fd *os.File + Fd io.WriteCloser Eof bool Closed bool ShouldCloseFd bool } -func MakeFdWriter(m *Multiplexer, fd *os.File, fdNum int, shouldCloseFd bool) *FdWriter { +func MakeFdWriter(m *Multiplexer, fd io.WriteCloser, fdNum int, shouldCloseFd bool) *FdWriter { fw := &FdWriter{ CVar: sync.NewCond(&sync.Mutex{}), Fd: fd, diff --git a/pkg/mpio/mpio.go b/pkg/mpio/mpio.go index 6fce20616..ca0a0ffaa 100644 --- a/pkg/mpio/mpio.go +++ b/pkg/mpio/mpio.go @@ -9,6 +9,7 @@ package mpio import ( "encoding/base64" "fmt" + "io" "os" "sync" @@ -111,13 +112,13 @@ func (m *Multiplexer) MakeStringFdReader(fdNum int, contents string) error { return nil } -func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) { +func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool) { m.Lock.Lock() defer m.Lock.Unlock() m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose) } -func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) { +func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd io.WriteCloser, shouldClose bool) { m.Lock.Lock() defer m.Lock.Unlock() m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose) diff --git a/pkg/mpio/packetreader.go b/pkg/mpio/packetreader.go new file mode 100644 index 000000000..8dfb30241 --- /dev/null +++ b/pkg/mpio/packetreader.go @@ -0,0 +1,96 @@ +// Copyright 2022 Dashborg Inc +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package mpio + +import ( + "encoding/base64" + "errors" + "io" + "sync" + + "github.com/scripthaus-dev/mshell/pkg/packet" +) + +type PacketReader struct { + CVar *sync.Cond + FdNum int + Buf []byte + Eof bool + Err error +} + +func MakePacketReader(fdNum int) *PacketReader { + return &PacketReader{ + CVar: sync.NewCond(&sync.Mutex{}), + FdNum: fdNum, + } +} + +func (pr *PacketReader) AddData(pk *packet.DataPacketType) { + pr.CVar.L.Lock() + defer pr.CVar.L.Unlock() + defer pr.CVar.Broadcast() + if pr.Eof || pr.Err != nil { + return + } + if pk.Data64 != "" { + realData, err := base64.StdEncoding.DecodeString(pk.Data64) + if err != nil { + pr.Err = err + return + } + pr.Buf = append(pr.Buf, realData...) + } + pr.Eof = pk.Eof + if pk.Error != "" { + pr.Err = errors.New(pk.Error) + } + return +} + +func (pr *PacketReader) Read(buf []byte) (int, error) { + pr.CVar.L.Lock() + defer pr.CVar.L.Unlock() + for { + if pr.Err != nil { + return 0, pr.Err + } + if pr.Eof { + return 0, io.EOF + } + if len(pr.Buf) == 0 { + pr.CVar.Wait() + continue + } + nr := copy(buf, pr.Buf) + pr.Buf = pr.Buf[nr:] + if len(pr.Buf) == 0 { + pr.Buf = nil + } + return nr, nil + } +} + +func (pr *PacketReader) Close() error { + pr.CVar.L.Lock() + defer pr.CVar.L.Unlock() + defer pr.CVar.Broadcast() + if pr.Err == nil { + pr.Err = io.ErrClosedPipe + } + return nil +} + +type NullReader struct{} + +func (NullReader) Read(buf []byte) (int, error) { + return 0, io.EOF +} + +func (NullReader) Close() error { + return nil +} diff --git a/pkg/mpio/packetwriter.go b/pkg/mpio/packetwriter.go new file mode 100644 index 000000000..0665f0441 --- /dev/null +++ b/pkg/mpio/packetwriter.go @@ -0,0 +1,40 @@ +// Copyright 2022 Dashborg Inc +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package mpio + +import ( + "encoding/base64" + + "github.com/scripthaus-dev/mshell/pkg/base" + "github.com/scripthaus-dev/mshell/pkg/packet" +) + +type PacketWriter struct { + FdNum int + Sender *packet.PacketSender + CK base.CommandKey +} + +func MakePacketWriter(fdNum int, sender *packet.PacketSender, ck base.CommandKey) *PacketWriter { + return &PacketWriter{FdNum: fdNum, Sender: sender, CK: ck} +} + +func (pw *PacketWriter) Write(data []byte) (int, error) { + pk := packet.MakeDataPacket() + pk.CK = pw.CK + pk.FdNum = pw.FdNum + pk.Data64 = base64.StdEncoding.EncodeToString(data) + return len(data), pw.Sender.SendPacket(pk) +} + +func (pw *PacketWriter) Close() error { + pk := packet.MakeDataPacket() + pk.CK = pw.CK + pk.FdNum = pw.FdNum + pk.Eof = true + return pw.Sender.SendPacket(pk) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 61bd17601..0400d1323 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -8,17 +8,21 @@ package server import ( "fmt" + "io" "os" "sync" "github.com/scripthaus-dev/mshell/pkg/base" + "github.com/scripthaus-dev/mshell/pkg/mpio" "github.com/scripthaus-dev/mshell/pkg/packet" + "github.com/scripthaus-dev/mshell/pkg/shexec" ) type MServer struct { Lock *sync.Mutex MainInput *packet.PacketParser Sender *packet.PacketSender + FdContext *serverFdContext } func (m *MServer) Close() { @@ -26,13 +30,73 @@ func (m *MServer) Close() { m.Sender.WaitForDone() } +type serverFdContext struct { + M *MServer + Lock *sync.Mutex + Sender *packet.PacketSender + CK base.CommandKey + Readers map[int]*mpio.PacketReader +} + +func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext { + rtn := &serverFdContext{ + M: m, + Lock: &sync.Mutex{}, + Sender: m.Sender, + CK: ck, + Readers: make(map[int]*mpio.PacketReader), + } + return rtn +} + +func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) { + c.Lock.Lock() + reader := c.Readers[pk.FdNum] + c.Lock.Unlock() + if reader == nil { + ackPacket := packet.MakeDataAckPacket() + ackPacket.CK = c.CK + ackPacket.FdNum = pk.FdNum + ackPacket.Error = "write to closed file (no fd)" + c.M.Sender.SendPacket(ackPacket) + return + } + reader.AddData(pk) + return +} + +func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser { + return mpio.MakePacketWriter(fdNum, c.Sender, c.CK) +} + +func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser { + c.Lock.Lock() + defer c.Lock.Unlock() + reader := mpio.MakePacketReader(fdNum) + c.Readers[fdNum] = reader + return reader +} + +func (m *MServer) runCommand(runPacket *packet.RunPacketType) { + fdContext := m.MakeServerFdContext(runPacket.CK) + m.Lock.Lock() + m.FdContext = fdContext + m.Lock.Unlock() + go func() { + donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, true) + fmt.Printf("done: err:%v, %v\n", err, donePk) + }() +} + func RunServer() (int, error) { server := &MServer{ Lock: &sync.Mutex{}, } + packet.GlobalDebug = true server.MainInput = packet.MakePacketParser(os.Stdin) server.Sender = packet.MakePacketSender(os.Stdout) defer server.Close() + defer fmt.Printf("runserver done\n") initPacket := packet.MakeInitPacket() initPacket.Version = base.MShellVersion server.Sender.SendPacket(initPacket) @@ -43,7 +107,12 @@ func RunServer() (int, error) { } if pk.GetType() == packet.RunPacketStr { runPacket := pk.(*packet.RunPacketType) - fmt.Printf("RUN> %s\n", runPacket) + server.runCommand(runPacket) + continue + } + if pk.GetType() == packet.DataPacketStr { + dataPacket := pk.(*packet.DataPacketType) + server.FdContext.processDataPacket(dataPacket) continue } server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk))) diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index f0334e3f4..4bc6dadca 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -64,6 +64,41 @@ type ShExecType struct { Multiplexer *mpio.Multiplexer } +type StdContext struct{} + +func (StdContext) GetWriter(fdNum int) io.WriteCloser { + if fdNum == 0 { + return os.Stdin + } + if fdNum == 1 { + return os.Stdout + } + if fdNum == 2 { + return os.Stderr + } + fd := os.NewFile(uintptr(fdNum), fmt.Sprintf("/dev/fd/%d", fdNum)) + return fd +} + +func (StdContext) GetReader(fdNum int) io.ReadCloser { + if fdNum == 0 { + return os.Stdin + } + if fdNum == 1 { + return os.Stdout + } + if fdNum == 2 { + return os.Stdout + } + fd := os.NewFile(uintptr(fdNum), fmt.Sprintf("/dev/fd/%d", fdNum)) + return fd +} + +type FdContext interface { + GetWriter(fdNum int) io.WriteCloser + GetReader(fdNum int) io.ReadCloser +} + func MakeShExec(ck base.CommandKey) *ShExecType { return &ShExecType{ Lock: &sync.Mutex{}, @@ -237,11 +272,10 @@ func RunCommand(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecT } type SSHOpts struct { - SSHHost string - SSHOptsStr string - SSHIdentity string - SSHUser string - SudoWithPass bool + SSHHost string + SSHOptsStr string + SSHIdentity string + SSHUser string } type InstallOpts struct { @@ -258,6 +292,7 @@ type ClientOpts struct { Cwd string Debug bool Sudo bool + SudoWithPass bool SudoPw string CommandStdinFdNum int Detach bool @@ -316,7 +351,7 @@ func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) { runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command) return runPacket, nil } - if opts.SSHOpts.SudoWithPass { + if opts.SudoWithPass { pwFdNum, err := opts.NextFreeFdNum() if err != nil { return nil, err @@ -480,7 +515,16 @@ func RunInstallSSHCommand(opts *InstallOpts) error { return fmt.Errorf("did not receive version string from client, install not successful") } -func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) { +func HasDupStdin(fds []packet.RemoteFd) bool { + for _, rfd := range fds { + if rfd.Read && rfd.DupStdin { + return true + } + } + return false +} + +func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) { cmd := MakeShExec("") ecmd := sshOpts.MakeSSHExecCmd(ClientCommand) cmd.Cmd = ecmd @@ -496,11 +540,11 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts if err != nil { return nil, fmt.Errorf("creating stderr pipe: %v", err) } - if !sshOpts.SudoWithPass { - cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false) + if !HasDupStdin(runPacket.Fds) { + cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false) } - cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false) - cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false) + cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false) + cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false) for _, rfd := range runPacket.Fds { if rfd.Read && rfd.Content != "" { err = cmd.Multiplexer.MakeStringFdReader(rfd.FdNum, rfd.Content) @@ -510,16 +554,14 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts continue } if rfd.Read && rfd.DupStdin { - cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, os.Stdin, false) + cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false) continue } - fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum)) - if fd == nil { - return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum) - } if rfd.Read { - cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true) + fd := fdContext.GetReader(rfd.FdNum) + cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false) } else if rfd.Write { + fd := fdContext.GetWriter(rfd.FdNum) cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true) } }