mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-21 21:32:13 +01:00
checkpoint. transfer binary data as base64. handle cwd. detect open fds. working to transfer data in non-error cases.
This commit is contained in:
parent
5223760a76
commit
e6776bd974
@ -20,6 +20,7 @@ import (
|
||||
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const MShellVersion = "0.1.0"
|
||||
@ -256,8 +257,26 @@ func handleRemote() {
|
||||
func handleServer() {
|
||||
}
|
||||
|
||||
func detectOpenFds() {
|
||||
|
||||
func detectOpenFds() ([]packet.RemoteFd, error) {
|
||||
var fds []packet.RemoteFd
|
||||
for fdNum := 3; fdNum <= 64; fdNum++ {
|
||||
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
flags = flags & 3
|
||||
rfd := packet.RemoteFd{FdNum: fdNum}
|
||||
if flags&2 == 2 {
|
||||
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
|
||||
}
|
||||
if flags&1 == 1 {
|
||||
rfd.Write = true
|
||||
} else {
|
||||
rfd.Read = true
|
||||
}
|
||||
fds = append(fds, rfd)
|
||||
}
|
||||
return fds, nil
|
||||
}
|
||||
|
||||
func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
@ -272,6 +291,13 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
opts.IsSSH = true
|
||||
break
|
||||
}
|
||||
if argStr == "--cwd" {
|
||||
if !iter.HasNext() {
|
||||
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
|
||||
}
|
||||
opts.Cwd = iter.Next()
|
||||
continue
|
||||
}
|
||||
}
|
||||
if opts.IsSSH {
|
||||
// parse SSH opts
|
||||
@ -281,11 +307,6 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
opts.SSHOptsTerm = true
|
||||
break
|
||||
}
|
||||
if argStr == "--cwd" {
|
||||
if !iter.HasNext() {
|
||||
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
|
||||
}
|
||||
}
|
||||
opts.SSHOpts = append(opts.SSHOpts, argStr)
|
||||
}
|
||||
if !opts.SSHOptsTerm {
|
||||
@ -310,6 +331,11 @@ func handleClient() (int, error) {
|
||||
if !opts.IsSSH {
|
||||
return 1, fmt.Errorf("when running in client mode '--ssh' option must be present")
|
||||
}
|
||||
fds, err := detectOpenFds()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
opts.Fds = fds
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(opts)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const DefaultMShellPath = "mshell"
|
||||
@ -176,3 +177,14 @@ func WriteErrorMsg(fileName string, errVal string) error {
|
||||
_, writeErr := fd.Write([]byte(oscEsc))
|
||||
return writeErr
|
||||
}
|
||||
|
||||
func ExpandHomeDir(pathStr string) string {
|
||||
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
|
||||
return pathStr
|
||||
}
|
||||
homeDir := GetHomeDir()
|
||||
if pathStr == "~" {
|
||||
return homeDir
|
||||
}
|
||||
return path.Join(homeDir, pathStr[2:])
|
||||
}
|
||||
|
@ -48,6 +48,12 @@ func (r *FdReader) Close() {
|
||||
r.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (r *FdReader) GetBufSize() int {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
return r.BufSize
|
||||
}
|
||||
|
||||
func (r *FdReader) NotifyAck(ackLen int) {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
|
@ -64,15 +64,15 @@ func (w *FdWriter) WaitForData() ([]byte, bool) {
|
||||
func (w *FdWriter) AddData(data []byte, eof bool) error {
|
||||
w.CVar.L.Lock()
|
||||
defer w.CVar.L.Unlock()
|
||||
if w.Closed {
|
||||
return fmt.Errorf("write to closed file")
|
||||
}
|
||||
if w.Eof {
|
||||
return fmt.Errorf("write to closed file (eof)")
|
||||
if w.Closed || w.Eof {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("write to closed file eof[%v]", w.Eof)
|
||||
}
|
||||
if len(data) > 0 {
|
||||
if len(data)+len(w.Buffer) > WriteBufSize {
|
||||
return fmt.Errorf("write exceeds buffer size")
|
||||
return fmt.Errorf("write exceeds buffer size bufsize=%d (max=%d)", len(data)+len(w.Buffer), WriteBufSize)
|
||||
}
|
||||
w.Buffer = append(w.Buffer, data...)
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
package mpio
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@ -14,8 +15,8 @@ import (
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
|
||||
const ReadBufSize = 128 * 1024
|
||||
const WriteBufSize = 128 * 1024
|
||||
const ReadBufSize = 32 * 1024
|
||||
const WriteBufSize = 32 * 1024
|
||||
const MaxSingleWriteSize = 4 * 1024
|
||||
|
||||
type Multiplexer struct {
|
||||
@ -29,6 +30,8 @@ type Multiplexer struct {
|
||||
Sender *packet.PacketSender
|
||||
Input chan packet.PacketType
|
||||
Started bool
|
||||
|
||||
Debug bool
|
||||
}
|
||||
|
||||
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
|
||||
@ -97,16 +100,16 @@ func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File) {
|
||||
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, false)
|
||||
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File) {
|
||||
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, false)
|
||||
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
|
||||
@ -126,7 +129,7 @@ func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.
|
||||
pk.SessionId = m.SessionId
|
||||
pk.CmdId = m.CmdId
|
||||
pk.FdNum = fdNum
|
||||
pk.Data = string(data)
|
||||
pk.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||
if err != nil {
|
||||
pk.Error = err.Error()
|
||||
}
|
||||
@ -173,6 +176,9 @@ func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.Pa
|
||||
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||
defer m.HandleInputDone()
|
||||
for pk := range m.Input {
|
||||
if m.Debug {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
err := m.processDataPacket(dataPacket)
|
||||
@ -191,12 +197,26 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||
donePacket := pk.(*packet.CmdDonePacketType)
|
||||
return donePacket
|
||||
}
|
||||
// other packet types are ignored
|
||||
if pk.GetType() == packet.ErrorPacketStr {
|
||||
errPacket := pk.(*packet.ErrorPacketType)
|
||||
// at this point, just send the error packet to stderr rather than try to do something special
|
||||
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
|
||||
return nil
|
||||
}
|
||||
if pk.GetType() == packet.RawPacketStr {
|
||||
rawPacket := pk.(*packet.RawPacketType)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
|
||||
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoding base64 data: %w", err)
|
||||
}
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
fw := m.FdWriters[dataPacket.FdNum]
|
||||
@ -205,9 +225,9 @@ func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error
|
||||
fw := MakeFdWriter(m, nil, dataPacket.FdNum, false)
|
||||
fw.Close()
|
||||
m.FdWriters[dataPacket.FdNum] = fw
|
||||
return fmt.Errorf("write to closed file")
|
||||
return fmt.Errorf("write to closed file (no fd)")
|
||||
}
|
||||
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
|
||||
err = fw.AddData(realData, dataPacket.Eof)
|
||||
if err != nil {
|
||||
fw.Close()
|
||||
return err
|
||||
|
@ -121,7 +121,7 @@ type DataPacketType struct {
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
FdNum int `json:"fdnum"`
|
||||
Data string `json:"data"`
|
||||
Data64 string `json:"data64"` // base64 encoded
|
||||
Eof bool `json:"eof,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
@ -130,6 +130,32 @@ func (*DataPacketType) GetType() string {
|
||||
return DataPacketStr
|
||||
}
|
||||
|
||||
func B64DecodedLen(b64 string) int {
|
||||
if len(b64) < 4 {
|
||||
return 0 // we use padded strings, so < 4 is always 0
|
||||
}
|
||||
realLen := 3 * (len(b64) / 4)
|
||||
if b64[len(b64)-1] == '=' {
|
||||
realLen--
|
||||
}
|
||||
if b64[len(b64)-2] == '=' {
|
||||
realLen--
|
||||
}
|
||||
return realLen
|
||||
}
|
||||
|
||||
func (p *DataPacketType) String() string {
|
||||
eofStr := ""
|
||||
if p.Eof {
|
||||
eofStr = ", eof"
|
||||
}
|
||||
errStr := ""
|
||||
if p.Error != "" {
|
||||
errStr = fmt.Sprintf(", err=%s", p.Error)
|
||||
}
|
||||
return fmt.Sprintf("data[fd=%d, len=%d%s%s]", p.FdNum, B64DecodedLen(p.Data64), eofStr, errStr)
|
||||
}
|
||||
|
||||
func MakeDataPacket() *DataPacketType {
|
||||
return &DataPacketType{Type: DataPacketStr}
|
||||
}
|
||||
@ -140,13 +166,21 @@ type DataAckPacketType struct {
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
FdNum int `json:"fdnum"`
|
||||
AckLen int `json:"acklen"`
|
||||
Error string `json:"error"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (*DataAckPacketType) GetType() string {
|
||||
return DataAckPacketStr
|
||||
}
|
||||
|
||||
func (p *DataAckPacketType) String() string {
|
||||
errStr := ""
|
||||
if p.Error != "" {
|
||||
errStr = fmt.Sprintf(" err=%s", p.Error)
|
||||
}
|
||||
return fmt.Sprintf("ack[fd=%d, acklen=%d%s]", p.FdNum, p.AckLen, errStr)
|
||||
}
|
||||
|
||||
func MakeDataAckPacket() *DataAckPacketType {
|
||||
return &DataAckPacketType{Type: DataAckPacketStr}
|
||||
}
|
||||
@ -252,6 +286,10 @@ func (*RawPacketType) GetType() string {
|
||||
return RawPacketStr
|
||||
}
|
||||
|
||||
func (p *RawPacketType) String() string {
|
||||
return fmt.Sprintf("raw[%s]", p.Data)
|
||||
}
|
||||
|
||||
func MakeRawPacket(val string) *RawPacketType {
|
||||
return &RawPacketType{Type: RawPacketStr, Data: val}
|
||||
}
|
||||
@ -265,6 +303,10 @@ func (*MessagePacketType) GetType() string {
|
||||
return MessagePacketStr
|
||||
}
|
||||
|
||||
func (p *MessagePacketType) String() string {
|
||||
return fmt.Sprintf("messsage[%s]", p.Message)
|
||||
}
|
||||
|
||||
func MakeMessagePacket(message string) *MessagePacketType {
|
||||
return &MessagePacketType{Type: MessagePacketStr, Message: message}
|
||||
}
|
||||
@ -394,6 +436,13 @@ type PacketType interface {
|
||||
GetType() string
|
||||
}
|
||||
|
||||
func AsString(pk PacketType) string {
|
||||
if s, ok := pk.(fmt.Stringer); ok {
|
||||
return s.String()
|
||||
}
|
||||
return fmt.Sprintf("%s[]", pk.GetType())
|
||||
}
|
||||
|
||||
type RpcPacketType interface {
|
||||
GetType() string
|
||||
GetPacketId() string
|
||||
@ -433,8 +482,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
|
||||
outBuf.Write(jsonBytes)
|
||||
outBuf.WriteByte('\n')
|
||||
if GlobalDebug {
|
||||
outBytes := outBuf.Bytes()
|
||||
fmt.Printf("SEND>%s", string(outBytes[1:]))
|
||||
fmt.Printf("SEND> %s\n", AsString(packet))
|
||||
}
|
||||
_, err = w.Write(outBuf.Bytes())
|
||||
if err != nil {
|
||||
@ -519,12 +567,14 @@ func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) erro
|
||||
}
|
||||
|
||||
func PacketParser(input io.Reader) chan PacketType {
|
||||
bufReader := bufio.NewReader(input)
|
||||
rtnCh := make(chan PacketType)
|
||||
PacketParserAttach(input, rtnCh)
|
||||
return rtnCh
|
||||
}
|
||||
|
||||
func PacketParserAttach(input io.Reader, rtnCh chan PacketType) {
|
||||
bufReader := bufio.NewReader(input)
|
||||
go func() {
|
||||
defer func() {
|
||||
close(rtnCh)
|
||||
}()
|
||||
for {
|
||||
line, err := bufReader.ReadString('\n')
|
||||
if err == io.EOF {
|
||||
@ -562,7 +612,6 @@ func PacketParser(input io.Reader) chan PacketType {
|
||||
rtnCh <- pk
|
||||
}
|
||||
}()
|
||||
return rtnCh
|
||||
}
|
||||
|
||||
type ErrorReporter interface {
|
||||
|
@ -111,7 +111,7 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
|
||||
ecmd := exec.Command("bash", "-c", pk.Command)
|
||||
UpdateCmdEnv(ecmd, pk.Env)
|
||||
if pk.Cwd != "" {
|
||||
ecmd.Dir = pk.Cwd
|
||||
ecmd.Dir = base.ExpandHomeDir(pk.Cwd)
|
||||
}
|
||||
ecmd.Stdin = cmdTty
|
||||
ecmd.Stdout = cmdTty
|
||||
@ -175,12 +175,13 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
|
||||
}
|
||||
}
|
||||
if pk.Cwd != "" {
|
||||
dirInfo, err := os.Stat(pk.Cwd)
|
||||
realCwd := base.ExpandHomeDir(pk.Cwd)
|
||||
dirInfo, err := os.Stat(realCwd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid cwd '%s' for command: %v", pk.Cwd, err)
|
||||
return fmt.Errorf("invalid cwd '%s' for command: %v", realCwd, err)
|
||||
}
|
||||
if !dirInfo.IsDir() {
|
||||
return fmt.Errorf("invalid cwd '%s' for command, not a directory", pk.Cwd)
|
||||
return fmt.Errorf("invalid cwd '%s' for command, not a directory", realCwd)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -228,8 +229,37 @@ func (opts *ClientOpts) MakeRunPacket() *packet.RunPacketType {
|
||||
return runPacket
|
||||
}
|
||||
|
||||
func ValidateRemoteFds(rfds []packet.RemoteFd) error {
|
||||
dupMap := make(map[int]bool)
|
||||
for _, rfd := range rfds {
|
||||
if rfd.FdNum < 0 {
|
||||
return fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
|
||||
}
|
||||
if rfd.FdNum < FirstExtraFilesFdNum {
|
||||
return fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
|
||||
}
|
||||
if rfd.FdNum > MaxFdNum {
|
||||
return fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
|
||||
}
|
||||
if dupMap[rfd.FdNum] {
|
||||
return fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
|
||||
}
|
||||
if rfd.Read && rfd.Write {
|
||||
return fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
|
||||
}
|
||||
if !rfd.Read && !rfd.Write {
|
||||
return fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
|
||||
}
|
||||
dupMap[rfd.FdNum] = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) {
|
||||
// packet.GlobalDebug = true
|
||||
err := ValidateRemoteFds(opts.Fds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := MakeShExec("", "")
|
||||
sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote`
|
||||
var fullSshOpts []string
|
||||
@ -249,15 +279,27 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false)
|
||||
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false)
|
||||
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false)
|
||||
for _, rfd := range opts.Fds {
|
||||
fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum))
|
||||
if fd == nil {
|
||||
return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum)
|
||||
}
|
||||
if rfd.Read {
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true)
|
||||
} else if rfd.Write {
|
||||
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true)
|
||||
}
|
||||
}
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("running ssh command: %w", err)
|
||||
}
|
||||
defer cmd.Close()
|
||||
packetCh := packet.PacketParser(stdoutReader)
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stderrReader)
|
||||
}()
|
||||
packet.PacketParserAttach(stderrReader, packetCh)
|
||||
sender := packet.MakePacketSender(inputWriter)
|
||||
for pk := range packetCh {
|
||||
if pk.GetType() == packet.RawPacketStr {
|
||||
@ -275,9 +317,6 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
||||
}
|
||||
runPacket := opts.MakeRunPacket()
|
||||
sender.SendPacket(runPacket)
|
||||
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin)
|
||||
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout)
|
||||
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr)
|
||||
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true)
|
||||
donePacket := cmd.WaitForCommand()
|
||||
if remoteDonePacket != nil {
|
||||
@ -287,6 +326,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
||||
}
|
||||
|
||||
func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
|
||||
defer cmd.Close()
|
||||
cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false)
|
||||
donePacket := cmd.WaitForCommand()
|
||||
sender.SendPacket(donePacket)
|
||||
@ -297,9 +337,13 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
|
||||
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
|
||||
UpdateCmdEnv(cmd.Cmd, pk.Env)
|
||||
if pk.Cwd != "" {
|
||||
cmd.Cmd.Dir = pk.Cwd
|
||||
cmd.Cmd.Dir = base.ExpandHomeDir(pk.Cwd)
|
||||
}
|
||||
err := ValidateRemoteFds(pk.Fds)
|
||||
if err != nil {
|
||||
cmd.Close()
|
||||
return nil, err
|
||||
}
|
||||
var err error
|
||||
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
|
||||
if err != nil {
|
||||
cmd.Close()
|
||||
@ -317,33 +361,9 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
|
||||
}
|
||||
extraFiles := make([]*os.File, 0, MaxFdNum+1)
|
||||
for _, rfd := range pk.Fds {
|
||||
if rfd.FdNum < 0 {
|
||||
cmd.Close()
|
||||
return nil, fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
|
||||
}
|
||||
if rfd.FdNum < FirstExtraFilesFdNum {
|
||||
cmd.Close()
|
||||
return nil, fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
|
||||
}
|
||||
if rfd.FdNum > MaxFdNum {
|
||||
cmd.Close()
|
||||
return nil, fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
|
||||
}
|
||||
if rfd.FdNum >= len(extraFiles) {
|
||||
extraFiles = extraFiles[:rfd.FdNum+1]
|
||||
}
|
||||
if extraFiles[rfd.FdNum] != nil {
|
||||
cmd.Close()
|
||||
return nil, fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
|
||||
}
|
||||
if rfd.Read && rfd.Write {
|
||||
cmd.Close()
|
||||
return nil, fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
|
||||
}
|
||||
if !rfd.Read && !rfd.Write {
|
||||
cmd.Close()
|
||||
return nil, fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
|
||||
}
|
||||
if rfd.Read {
|
||||
// client file is open for reading, so we make a writer pipe
|
||||
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)
|
||||
|
Loading…
Reference in New Issue
Block a user