move static files from remotefd content to 'rundata'. send all rundata before command start. parse rundata before command start. compatible with detached commands

This commit is contained in:
sawka 2022-06-28 21:57:30 -07:00
parent 9054c3cdcc
commit c73691ac24
6 changed files with 214 additions and 80 deletions

View File

@ -42,9 +42,6 @@ func doSingle(ck base.CommandKey) {
sender := packet.MakePacketSender(os.Stdout)
var runPacket *packet.RunPacketType
for pk := range packetParser.MainCh {
if pk.GetType() == packet.PingPacketStr {
continue
}
if pk.GetType() == packet.RunPacketStr {
runPacket, _ = pk.(*packet.RunPacketType)
break
@ -173,9 +170,6 @@ func doMain() {
}
sender.SendPacket(initPacket)
for pk := range packetParser.MainCh {
if pk.GetType() == packet.PingPacketStr {
continue
}
if pk.GetType() == packet.RunPacketStr {
doMainRun(pk.(*packet.RunPacketType), sender)
continue
@ -211,6 +205,20 @@ func doMain() {
}
}
func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType, error) {
rpb := packet.MakeRunPacketBuilder()
for pk := range packetParser.MainCh {
ok, runPacket := rpb.ProcessPacket(pk)
if runPacket != nil {
return runPacket, nil
}
if !ok {
return nil, fmt.Errorf("invalid packet '%s' sent to mshell", pk.GetType())
}
}
return nil, fmt.Errorf("no run packet received")
}
func handleSingle() {
packetParser := packet.MakePacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
@ -228,20 +236,13 @@ func handleSingle() {
initPacket := packet.MakeInitPacket()
initPacket.Version = base.MShellVersion
sender.SendPacket(initPacket)
var runPacket *packet.RunPacketType
for pk := range packetParser.MainCh {
if pk.GetType() == packet.PingPacketStr {
continue
runPacket, err := readFullRunPacket(packetParser)
if err != nil {
ck := base.CommandKey("")
if runPacket != nil {
ck = runPacket.CK
}
if pk.GetType() == packet.RunPacketStr {
runPacket, _ = pk.(*packet.RunPacketType)
break
}
sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
return
}
if runPacket == nil {
sender.SendErrorPacket(fmt.Sprintf("no run packet received"))
sender.SendCKErrorPacket(ck, err.Error())
return
}
cmd, err := shexec.RunCommand(runPacket, sender)

View File

@ -20,12 +20,14 @@ import (
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const MaxSingleWriteSize = 4 * 1024
const MaxTotalRunDataSize = 10 * ReadBufSize
type Multiplexer struct {
Lock *sync.Mutex
CK base.CommandKey
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
RunData map[int]*FdReader // synchronized
CloseAfterStart []*os.File // synchronized
Sender *packet.PacketSender
@ -105,16 +107,22 @@ func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
return pr, nil
}
func (m *Multiplexer) MakeStringFdReader(fdNum int, contents string) error {
pw, err := m.MakeReaderPipe(fdNum)
// returns the *reader* to connect to process, writer is put in FdWriters
func (m *Multiplexer) MakeStaticWriterPipe(fdNum int, data []byte) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return err
return nil, err
}
go func() {
pw.Write([]byte(contents))
pw.Close()
}()
return nil
m.Lock.Lock()
defer m.Lock.Unlock()
fdWriter := MakeFdWriter(m, pw, fdNum, true)
err = fdWriter.AddData(data, true)
if err != nil {
return nil, err
}
m.FdWriters[fdNum] = fdWriter
m.CloseAfterStart = append(m.CloseAfterStart, pr)
return pr, nil
}
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool) {
@ -212,6 +220,10 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
donePacket := pk.(*packet.CmdDonePacketType)
return donePacket
}
if pk.GetType() == packet.CmdStartPacketStr {
// nothing
continue
}
m.UPR.UnknownPacket(pk)
}
return nil

View File

@ -8,6 +8,7 @@ package packet
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@ -33,6 +34,7 @@ const (
DataAckPacketStr = "dataack"
CmdStartPacketStr = "cmdstart"
CmdDonePacketStr = "cmddone"
DataEndPacketStr = "dataend"
ResponsePacketStr = "resp"
DonePacketStr = "done"
ErrorPacketStr = "error"
@ -68,6 +70,7 @@ func init() {
TypeStrToFactory[InputPacketStr] = reflect.TypeOf(InputPacketType{})
TypeStrToFactory[DataPacketStr] = reflect.TypeOf(DataPacketType{})
TypeStrToFactory[DataAckPacketStr] = reflect.TypeOf(DataAckPacketType{})
TypeStrToFactory[DataEndPacketStr] = reflect.TypeOf(DataEndPacketType{})
}
func MakePacket(packetType string) (PacketType, error) {
@ -166,6 +169,23 @@ func MakeDataPacket() *DataPacketType {
return &DataPacketType{Type: DataPacketStr}
}
type DataEndPacketType struct {
Type string `json:"type"`
CK base.CommandKey `json:"ck"`
}
func MakeDataEndPacket(ck base.CommandKey) *DataEndPacketType {
return &DataEndPacketType{Type: DataEndPacketStr, CK: ck}
}
func (*DataEndPacketType) GetType() string {
return DataEndPacketStr
}
func (p *DataEndPacketType) GetCK() base.CommandKey {
return p.CK
}
type DataAckPacketType struct {
Type string `json:"type"`
CK base.CommandKey `json:"ck"`
@ -411,11 +431,16 @@ type TermSize struct {
}
type RemoteFd struct {
FdNum int `json:"fdnum"`
Read bool `json:"read"`
Write bool `json:"write"`
Content string `json:"-"`
DupStdin bool `json:"-"`
FdNum int `json:"fdnum"`
Read bool `json:"read"`
Write bool `json:"write"`
DupStdin bool `json:"-"`
}
type RunDataType struct {
FdNum int `json:"fdnum"`
DataLen int `json:"datalen"`
Data []byte `json:"-"`
}
type RunPacketType struct {
@ -426,6 +451,7 @@ type RunPacketType struct {
Env map[string]string `json:"env,omitempty"`
TermSize *TermSize `json:"termsize,omitempty"`
Fds []RemoteFd `json:"fds,omitempty"`
RunData []RunDataType `json:"rundata,omitempty"`
Detached bool `json:"detached,omitempty"`
}
@ -637,3 +663,47 @@ func (DefaultUPR) UnknownPacket(pk PacketType) {
}
}
// todo: clean hanging entries in RunMap when in server mode
type RunPacketBuilder struct {
RunMap map[base.CommandKey]*RunPacketType
}
func MakeRunPacketBuilder() *RunPacketBuilder {
return &RunPacketBuilder{
RunMap: make(map[base.CommandKey]*RunPacketType),
}
}
// returns (consumed, fullRunPacket)
func (b *RunPacketBuilder) ProcessPacket(pk PacketType) (bool, *RunPacketType) {
if pk.GetType() == RunPacketStr {
runPacket := pk.(*RunPacketType)
b.RunMap[runPacket.CK] = runPacket
return true, nil
}
if pk.GetType() == DataEndPacketStr {
endPacket := pk.(*DataEndPacketType)
runPacket := b.RunMap[endPacket.CK] // might be nil
delete(b.RunMap, endPacket.CK)
return true, runPacket
}
if pk.GetType() == DataPacketStr {
dataPacket := pk.(*DataPacketType)
runPacket := b.RunMap[dataPacket.CK]
if runPacket == nil {
return false, nil
}
for idx, runData := range runPacket.RunData {
if runData.FdNum == dataPacket.FdNum {
// can ignore error, will get caught later with RunData.DataLen check
realData, _ := base64.StdEncoding.DecodeString(dataPacket.Data64)
runData.Data = append(runData.Data, realData...)
runPacket.RunData[idx] = runData
break
}
}
return true, nil
}
return false, nil
}

View File

@ -93,6 +93,9 @@ func MakePacketParser(input io.Reader) *PacketParser {
if pk.GetType() == DonePacketStr {
return
}
if pk.GetType() == PingPacketStr {
continue
}
parser.MainCh <- pk
}
}()

View File

@ -151,9 +151,6 @@ func RunServer() (int, error) {
if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.PingPacketStr {
continue
}
if pk.GetType() == packet.RunPacketStr {
runPacket := pk.(*packet.RunPacketType)
server.runCommand(runPacket)

View File

@ -7,6 +7,7 @@
package shexec
import (
"encoding/base64"
"fmt"
"io"
"os"
@ -223,15 +224,20 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
if rfd.Write {
return fmt.Errorf("cannot detach command with writable remote files fd=%d", rfd.FdNum)
}
if rfd.Read {
if rfd.Content == "" {
return fmt.Errorf("cannot detach command with readable remote files fd=%d", rfd.FdNum)
}
if len(rfd.Content) > mpio.ReadBufSize {
return fmt.Errorf("cannot detach command, constant readable input too large fd=%d, len=%d, max=%d", rfd.FdNum, len(rfd.Content), mpio.ReadBufSize)
}
if rfd.Read && !rfd.DupStdin {
return fmt.Errorf("cannot detach command with readable remote files fd=%d", rfd.FdNum)
}
}
totalRunData := 0
for _, rd := range pk.RunData {
if rd.DataLen > mpio.ReadBufSize {
return fmt.Errorf("cannot detach command, constant rundata input too large fd=%d, len=%d, max=%d", rd.FdNum, rd.DataLen, mpio.ReadBufSize)
}
totalRunData += rd.DataLen
}
if totalRunData > mpio.MaxTotalRunDataSize {
return fmt.Errorf("cannot detach command, constant rundata input too large len=%d, max=%d", totalRunData, mpio.MaxTotalRunDataSize)
}
}
if pk.Cwd != "" {
realCwd := base.ExpandHomeDir(pk.Cwd)
@ -243,6 +249,11 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
return fmt.Errorf("invalid cwd '%s' for command, not a directory", realCwd)
}
}
for _, runData := range pk.RunData {
if runData.DataLen != len(runData.Data) {
return fmt.Errorf("rundata length mismatch, fd=%d, datalen=%d, expected=%d", runData.FdNum, len(runData.Data), runData.DataLen)
}
}
return nil
}
@ -286,16 +297,15 @@ type InstallOpts struct {
}
type ClientOpts struct {
SSHOpts SSHOpts
Command string
Fds []packet.RemoteFd
Cwd string
Debug bool
Sudo bool
SudoWithPass bool
SudoPw string
CommandStdinFdNum int
Detach bool
SSHOpts SSHOpts
Command string
Fds []packet.RemoteFd
Cwd string
Debug bool
Sudo bool
SudoWithPass bool
SudoPw string
Detach bool
}
func (opts SSHOpts) MakeSSHExecCmd(remoteCommand string) *exec.Cmd {
@ -352,48 +362,55 @@ func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) {
return runPacket, nil
}
if opts.SudoWithPass {
pwFdNum, err := opts.NextFreeFdNum()
pwFdNum, err := AddRunData(runPacket, opts.SudoPw, "sudo pw")
if err != nil {
return nil, err
}
pwRfd := packet.RemoteFd{FdNum: pwFdNum, Read: true, Content: opts.SudoPw}
opts.Fds = append(opts.Fds, pwRfd)
commandFdNum, err := opts.NextFreeFdNum()
commandFdNum, err := AddRunData(runPacket, opts.Command, "command")
if err != nil {
return nil, err
}
commandRfd := packet.RemoteFd{FdNum: commandFdNum, Read: true, Content: opts.Command}
opts.Fds = append(opts.Fds, commandRfd)
commandStdinFdNum, err := opts.NextFreeFdNum()
commandStdinFdNum, err := NextFreeFdNum(runPacket)
if err != nil {
return nil, err
}
commandStdinRfd := packet.RemoteFd{FdNum: commandStdinFdNum, Read: true, DupStdin: true}
opts.Fds = append(opts.Fds, commandStdinRfd)
opts.CommandStdinFdNum = commandStdinFdNum
maxFdNum := opts.MaxFdNum()
runPacket.Fds = append(runPacket.Fds, commandStdinRfd)
maxFdNum := MaxFdNumInPacket(runPacket)
runPacket.Command = fmt.Sprintf(RunSudoPasswordCommandFmt, pwFdNum, maxFdNum+1, pwFdNum, commandFdNum, commandStdinFdNum)
runPacket.Fds = opts.Fds
return runPacket, nil
} else {
commandFdNum, err := opts.NextFreeFdNum()
commandFdNum, err := AddRunData(runPacket, opts.Command, "command")
if err != nil {
return nil, err
}
rfd := packet.RemoteFd{FdNum: commandFdNum, Read: true, Content: opts.Command}
opts.Fds = append(opts.Fds, rfd)
maxFdNum := opts.MaxFdNum()
maxFdNum := MaxFdNumInPacket(runPacket)
runPacket.Command = fmt.Sprintf(RunSudoCommandFmt, maxFdNum+1, commandFdNum)
runPacket.Fds = opts.Fds
return runPacket, nil
}
}
func (opts *ClientOpts) NextFreeFdNum() (int, error) {
func AddRunData(pk *packet.RunPacketType, data string, dataType string) (int, error) {
if len(data) > mpio.ReadBufSize {
return 0, fmt.Errorf("%s too large, exceeds read buffer size", dataType)
}
fdNum, err := NextFreeFdNum(pk)
if err != nil {
return 0, err
}
runData := packet.RunDataType{FdNum: fdNum, DataLen: len(data), Data: []byte(data)}
pk.RunData = append(pk.RunData, runData)
return fdNum, nil
}
func NextFreeFdNum(pk *packet.RunPacketType) (int, error) {
fdMap := make(map[int]bool)
for _, fd := range opts.Fds {
for _, fd := range pk.Fds {
fdMap[fd.FdNum] = true
}
for _, rd := range pk.RunData {
fdMap[rd.FdNum] = true
}
for i := 3; i <= MaxFdNum; i++ {
if !fdMap[i] {
return i, nil
@ -402,13 +419,18 @@ func (opts *ClientOpts) NextFreeFdNum() (int, error) {
return 0, fmt.Errorf("reached maximum number of fds, all fds between 3-%d are in use", MaxFdNum)
}
func (opts *ClientOpts) MaxFdNum() int {
func MaxFdNumInPacket(pk *packet.RunPacketType) int {
maxFdNum := 3
for _, fd := range opts.Fds {
for _, fd := range pk.Fds {
if fd.FdNum > maxFdNum {
maxFdNum = fd.FdNum
}
}
for _, rd := range pk.RunData {
if rd.FdNum > maxFdNum {
maxFdNum = rd.FdNum
}
}
return maxFdNum
}
@ -546,13 +568,6 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false)
cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false)
for _, rfd := range runPacket.Fds {
if rfd.Read && rfd.Content != "" {
err = cmd.Multiplexer.MakeStringFdReader(rfd.FdNum, rfd.Content)
if err != nil {
return nil, fmt.Errorf("creating content fd %d", rfd.FdNum)
}
continue
}
if rfd.Read && rfd.DupStdin {
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false)
continue
@ -610,7 +625,7 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
if !versionOk {
return nil, fmt.Errorf("did not receive version from remote mshell")
}
sender.SendPacket(runPacket)
SendRunPacketAndRunData(sender, runPacket)
if debug {
cmd.Multiplexer.Debug = true
}
@ -622,6 +637,32 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdCon
return donePacket, nil
}
func min(v1 int, v2 int) int {
if v1 <= v2 {
return v1
}
return v2
}
func SendRunPacketAndRunData(sender *packet.PacketSender, runPacket *packet.RunPacketType) {
sender.SendPacket(runPacket)
for _, runData := range runPacket.RunData {
sendBuf := runData.Data
for len(sendBuf) > 0 {
chunkSize := min(len(sendBuf), mpio.MaxSingleWriteSize)
chunk := sendBuf[0:chunkSize]
dataPk := packet.MakeDataPacket()
dataPk.CK = runPacket.CK
dataPk.FdNum = runData.FdNum
dataPk.Data64 = base64.StdEncoding.EncodeToString(chunk)
dataPk.Eof = (len(chunk) == len(sendBuf))
sendBuf = sendBuf[chunkSize:]
sender.SendPacket(dataPk)
}
}
sender.SendPacket(packet.MakeDataEndPacket(runPacket.CK))
}
func DetectGoArch(uname string) (string, string, error) {
fields := strings.SplitN(uname, "|", 2)
if len(fields) != 2 {
@ -683,6 +724,16 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
return nil, err
}
extraFiles := make([]*os.File, 0, MaxFdNum+1)
for _, runData := range pk.RunData {
if runData.FdNum >= len(extraFiles) {
extraFiles = extraFiles[:runData.FdNum+1]
}
extraFiles[runData.FdNum], err = cmd.Multiplexer.MakeStaticWriterPipe(runData.FdNum, runData.Data)
if err != nil {
cmd.Close()
return nil, err
}
}
for _, rfd := range pk.Fds {
if rfd.FdNum >= len(extraFiles) {
extraFiles = extraFiles[:rfd.FdNum+1]