checkpoint. transfer binary data as base64. handle cwd. detect open fds. working to transfer data in non-error cases.

This commit is contained in:
sawka 2022-06-24 23:42:00 -07:00
parent 5223760a76
commit e6776bd974
7 changed files with 202 additions and 69 deletions

View File

@ -20,6 +20,7 @@ import (
"github.com/scripthaus-dev/mshell/pkg/cmdtail" "github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec" "github.com/scripthaus-dev/mshell/pkg/shexec"
"golang.org/x/sys/unix"
) )
const MShellVersion = "0.1.0" const MShellVersion = "0.1.0"
@ -256,8 +257,26 @@ func handleRemote() {
func handleServer() { func handleServer() {
} }
func detectOpenFds() { func detectOpenFds() ([]packet.RemoteFd, error) {
var fds []packet.RemoteFd
for fdNum := 3; fdNum <= 64; fdNum++ {
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
if err != nil {
continue
}
flags = flags & 3
rfd := packet.RemoteFd{FdNum: fdNum}
if flags&2 == 2 {
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
}
if flags&1 == 1 {
rfd.Write = true
} else {
rfd.Read = true
}
fds = append(fds, rfd)
}
return fds, nil
} }
func parseClientOpts() (*shexec.ClientOpts, error) { func parseClientOpts() (*shexec.ClientOpts, error) {
@ -272,6 +291,13 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
opts.IsSSH = true opts.IsSSH = true
break break
} }
if argStr == "--cwd" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
opts.Cwd = iter.Next()
continue
}
} }
if opts.IsSSH { if opts.IsSSH {
// parse SSH opts // parse SSH opts
@ -281,11 +307,6 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
opts.SSHOptsTerm = true opts.SSHOptsTerm = true
break break
} }
if argStr == "--cwd" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
}
opts.SSHOpts = append(opts.SSHOpts, argStr) opts.SSHOpts = append(opts.SSHOpts, argStr)
} }
if !opts.SSHOptsTerm { if !opts.SSHOptsTerm {
@ -310,6 +331,11 @@ func handleClient() (int, error) {
if !opts.IsSSH { if !opts.IsSSH {
return 1, fmt.Errorf("when running in client mode '--ssh' option must be present") return 1, fmt.Errorf("when running in client mode '--ssh' option must be present")
} }
fds, err := detectOpenFds()
if err != nil {
return 1, err
}
opts.Fds = fds
donePacket, err := shexec.RunClientSSHCommandAndWait(opts) donePacket, err := shexec.RunClientSSHCommandAndWait(opts)
if err != nil { if err != nil {
return 1, err return 1, err

View File

@ -14,6 +14,7 @@ import (
"os/exec" "os/exec"
"path" "path"
"path/filepath" "path/filepath"
"strings"
) )
const DefaultMShellPath = "mshell" const DefaultMShellPath = "mshell"
@ -176,3 +177,14 @@ func WriteErrorMsg(fileName string, errVal string) error {
_, writeErr := fd.Write([]byte(oscEsc)) _, writeErr := fd.Write([]byte(oscEsc))
return writeErr return writeErr
} }
func ExpandHomeDir(pathStr string) string {
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
return pathStr
}
homeDir := GetHomeDir()
if pathStr == "~" {
return homeDir
}
return path.Join(homeDir, pathStr[2:])
}

View File

@ -48,6 +48,12 @@ func (r *FdReader) Close() {
r.CVar.Broadcast() r.CVar.Broadcast()
} }
func (r *FdReader) GetBufSize() int {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
return r.BufSize
}
func (r *FdReader) NotifyAck(ackLen int) { func (r *FdReader) NotifyAck(ackLen int) {
r.CVar.L.Lock() r.CVar.L.Lock()
defer r.CVar.L.Unlock() defer r.CVar.L.Unlock()

View File

@ -64,15 +64,15 @@ func (w *FdWriter) WaitForData() ([]byte, bool) {
func (w *FdWriter) AddData(data []byte, eof bool) error { func (w *FdWriter) AddData(data []byte, eof bool) error {
w.CVar.L.Lock() w.CVar.L.Lock()
defer w.CVar.L.Unlock() defer w.CVar.L.Unlock()
if w.Closed { if w.Closed || w.Eof {
return fmt.Errorf("write to closed file") if len(data) == 0 {
} return nil
if w.Eof { }
return fmt.Errorf("write to closed file (eof)") return fmt.Errorf("write to closed file eof[%v]", w.Eof)
} }
if len(data) > 0 { if len(data) > 0 {
if len(data)+len(w.Buffer) > WriteBufSize { if len(data)+len(w.Buffer) > WriteBufSize {
return fmt.Errorf("write exceeds buffer size") return fmt.Errorf("write exceeds buffer size bufsize=%d (max=%d)", len(data)+len(w.Buffer), WriteBufSize)
} }
w.Buffer = append(w.Buffer, data...) w.Buffer = append(w.Buffer, data...)
} }

View File

@ -7,6 +7,7 @@
package mpio package mpio
import ( import (
"encoding/base64"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@ -14,8 +15,8 @@ import (
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
) )
const ReadBufSize = 128 * 1024 const ReadBufSize = 32 * 1024
const WriteBufSize = 128 * 1024 const WriteBufSize = 32 * 1024
const MaxSingleWriteSize = 4 * 1024 const MaxSingleWriteSize = 4 * 1024
type Multiplexer struct { type Multiplexer struct {
@ -29,6 +30,8 @@ type Multiplexer struct {
Sender *packet.PacketSender Sender *packet.PacketSender
Input chan packet.PacketType Input chan packet.PacketType
Started bool Started bool
Debug bool
} }
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer { func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
@ -97,16 +100,16 @@ func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
return pr, nil return pr, nil
} }
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File) { func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) {
m.Lock.Lock() m.Lock.Lock()
defer m.Lock.Unlock() defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, false) m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
} }
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File) { func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) {
m.Lock.Lock() m.Lock.Lock()
defer m.Lock.Unlock() defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, false) m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)
} }
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType { func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
@ -126,7 +129,7 @@ func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.
pk.SessionId = m.SessionId pk.SessionId = m.SessionId
pk.CmdId = m.CmdId pk.CmdId = m.CmdId
pk.FdNum = fdNum pk.FdNum = fdNum
pk.Data = string(data) pk.Data64 = base64.StdEncoding.EncodeToString(data)
if err != nil { if err != nil {
pk.Error = err.Error() pk.Error = err.Error()
} }
@ -173,6 +176,9 @@ func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.Pa
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType { func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
defer m.HandleInputDone() defer m.HandleInputDone()
for pk := range m.Input { for pk := range m.Input {
if m.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.DataPacketStr { if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType) dataPacket := pk.(*packet.DataPacketType)
err := m.processDataPacket(dataPacket) err := m.processDataPacket(dataPacket)
@ -191,12 +197,26 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
donePacket := pk.(*packet.CmdDonePacketType) donePacket := pk.(*packet.CmdDonePacketType)
return donePacket return donePacket
} }
// other packet types are ignored if pk.GetType() == packet.ErrorPacketStr {
errPacket := pk.(*packet.ErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
return nil
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
continue
}
} }
return nil return nil
} }
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error { func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
if err != nil {
return fmt.Errorf("decoding base64 data: %w", err)
}
m.Lock.Lock() m.Lock.Lock()
defer m.Lock.Unlock() defer m.Lock.Unlock()
fw := m.FdWriters[dataPacket.FdNum] fw := m.FdWriters[dataPacket.FdNum]
@ -205,9 +225,9 @@ func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error
fw := MakeFdWriter(m, nil, dataPacket.FdNum, false) fw := MakeFdWriter(m, nil, dataPacket.FdNum, false)
fw.Close() fw.Close()
m.FdWriters[dataPacket.FdNum] = fw m.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file") return fmt.Errorf("write to closed file (no fd)")
} }
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof) err = fw.AddData(realData, dataPacket.Eof)
if err != nil { if err != nil {
fw.Close() fw.Close()
return err return err

View File

@ -121,7 +121,7 @@ type DataPacketType struct {
SessionId string `json:"sessionid,omitempty"` SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"` CmdId string `json:"cmdid,omitempty"`
FdNum int `json:"fdnum"` FdNum int `json:"fdnum"`
Data string `json:"data"` Data64 string `json:"data64"` // base64 encoded
Eof bool `json:"eof,omitempty"` Eof bool `json:"eof,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
@ -130,6 +130,32 @@ func (*DataPacketType) GetType() string {
return DataPacketStr return DataPacketStr
} }
func B64DecodedLen(b64 string) int {
if len(b64) < 4 {
return 0 // we use padded strings, so < 4 is always 0
}
realLen := 3 * (len(b64) / 4)
if b64[len(b64)-1] == '=' {
realLen--
}
if b64[len(b64)-2] == '=' {
realLen--
}
return realLen
}
func (p *DataPacketType) String() string {
eofStr := ""
if p.Eof {
eofStr = ", eof"
}
errStr := ""
if p.Error != "" {
errStr = fmt.Sprintf(", err=%s", p.Error)
}
return fmt.Sprintf("data[fd=%d, len=%d%s%s]", p.FdNum, B64DecodedLen(p.Data64), eofStr, errStr)
}
func MakeDataPacket() *DataPacketType { func MakeDataPacket() *DataPacketType {
return &DataPacketType{Type: DataPacketStr} return &DataPacketType{Type: DataPacketStr}
} }
@ -140,13 +166,21 @@ type DataAckPacketType struct {
CmdId string `json:"cmdid,omitempty"` CmdId string `json:"cmdid,omitempty"`
FdNum int `json:"fdnum"` FdNum int `json:"fdnum"`
AckLen int `json:"acklen"` AckLen int `json:"acklen"`
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
func (*DataAckPacketType) GetType() string { func (*DataAckPacketType) GetType() string {
return DataAckPacketStr return DataAckPacketStr
} }
func (p *DataAckPacketType) String() string {
errStr := ""
if p.Error != "" {
errStr = fmt.Sprintf(" err=%s", p.Error)
}
return fmt.Sprintf("ack[fd=%d, acklen=%d%s]", p.FdNum, p.AckLen, errStr)
}
func MakeDataAckPacket() *DataAckPacketType { func MakeDataAckPacket() *DataAckPacketType {
return &DataAckPacketType{Type: DataAckPacketStr} return &DataAckPacketType{Type: DataAckPacketStr}
} }
@ -252,6 +286,10 @@ func (*RawPacketType) GetType() string {
return RawPacketStr return RawPacketStr
} }
func (p *RawPacketType) String() string {
return fmt.Sprintf("raw[%s]", p.Data)
}
func MakeRawPacket(val string) *RawPacketType { func MakeRawPacket(val string) *RawPacketType {
return &RawPacketType{Type: RawPacketStr, Data: val} return &RawPacketType{Type: RawPacketStr, Data: val}
} }
@ -265,6 +303,10 @@ func (*MessagePacketType) GetType() string {
return MessagePacketStr return MessagePacketStr
} }
func (p *MessagePacketType) String() string {
return fmt.Sprintf("messsage[%s]", p.Message)
}
func MakeMessagePacket(message string) *MessagePacketType { func MakeMessagePacket(message string) *MessagePacketType {
return &MessagePacketType{Type: MessagePacketStr, Message: message} return &MessagePacketType{Type: MessagePacketStr, Message: message}
} }
@ -394,6 +436,13 @@ type PacketType interface {
GetType() string GetType() string
} }
func AsString(pk PacketType) string {
if s, ok := pk.(fmt.Stringer); ok {
return s.String()
}
return fmt.Sprintf("%s[]", pk.GetType())
}
type RpcPacketType interface { type RpcPacketType interface {
GetType() string GetType() string
GetPacketId() string GetPacketId() string
@ -433,8 +482,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
outBuf.Write(jsonBytes) outBuf.Write(jsonBytes)
outBuf.WriteByte('\n') outBuf.WriteByte('\n')
if GlobalDebug { if GlobalDebug {
outBytes := outBuf.Bytes() fmt.Printf("SEND> %s\n", AsString(packet))
fmt.Printf("SEND>%s", string(outBytes[1:]))
} }
_, err = w.Write(outBuf.Bytes()) _, err = w.Write(outBuf.Bytes())
if err != nil { if err != nil {
@ -519,12 +567,14 @@ func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) erro
} }
func PacketParser(input io.Reader) chan PacketType { func PacketParser(input io.Reader) chan PacketType {
bufReader := bufio.NewReader(input)
rtnCh := make(chan PacketType) rtnCh := make(chan PacketType)
PacketParserAttach(input, rtnCh)
return rtnCh
}
func PacketParserAttach(input io.Reader, rtnCh chan PacketType) {
bufReader := bufio.NewReader(input)
go func() { go func() {
defer func() {
close(rtnCh)
}()
for { for {
line, err := bufReader.ReadString('\n') line, err := bufReader.ReadString('\n')
if err == io.EOF { if err == io.EOF {
@ -562,7 +612,6 @@ func PacketParser(input io.Reader) chan PacketType {
rtnCh <- pk rtnCh <- pk
} }
}() }()
return rtnCh
} }
type ErrorReporter interface { type ErrorReporter interface {

View File

@ -111,7 +111,7 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
ecmd := exec.Command("bash", "-c", pk.Command) ecmd := exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(ecmd, pk.Env) UpdateCmdEnv(ecmd, pk.Env)
if pk.Cwd != "" { if pk.Cwd != "" {
ecmd.Dir = pk.Cwd ecmd.Dir = base.ExpandHomeDir(pk.Cwd)
} }
ecmd.Stdin = cmdTty ecmd.Stdin = cmdTty
ecmd.Stdout = cmdTty ecmd.Stdout = cmdTty
@ -175,12 +175,13 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
} }
} }
if pk.Cwd != "" { if pk.Cwd != "" {
dirInfo, err := os.Stat(pk.Cwd) realCwd := base.ExpandHomeDir(pk.Cwd)
dirInfo, err := os.Stat(realCwd)
if err != nil { if err != nil {
return fmt.Errorf("invalid cwd '%s' for command: %v", pk.Cwd, err) return fmt.Errorf("invalid cwd '%s' for command: %v", realCwd, err)
} }
if !dirInfo.IsDir() { if !dirInfo.IsDir() {
return fmt.Errorf("invalid cwd '%s' for command, not a directory", pk.Cwd) return fmt.Errorf("invalid cwd '%s' for command, not a directory", realCwd)
} }
} }
return nil return nil
@ -228,8 +229,37 @@ func (opts *ClientOpts) MakeRunPacket() *packet.RunPacketType {
return runPacket return runPacket
} }
func ValidateRemoteFds(rfds []packet.RemoteFd) error {
dupMap := make(map[int]bool)
for _, rfd := range rfds {
if rfd.FdNum < 0 {
return fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
}
if rfd.FdNum < FirstExtraFilesFdNum {
return fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
}
if rfd.FdNum > MaxFdNum {
return fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
}
if dupMap[rfd.FdNum] {
return fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
}
if rfd.Read && rfd.Write {
return fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
}
if !rfd.Read && !rfd.Write {
return fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
}
dupMap[rfd.FdNum] = true
}
return nil
}
func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) { func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) {
// packet.GlobalDebug = true err := ValidateRemoteFds(opts.Fds)
if err != nil {
return nil, err
}
cmd := MakeShExec("", "") cmd := MakeShExec("", "")
sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote` sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote`
var fullSshOpts []string var fullSshOpts []string
@ -249,15 +279,27 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
if err != nil { if err != nil {
return nil, fmt.Errorf("creating stderr pipe: %v", err) return nil, fmt.Errorf("creating stderr pipe: %v", err)
} }
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false)
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false)
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false)
for _, rfd := range opts.Fds {
fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum))
if fd == nil {
return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum)
}
if rfd.Read {
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true)
} else if rfd.Write {
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true)
}
}
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
return nil, fmt.Errorf("running ssh command: %w", err) return nil, fmt.Errorf("running ssh command: %w", err)
} }
defer cmd.Close() defer cmd.Close()
packetCh := packet.PacketParser(stdoutReader) packetCh := packet.PacketParser(stdoutReader)
go func() { packet.PacketParserAttach(stderrReader, packetCh)
io.Copy(os.Stderr, stderrReader)
}()
sender := packet.MakePacketSender(inputWriter) sender := packet.MakePacketSender(inputWriter)
for pk := range packetCh { for pk := range packetCh {
if pk.GetType() == packet.RawPacketStr { if pk.GetType() == packet.RawPacketStr {
@ -275,9 +317,6 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
} }
runPacket := opts.MakeRunPacket() runPacket := opts.MakeRunPacket()
sender.SendPacket(runPacket) sender.SendPacket(runPacket)
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin)
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout)
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr)
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true) remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true)
donePacket := cmd.WaitForCommand() donePacket := cmd.WaitForCommand()
if remoteDonePacket != nil { if remoteDonePacket != nil {
@ -287,6 +326,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
} }
func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) { func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
defer cmd.Close()
cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false) cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false)
donePacket := cmd.WaitForCommand() donePacket := cmd.WaitForCommand()
sender.SendPacket(donePacket) sender.SendPacket(donePacket)
@ -297,9 +337,13 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
cmd.Cmd = exec.Command("bash", "-c", pk.Command) cmd.Cmd = exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(cmd.Cmd, pk.Env) UpdateCmdEnv(cmd.Cmd, pk.Env)
if pk.Cwd != "" { if pk.Cwd != "" {
cmd.Cmd.Dir = pk.Cwd cmd.Cmd.Dir = base.ExpandHomeDir(pk.Cwd)
}
err := ValidateRemoteFds(pk.Fds)
if err != nil {
cmd.Close()
return nil, err
} }
var err error
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0) cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
if err != nil { if err != nil {
cmd.Close() cmd.Close()
@ -317,33 +361,9 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
} }
extraFiles := make([]*os.File, 0, MaxFdNum+1) extraFiles := make([]*os.File, 0, MaxFdNum+1)
for _, rfd := range pk.Fds { for _, rfd := range pk.Fds {
if rfd.FdNum < 0 {
cmd.Close()
return nil, fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
}
if rfd.FdNum < FirstExtraFilesFdNum {
cmd.Close()
return nil, fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
}
if rfd.FdNum > MaxFdNum {
cmd.Close()
return nil, fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
}
if rfd.FdNum >= len(extraFiles) { if rfd.FdNum >= len(extraFiles) {
extraFiles = extraFiles[:rfd.FdNum+1] extraFiles = extraFiles[:rfd.FdNum+1]
} }
if extraFiles[rfd.FdNum] != nil {
cmd.Close()
return nil, fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
}
if rfd.Read && rfd.Write {
cmd.Close()
return nil, fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
}
if !rfd.Read && !rfd.Write {
cmd.Close()
return nil, fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
}
if rfd.Read { if rfd.Read {
// client file is open for reading, so we make a writer pipe // client file is open for reading, so we make a writer pipe
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum) extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)