working on server mode. extract fdcontext as interface. create packet writer/reader for mpio. hook up to serverFdContext.

This commit is contained in:
sawka 2022-06-28 17:20:01 -07:00
parent d7eb2526f0
commit 1d44afc10e
8 changed files with 276 additions and 29 deletions

View File

@ -376,7 +376,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password") return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password")
} }
opts.Sudo = true opts.Sudo = true
opts.SSHOpts.SudoWithPass = true opts.SudoWithPass = true
opts.SudoPw = iter.Next() opts.SudoPw = iter.Next()
continue continue
} }
@ -385,7 +385,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file") return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file")
} }
opts.Sudo = true opts.Sudo = true
opts.SSHOpts.SudoWithPass = true opts.SudoWithPass = true
fileName := iter.Next() fileName := iter.Next()
contents, err := os.ReadFile(fileName) contents, err := os.ReadFile(fileName)
if err != nil { if err != nil {
@ -433,7 +433,7 @@ func handleClient() (int, error) {
if err != nil { if err != nil {
return 1, err return 1, err
} }
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, opts.SSHOpts, opts.Debug) donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, opts.Debug)
if err != nil { if err != nil {
return 1, err return 1, err
} }

View File

@ -8,7 +8,6 @@ package mpio
import ( import (
"io" "io"
"os"
"sync" "sync"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
@ -18,13 +17,13 @@ type FdReader struct {
CVar *sync.Cond CVar *sync.Cond
M *Multiplexer M *Multiplexer
FdNum int FdNum int
Fd *os.File Fd io.ReadCloser
BufSize int BufSize int
Closed bool Closed bool
ShouldCloseFd bool ShouldCloseFd bool
} }
func MakeFdReader(m *Multiplexer, fd *os.File, fdNum int, shouldCloseFd bool) *FdReader { func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool) *FdReader {
fr := &FdReader{ fr := &FdReader{
CVar: sync.NewCond(&sync.Mutex{}), CVar: sync.NewCond(&sync.Mutex{}),
M: m, M: m,

View File

@ -8,7 +8,7 @@ package mpio
import ( import (
"fmt" "fmt"
"os" "io"
"sync" "sync"
) )
@ -17,13 +17,13 @@ type FdWriter struct {
M *Multiplexer M *Multiplexer
FdNum int FdNum int
Buffer []byte Buffer []byte
Fd *os.File Fd io.WriteCloser
Eof bool Eof bool
Closed bool Closed bool
ShouldCloseFd bool ShouldCloseFd bool
} }
func MakeFdWriter(m *Multiplexer, fd *os.File, fdNum int, shouldCloseFd bool) *FdWriter { func MakeFdWriter(m *Multiplexer, fd io.WriteCloser, fdNum int, shouldCloseFd bool) *FdWriter {
fw := &FdWriter{ fw := &FdWriter{
CVar: sync.NewCond(&sync.Mutex{}), CVar: sync.NewCond(&sync.Mutex{}),
Fd: fd, Fd: fd,

View File

@ -9,6 +9,7 @@ package mpio
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"os" "os"
"sync" "sync"
@ -111,13 +112,13 @@ func (m *Multiplexer) MakeStringFdReader(fdNum int, contents string) error {
return nil return nil
} }
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) { func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool) {
m.Lock.Lock() m.Lock.Lock()
defer m.Lock.Unlock() defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose) m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
} }
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) { func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd io.WriteCloser, shouldClose bool) {
m.Lock.Lock() m.Lock.Lock()
defer m.Lock.Unlock() defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose) m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)

96
pkg/mpio/packetreader.go Normal file
View File

@ -0,0 +1,96 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package mpio
import (
"encoding/base64"
"errors"
"io"
"sync"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
type PacketReader struct {
CVar *sync.Cond
FdNum int
Buf []byte
Eof bool
Err error
}
func MakePacketReader(fdNum int) *PacketReader {
return &PacketReader{
CVar: sync.NewCond(&sync.Mutex{}),
FdNum: fdNum,
}
}
func (pr *PacketReader) AddData(pk *packet.DataPacketType) {
pr.CVar.L.Lock()
defer pr.CVar.L.Unlock()
defer pr.CVar.Broadcast()
if pr.Eof || pr.Err != nil {
return
}
if pk.Data64 != "" {
realData, err := base64.StdEncoding.DecodeString(pk.Data64)
if err != nil {
pr.Err = err
return
}
pr.Buf = append(pr.Buf, realData...)
}
pr.Eof = pk.Eof
if pk.Error != "" {
pr.Err = errors.New(pk.Error)
}
return
}
func (pr *PacketReader) Read(buf []byte) (int, error) {
pr.CVar.L.Lock()
defer pr.CVar.L.Unlock()
for {
if pr.Err != nil {
return 0, pr.Err
}
if pr.Eof {
return 0, io.EOF
}
if len(pr.Buf) == 0 {
pr.CVar.Wait()
continue
}
nr := copy(buf, pr.Buf)
pr.Buf = pr.Buf[nr:]
if len(pr.Buf) == 0 {
pr.Buf = nil
}
return nr, nil
}
}
func (pr *PacketReader) Close() error {
pr.CVar.L.Lock()
defer pr.CVar.L.Unlock()
defer pr.CVar.Broadcast()
if pr.Err == nil {
pr.Err = io.ErrClosedPipe
}
return nil
}
type NullReader struct{}
func (NullReader) Read(buf []byte) (int, error) {
return 0, io.EOF
}
func (NullReader) Close() error {
return nil
}

40
pkg/mpio/packetwriter.go Normal file
View File

@ -0,0 +1,40 @@
// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package mpio
import (
"encoding/base64"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
type PacketWriter struct {
FdNum int
Sender *packet.PacketSender
CK base.CommandKey
}
func MakePacketWriter(fdNum int, sender *packet.PacketSender, ck base.CommandKey) *PacketWriter {
return &PacketWriter{FdNum: fdNum, Sender: sender, CK: ck}
}
func (pw *PacketWriter) Write(data []byte) (int, error) {
pk := packet.MakeDataPacket()
pk.CK = pw.CK
pk.FdNum = pw.FdNum
pk.Data64 = base64.StdEncoding.EncodeToString(data)
return len(data), pw.Sender.SendPacket(pk)
}
func (pw *PacketWriter) Close() error {
pk := packet.MakeDataPacket()
pk.CK = pw.CK
pk.FdNum = pw.FdNum
pk.Eof = true
return pw.Sender.SendPacket(pk)
}

View File

@ -8,17 +8,21 @@ package server
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"sync" "sync"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/mpio"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
) )
type MServer struct { type MServer struct {
Lock *sync.Mutex Lock *sync.Mutex
MainInput *packet.PacketParser MainInput *packet.PacketParser
Sender *packet.PacketSender Sender *packet.PacketSender
FdContext *serverFdContext
} }
func (m *MServer) Close() { func (m *MServer) Close() {
@ -26,13 +30,73 @@ func (m *MServer) Close() {
m.Sender.WaitForDone() m.Sender.WaitForDone()
} }
type serverFdContext struct {
M *MServer
Lock *sync.Mutex
Sender *packet.PacketSender
CK base.CommandKey
Readers map[int]*mpio.PacketReader
}
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
rtn := &serverFdContext{
M: m,
Lock: &sync.Mutex{},
Sender: m.Sender,
CK: ck,
Readers: make(map[int]*mpio.PacketReader),
}
return rtn
}
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
c.Lock.Lock()
reader := c.Readers[pk.FdNum]
c.Lock.Unlock()
if reader == nil {
ackPacket := packet.MakeDataAckPacket()
ackPacket.CK = c.CK
ackPacket.FdNum = pk.FdNum
ackPacket.Error = "write to closed file (no fd)"
c.M.Sender.SendPacket(ackPacket)
return
}
reader.AddData(pk)
return
}
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
}
func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
c.Lock.Lock()
defer c.Lock.Unlock()
reader := mpio.MakePacketReader(fdNum)
c.Readers[fdNum] = reader
return reader
}
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
fdContext := m.MakeServerFdContext(runPacket.CK)
m.Lock.Lock()
m.FdContext = fdContext
m.Lock.Unlock()
go func() {
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, true)
fmt.Printf("done: err:%v, %v\n", err, donePk)
}()
}
func RunServer() (int, error) { func RunServer() (int, error) {
server := &MServer{ server := &MServer{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
} }
packet.GlobalDebug = true
server.MainInput = packet.MakePacketParser(os.Stdin) server.MainInput = packet.MakePacketParser(os.Stdin)
server.Sender = packet.MakePacketSender(os.Stdout) server.Sender = packet.MakePacketSender(os.Stdout)
defer server.Close() defer server.Close()
defer fmt.Printf("runserver done\n")
initPacket := packet.MakeInitPacket() initPacket := packet.MakeInitPacket()
initPacket.Version = base.MShellVersion initPacket.Version = base.MShellVersion
server.Sender.SendPacket(initPacket) server.Sender.SendPacket(initPacket)
@ -43,7 +107,12 @@ func RunServer() (int, error) {
} }
if pk.GetType() == packet.RunPacketStr { if pk.GetType() == packet.RunPacketStr {
runPacket := pk.(*packet.RunPacketType) runPacket := pk.(*packet.RunPacketType)
fmt.Printf("RUN> %s\n", runPacket) server.runCommand(runPacket)
continue
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
server.FdContext.processDataPacket(dataPacket)
continue continue
} }
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk))) server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk)))

View File

@ -64,6 +64,41 @@ type ShExecType struct {
Multiplexer *mpio.Multiplexer Multiplexer *mpio.Multiplexer
} }
type StdContext struct{}
func (StdContext) GetWriter(fdNum int) io.WriteCloser {
if fdNum == 0 {
return os.Stdin
}
if fdNum == 1 {
return os.Stdout
}
if fdNum == 2 {
return os.Stderr
}
fd := os.NewFile(uintptr(fdNum), fmt.Sprintf("/dev/fd/%d", fdNum))
return fd
}
func (StdContext) GetReader(fdNum int) io.ReadCloser {
if fdNum == 0 {
return os.Stdin
}
if fdNum == 1 {
return os.Stdout
}
if fdNum == 2 {
return os.Stdout
}
fd := os.NewFile(uintptr(fdNum), fmt.Sprintf("/dev/fd/%d", fdNum))
return fd
}
type FdContext interface {
GetWriter(fdNum int) io.WriteCloser
GetReader(fdNum int) io.ReadCloser
}
func MakeShExec(ck base.CommandKey) *ShExecType { func MakeShExec(ck base.CommandKey) *ShExecType {
return &ShExecType{ return &ShExecType{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
@ -241,7 +276,6 @@ type SSHOpts struct {
SSHOptsStr string SSHOptsStr string
SSHIdentity string SSHIdentity string
SSHUser string SSHUser string
SudoWithPass bool
} }
type InstallOpts struct { type InstallOpts struct {
@ -258,6 +292,7 @@ type ClientOpts struct {
Cwd string Cwd string
Debug bool Debug bool
Sudo bool Sudo bool
SudoWithPass bool
SudoPw string SudoPw string
CommandStdinFdNum int CommandStdinFdNum int
Detach bool Detach bool
@ -316,7 +351,7 @@ func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) {
runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command) runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command)
return runPacket, nil return runPacket, nil
} }
if opts.SSHOpts.SudoWithPass { if opts.SudoWithPass {
pwFdNum, err := opts.NextFreeFdNum() pwFdNum, err := opts.NextFreeFdNum()
if err != nil { if err != nil {
return nil, err return nil, err
@ -480,7 +515,16 @@ func RunInstallSSHCommand(opts *InstallOpts) error {
return fmt.Errorf("did not receive version string from client, install not successful") return fmt.Errorf("did not receive version string from client, install not successful")
} }
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) { func HasDupStdin(fds []packet.RemoteFd) bool {
for _, rfd := range fds {
if rfd.Read && rfd.DupStdin {
return true
}
}
return false
}
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) {
cmd := MakeShExec("") cmd := MakeShExec("")
ecmd := sshOpts.MakeSSHExecCmd(ClientCommand) ecmd := sshOpts.MakeSSHExecCmd(ClientCommand)
cmd.Cmd = ecmd cmd.Cmd = ecmd
@ -496,11 +540,11 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts
if err != nil { if err != nil {
return nil, fmt.Errorf("creating stderr pipe: %v", err) return nil, fmt.Errorf("creating stderr pipe: %v", err)
} }
if !sshOpts.SudoWithPass { if !HasDupStdin(runPacket.Fds) {
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false) cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false)
} }
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false) cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false)
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false) cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false)
for _, rfd := range runPacket.Fds { for _, rfd := range runPacket.Fds {
if rfd.Read && rfd.Content != "" { if rfd.Read && rfd.Content != "" {
err = cmd.Multiplexer.MakeStringFdReader(rfd.FdNum, rfd.Content) err = cmd.Multiplexer.MakeStringFdReader(rfd.FdNum, rfd.Content)
@ -510,16 +554,14 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts
continue continue
} }
if rfd.Read && rfd.DupStdin { if rfd.Read && rfd.DupStdin {
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, os.Stdin, false) cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false)
continue continue
} }
fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum))
if fd == nil {
return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum)
}
if rfd.Read { if rfd.Read {
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true) fd := fdContext.GetReader(rfd.FdNum)
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false)
} else if rfd.Write { } else if rfd.Write {
fd := fdContext.GetWriter(rfd.FdNum)
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true) cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true)
} }
} }