mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-21 21:32:13 +01:00
checkpoint. transfer binary data as base64. handle cwd. detect open fds. working to transfer data in non-error cases.
This commit is contained in:
parent
5223760a76
commit
e6776bd974
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
|
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
|
||||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MShellVersion = "0.1.0"
|
const MShellVersion = "0.1.0"
|
||||||
@ -256,8 +257,26 @@ func handleRemote() {
|
|||||||
func handleServer() {
|
func handleServer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func detectOpenFds() {
|
func detectOpenFds() ([]packet.RemoteFd, error) {
|
||||||
|
var fds []packet.RemoteFd
|
||||||
|
for fdNum := 3; fdNum <= 64; fdNum++ {
|
||||||
|
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
flags = flags & 3
|
||||||
|
rfd := packet.RemoteFd{FdNum: fdNum}
|
||||||
|
if flags&2 == 2 {
|
||||||
|
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
|
||||||
|
}
|
||||||
|
if flags&1 == 1 {
|
||||||
|
rfd.Write = true
|
||||||
|
} else {
|
||||||
|
rfd.Read = true
|
||||||
|
}
|
||||||
|
fds = append(fds, rfd)
|
||||||
|
}
|
||||||
|
return fds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseClientOpts() (*shexec.ClientOpts, error) {
|
func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||||
@ -272,6 +291,13 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
|
|||||||
opts.IsSSH = true
|
opts.IsSSH = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if argStr == "--cwd" {
|
||||||
|
if !iter.HasNext() {
|
||||||
|
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
|
||||||
|
}
|
||||||
|
opts.Cwd = iter.Next()
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if opts.IsSSH {
|
if opts.IsSSH {
|
||||||
// parse SSH opts
|
// parse SSH opts
|
||||||
@ -281,11 +307,6 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
|
|||||||
opts.SSHOptsTerm = true
|
opts.SSHOptsTerm = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if argStr == "--cwd" {
|
|
||||||
if !iter.HasNext() {
|
|
||||||
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
opts.SSHOpts = append(opts.SSHOpts, argStr)
|
opts.SSHOpts = append(opts.SSHOpts, argStr)
|
||||||
}
|
}
|
||||||
if !opts.SSHOptsTerm {
|
if !opts.SSHOptsTerm {
|
||||||
@ -310,6 +331,11 @@ func handleClient() (int, error) {
|
|||||||
if !opts.IsSSH {
|
if !opts.IsSSH {
|
||||||
return 1, fmt.Errorf("when running in client mode '--ssh' option must be present")
|
return 1, fmt.Errorf("when running in client mode '--ssh' option must be present")
|
||||||
}
|
}
|
||||||
|
fds, err := detectOpenFds()
|
||||||
|
if err != nil {
|
||||||
|
return 1, err
|
||||||
|
}
|
||||||
|
opts.Fds = fds
|
||||||
donePacket, err := shexec.RunClientSSHCommandAndWait(opts)
|
donePacket, err := shexec.RunClientSSHCommandAndWait(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 1, err
|
return 1, err
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMShellPath = "mshell"
|
const DefaultMShellPath = "mshell"
|
||||||
@ -176,3 +177,14 @@ func WriteErrorMsg(fileName string, errVal string) error {
|
|||||||
_, writeErr := fd.Write([]byte(oscEsc))
|
_, writeErr := fd.Write([]byte(oscEsc))
|
||||||
return writeErr
|
return writeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ExpandHomeDir(pathStr string) string {
|
||||||
|
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
|
||||||
|
return pathStr
|
||||||
|
}
|
||||||
|
homeDir := GetHomeDir()
|
||||||
|
if pathStr == "~" {
|
||||||
|
return homeDir
|
||||||
|
}
|
||||||
|
return path.Join(homeDir, pathStr[2:])
|
||||||
|
}
|
||||||
|
@ -48,6 +48,12 @@ func (r *FdReader) Close() {
|
|||||||
r.CVar.Broadcast()
|
r.CVar.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *FdReader) GetBufSize() int {
|
||||||
|
r.CVar.L.Lock()
|
||||||
|
defer r.CVar.L.Unlock()
|
||||||
|
return r.BufSize
|
||||||
|
}
|
||||||
|
|
||||||
func (r *FdReader) NotifyAck(ackLen int) {
|
func (r *FdReader) NotifyAck(ackLen int) {
|
||||||
r.CVar.L.Lock()
|
r.CVar.L.Lock()
|
||||||
defer r.CVar.L.Unlock()
|
defer r.CVar.L.Unlock()
|
||||||
|
@ -64,15 +64,15 @@ func (w *FdWriter) WaitForData() ([]byte, bool) {
|
|||||||
func (w *FdWriter) AddData(data []byte, eof bool) error {
|
func (w *FdWriter) AddData(data []byte, eof bool) error {
|
||||||
w.CVar.L.Lock()
|
w.CVar.L.Lock()
|
||||||
defer w.CVar.L.Unlock()
|
defer w.CVar.L.Unlock()
|
||||||
if w.Closed {
|
if w.Closed || w.Eof {
|
||||||
return fmt.Errorf("write to closed file")
|
if len(data) == 0 {
|
||||||
}
|
return nil
|
||||||
if w.Eof {
|
}
|
||||||
return fmt.Errorf("write to closed file (eof)")
|
return fmt.Errorf("write to closed file eof[%v]", w.Eof)
|
||||||
}
|
}
|
||||||
if len(data) > 0 {
|
if len(data) > 0 {
|
||||||
if len(data)+len(w.Buffer) > WriteBufSize {
|
if len(data)+len(w.Buffer) > WriteBufSize {
|
||||||
return fmt.Errorf("write exceeds buffer size")
|
return fmt.Errorf("write exceeds buffer size bufsize=%d (max=%d)", len(data)+len(w.Buffer), WriteBufSize)
|
||||||
}
|
}
|
||||||
w.Buffer = append(w.Buffer, data...)
|
w.Buffer = append(w.Buffer, data...)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
package mpio
|
package mpio
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@ -14,8 +15,8 @@ import (
|
|||||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ReadBufSize = 128 * 1024
|
const ReadBufSize = 32 * 1024
|
||||||
const WriteBufSize = 128 * 1024
|
const WriteBufSize = 32 * 1024
|
||||||
const MaxSingleWriteSize = 4 * 1024
|
const MaxSingleWriteSize = 4 * 1024
|
||||||
|
|
||||||
type Multiplexer struct {
|
type Multiplexer struct {
|
||||||
@ -29,6 +30,8 @@ type Multiplexer struct {
|
|||||||
Sender *packet.PacketSender
|
Sender *packet.PacketSender
|
||||||
Input chan packet.PacketType
|
Input chan packet.PacketType
|
||||||
Started bool
|
Started bool
|
||||||
|
|
||||||
|
Debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
|
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
|
||||||
@ -97,16 +100,16 @@ func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
|
|||||||
return pr, nil
|
return pr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File) {
|
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) {
|
||||||
m.Lock.Lock()
|
m.Lock.Lock()
|
||||||
defer m.Lock.Unlock()
|
defer m.Lock.Unlock()
|
||||||
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, false)
|
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File) {
|
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) {
|
||||||
m.Lock.Lock()
|
m.Lock.Lock()
|
||||||
defer m.Lock.Unlock()
|
defer m.Lock.Unlock()
|
||||||
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, false)
|
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
|
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
|
||||||
@ -126,7 +129,7 @@ func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.
|
|||||||
pk.SessionId = m.SessionId
|
pk.SessionId = m.SessionId
|
||||||
pk.CmdId = m.CmdId
|
pk.CmdId = m.CmdId
|
||||||
pk.FdNum = fdNum
|
pk.FdNum = fdNum
|
||||||
pk.Data = string(data)
|
pk.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pk.Error = err.Error()
|
pk.Error = err.Error()
|
||||||
}
|
}
|
||||||
@ -173,6 +176,9 @@ func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.Pa
|
|||||||
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||||
defer m.HandleInputDone()
|
defer m.HandleInputDone()
|
||||||
for pk := range m.Input {
|
for pk := range m.Input {
|
||||||
|
if m.Debug {
|
||||||
|
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||||
|
}
|
||||||
if pk.GetType() == packet.DataPacketStr {
|
if pk.GetType() == packet.DataPacketStr {
|
||||||
dataPacket := pk.(*packet.DataPacketType)
|
dataPacket := pk.(*packet.DataPacketType)
|
||||||
err := m.processDataPacket(dataPacket)
|
err := m.processDataPacket(dataPacket)
|
||||||
@ -191,12 +197,26 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
|||||||
donePacket := pk.(*packet.CmdDonePacketType)
|
donePacket := pk.(*packet.CmdDonePacketType)
|
||||||
return donePacket
|
return donePacket
|
||||||
}
|
}
|
||||||
// other packet types are ignored
|
if pk.GetType() == packet.ErrorPacketStr {
|
||||||
|
errPacket := pk.(*packet.ErrorPacketType)
|
||||||
|
// at this point, just send the error packet to stderr rather than try to do something special
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if pk.GetType() == packet.RawPacketStr {
|
||||||
|
rawPacket := pk.(*packet.RawPacketType)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
|
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
|
||||||
|
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decoding base64 data: %w", err)
|
||||||
|
}
|
||||||
m.Lock.Lock()
|
m.Lock.Lock()
|
||||||
defer m.Lock.Unlock()
|
defer m.Lock.Unlock()
|
||||||
fw := m.FdWriters[dataPacket.FdNum]
|
fw := m.FdWriters[dataPacket.FdNum]
|
||||||
@ -205,9 +225,9 @@ func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error
|
|||||||
fw := MakeFdWriter(m, nil, dataPacket.FdNum, false)
|
fw := MakeFdWriter(m, nil, dataPacket.FdNum, false)
|
||||||
fw.Close()
|
fw.Close()
|
||||||
m.FdWriters[dataPacket.FdNum] = fw
|
m.FdWriters[dataPacket.FdNum] = fw
|
||||||
return fmt.Errorf("write to closed file")
|
return fmt.Errorf("write to closed file (no fd)")
|
||||||
}
|
}
|
||||||
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
|
err = fw.AddData(realData, dataPacket.Eof)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fw.Close()
|
fw.Close()
|
||||||
return err
|
return err
|
||||||
|
@ -121,7 +121,7 @@ type DataPacketType struct {
|
|||||||
SessionId string `json:"sessionid,omitempty"`
|
SessionId string `json:"sessionid,omitempty"`
|
||||||
CmdId string `json:"cmdid,omitempty"`
|
CmdId string `json:"cmdid,omitempty"`
|
||||||
FdNum int `json:"fdnum"`
|
FdNum int `json:"fdnum"`
|
||||||
Data string `json:"data"`
|
Data64 string `json:"data64"` // base64 encoded
|
||||||
Eof bool `json:"eof,omitempty"`
|
Eof bool `json:"eof,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
@ -130,6 +130,32 @@ func (*DataPacketType) GetType() string {
|
|||||||
return DataPacketStr
|
return DataPacketStr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func B64DecodedLen(b64 string) int {
|
||||||
|
if len(b64) < 4 {
|
||||||
|
return 0 // we use padded strings, so < 4 is always 0
|
||||||
|
}
|
||||||
|
realLen := 3 * (len(b64) / 4)
|
||||||
|
if b64[len(b64)-1] == '=' {
|
||||||
|
realLen--
|
||||||
|
}
|
||||||
|
if b64[len(b64)-2] == '=' {
|
||||||
|
realLen--
|
||||||
|
}
|
||||||
|
return realLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DataPacketType) String() string {
|
||||||
|
eofStr := ""
|
||||||
|
if p.Eof {
|
||||||
|
eofStr = ", eof"
|
||||||
|
}
|
||||||
|
errStr := ""
|
||||||
|
if p.Error != "" {
|
||||||
|
errStr = fmt.Sprintf(", err=%s", p.Error)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("data[fd=%d, len=%d%s%s]", p.FdNum, B64DecodedLen(p.Data64), eofStr, errStr)
|
||||||
|
}
|
||||||
|
|
||||||
func MakeDataPacket() *DataPacketType {
|
func MakeDataPacket() *DataPacketType {
|
||||||
return &DataPacketType{Type: DataPacketStr}
|
return &DataPacketType{Type: DataPacketStr}
|
||||||
}
|
}
|
||||||
@ -140,13 +166,21 @@ type DataAckPacketType struct {
|
|||||||
CmdId string `json:"cmdid,omitempty"`
|
CmdId string `json:"cmdid,omitempty"`
|
||||||
FdNum int `json:"fdnum"`
|
FdNum int `json:"fdnum"`
|
||||||
AckLen int `json:"acklen"`
|
AckLen int `json:"acklen"`
|
||||||
Error string `json:"error"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*DataAckPacketType) GetType() string {
|
func (*DataAckPacketType) GetType() string {
|
||||||
return DataAckPacketStr
|
return DataAckPacketStr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DataAckPacketType) String() string {
|
||||||
|
errStr := ""
|
||||||
|
if p.Error != "" {
|
||||||
|
errStr = fmt.Sprintf(" err=%s", p.Error)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("ack[fd=%d, acklen=%d%s]", p.FdNum, p.AckLen, errStr)
|
||||||
|
}
|
||||||
|
|
||||||
func MakeDataAckPacket() *DataAckPacketType {
|
func MakeDataAckPacket() *DataAckPacketType {
|
||||||
return &DataAckPacketType{Type: DataAckPacketStr}
|
return &DataAckPacketType{Type: DataAckPacketStr}
|
||||||
}
|
}
|
||||||
@ -252,6 +286,10 @@ func (*RawPacketType) GetType() string {
|
|||||||
return RawPacketStr
|
return RawPacketStr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *RawPacketType) String() string {
|
||||||
|
return fmt.Sprintf("raw[%s]", p.Data)
|
||||||
|
}
|
||||||
|
|
||||||
func MakeRawPacket(val string) *RawPacketType {
|
func MakeRawPacket(val string) *RawPacketType {
|
||||||
return &RawPacketType{Type: RawPacketStr, Data: val}
|
return &RawPacketType{Type: RawPacketStr, Data: val}
|
||||||
}
|
}
|
||||||
@ -265,6 +303,10 @@ func (*MessagePacketType) GetType() string {
|
|||||||
return MessagePacketStr
|
return MessagePacketStr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *MessagePacketType) String() string {
|
||||||
|
return fmt.Sprintf("messsage[%s]", p.Message)
|
||||||
|
}
|
||||||
|
|
||||||
func MakeMessagePacket(message string) *MessagePacketType {
|
func MakeMessagePacket(message string) *MessagePacketType {
|
||||||
return &MessagePacketType{Type: MessagePacketStr, Message: message}
|
return &MessagePacketType{Type: MessagePacketStr, Message: message}
|
||||||
}
|
}
|
||||||
@ -394,6 +436,13 @@ type PacketType interface {
|
|||||||
GetType() string
|
GetType() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AsString(pk PacketType) string {
|
||||||
|
if s, ok := pk.(fmt.Stringer); ok {
|
||||||
|
return s.String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s[]", pk.GetType())
|
||||||
|
}
|
||||||
|
|
||||||
type RpcPacketType interface {
|
type RpcPacketType interface {
|
||||||
GetType() string
|
GetType() string
|
||||||
GetPacketId() string
|
GetPacketId() string
|
||||||
@ -433,8 +482,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
|
|||||||
outBuf.Write(jsonBytes)
|
outBuf.Write(jsonBytes)
|
||||||
outBuf.WriteByte('\n')
|
outBuf.WriteByte('\n')
|
||||||
if GlobalDebug {
|
if GlobalDebug {
|
||||||
outBytes := outBuf.Bytes()
|
fmt.Printf("SEND> %s\n", AsString(packet))
|
||||||
fmt.Printf("SEND>%s", string(outBytes[1:]))
|
|
||||||
}
|
}
|
||||||
_, err = w.Write(outBuf.Bytes())
|
_, err = w.Write(outBuf.Bytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -519,12 +567,14 @@ func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PacketParser(input io.Reader) chan PacketType {
|
func PacketParser(input io.Reader) chan PacketType {
|
||||||
bufReader := bufio.NewReader(input)
|
|
||||||
rtnCh := make(chan PacketType)
|
rtnCh := make(chan PacketType)
|
||||||
|
PacketParserAttach(input, rtnCh)
|
||||||
|
return rtnCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func PacketParserAttach(input io.Reader, rtnCh chan PacketType) {
|
||||||
|
bufReader := bufio.NewReader(input)
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
|
||||||
close(rtnCh)
|
|
||||||
}()
|
|
||||||
for {
|
for {
|
||||||
line, err := bufReader.ReadString('\n')
|
line, err := bufReader.ReadString('\n')
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@ -562,7 +612,6 @@ func PacketParser(input io.Reader) chan PacketType {
|
|||||||
rtnCh <- pk
|
rtnCh <- pk
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return rtnCh
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrorReporter interface {
|
type ErrorReporter interface {
|
||||||
|
@ -111,7 +111,7 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
|
|||||||
ecmd := exec.Command("bash", "-c", pk.Command)
|
ecmd := exec.Command("bash", "-c", pk.Command)
|
||||||
UpdateCmdEnv(ecmd, pk.Env)
|
UpdateCmdEnv(ecmd, pk.Env)
|
||||||
if pk.Cwd != "" {
|
if pk.Cwd != "" {
|
||||||
ecmd.Dir = pk.Cwd
|
ecmd.Dir = base.ExpandHomeDir(pk.Cwd)
|
||||||
}
|
}
|
||||||
ecmd.Stdin = cmdTty
|
ecmd.Stdin = cmdTty
|
||||||
ecmd.Stdout = cmdTty
|
ecmd.Stdout = cmdTty
|
||||||
@ -175,12 +175,13 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if pk.Cwd != "" {
|
if pk.Cwd != "" {
|
||||||
dirInfo, err := os.Stat(pk.Cwd)
|
realCwd := base.ExpandHomeDir(pk.Cwd)
|
||||||
|
dirInfo, err := os.Stat(realCwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid cwd '%s' for command: %v", pk.Cwd, err)
|
return fmt.Errorf("invalid cwd '%s' for command: %v", realCwd, err)
|
||||||
}
|
}
|
||||||
if !dirInfo.IsDir() {
|
if !dirInfo.IsDir() {
|
||||||
return fmt.Errorf("invalid cwd '%s' for command, not a directory", pk.Cwd)
|
return fmt.Errorf("invalid cwd '%s' for command, not a directory", realCwd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -228,8 +229,37 @@ func (opts *ClientOpts) MakeRunPacket() *packet.RunPacketType {
|
|||||||
return runPacket
|
return runPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidateRemoteFds(rfds []packet.RemoteFd) error {
|
||||||
|
dupMap := make(map[int]bool)
|
||||||
|
for _, rfd := range rfds {
|
||||||
|
if rfd.FdNum < 0 {
|
||||||
|
return fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
|
||||||
|
}
|
||||||
|
if rfd.FdNum < FirstExtraFilesFdNum {
|
||||||
|
return fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
|
||||||
|
}
|
||||||
|
if rfd.FdNum > MaxFdNum {
|
||||||
|
return fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
|
||||||
|
}
|
||||||
|
if dupMap[rfd.FdNum] {
|
||||||
|
return fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
|
||||||
|
}
|
||||||
|
if rfd.Read && rfd.Write {
|
||||||
|
return fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
|
||||||
|
}
|
||||||
|
if !rfd.Read && !rfd.Write {
|
||||||
|
return fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
|
||||||
|
}
|
||||||
|
dupMap[rfd.FdNum] = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) {
|
func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) {
|
||||||
// packet.GlobalDebug = true
|
err := ValidateRemoteFds(opts.Fds)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
cmd := MakeShExec("", "")
|
cmd := MakeShExec("", "")
|
||||||
sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote`
|
sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote`
|
||||||
var fullSshOpts []string
|
var fullSshOpts []string
|
||||||
@ -249,15 +279,27 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating stderr pipe: %v", err)
|
return nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||||
}
|
}
|
||||||
|
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false)
|
||||||
|
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false)
|
||||||
|
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false)
|
||||||
|
for _, rfd := range opts.Fds {
|
||||||
|
fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum))
|
||||||
|
if fd == nil {
|
||||||
|
return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum)
|
||||||
|
}
|
||||||
|
if rfd.Read {
|
||||||
|
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true)
|
||||||
|
} else if rfd.Write {
|
||||||
|
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
err = ecmd.Start()
|
err = ecmd.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("running ssh command: %w", err)
|
return nil, fmt.Errorf("running ssh command: %w", err)
|
||||||
}
|
}
|
||||||
defer cmd.Close()
|
defer cmd.Close()
|
||||||
packetCh := packet.PacketParser(stdoutReader)
|
packetCh := packet.PacketParser(stdoutReader)
|
||||||
go func() {
|
packet.PacketParserAttach(stderrReader, packetCh)
|
||||||
io.Copy(os.Stderr, stderrReader)
|
|
||||||
}()
|
|
||||||
sender := packet.MakePacketSender(inputWriter)
|
sender := packet.MakePacketSender(inputWriter)
|
||||||
for pk := range packetCh {
|
for pk := range packetCh {
|
||||||
if pk.GetType() == packet.RawPacketStr {
|
if pk.GetType() == packet.RawPacketStr {
|
||||||
@ -275,9 +317,6 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
|||||||
}
|
}
|
||||||
runPacket := opts.MakeRunPacket()
|
runPacket := opts.MakeRunPacket()
|
||||||
sender.SendPacket(runPacket)
|
sender.SendPacket(runPacket)
|
||||||
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin)
|
|
||||||
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout)
|
|
||||||
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr)
|
|
||||||
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true)
|
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true)
|
||||||
donePacket := cmd.WaitForCommand()
|
donePacket := cmd.WaitForCommand()
|
||||||
if remoteDonePacket != nil {
|
if remoteDonePacket != nil {
|
||||||
@ -287,6 +326,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
|
func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
|
||||||
|
defer cmd.Close()
|
||||||
cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false)
|
cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false)
|
||||||
donePacket := cmd.WaitForCommand()
|
donePacket := cmd.WaitForCommand()
|
||||||
sender.SendPacket(donePacket)
|
sender.SendPacket(donePacket)
|
||||||
@ -297,9 +337,13 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
|
|||||||
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
|
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
|
||||||
UpdateCmdEnv(cmd.Cmd, pk.Env)
|
UpdateCmdEnv(cmd.Cmd, pk.Env)
|
||||||
if pk.Cwd != "" {
|
if pk.Cwd != "" {
|
||||||
cmd.Cmd.Dir = pk.Cwd
|
cmd.Cmd.Dir = base.ExpandHomeDir(pk.Cwd)
|
||||||
|
}
|
||||||
|
err := ValidateRemoteFds(pk.Fds)
|
||||||
|
if err != nil {
|
||||||
|
cmd.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
|
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Close()
|
cmd.Close()
|
||||||
@ -317,33 +361,9 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
|
|||||||
}
|
}
|
||||||
extraFiles := make([]*os.File, 0, MaxFdNum+1)
|
extraFiles := make([]*os.File, 0, MaxFdNum+1)
|
||||||
for _, rfd := range pk.Fds {
|
for _, rfd := range pk.Fds {
|
||||||
if rfd.FdNum < 0 {
|
|
||||||
cmd.Close()
|
|
||||||
return nil, fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
|
|
||||||
}
|
|
||||||
if rfd.FdNum < FirstExtraFilesFdNum {
|
|
||||||
cmd.Close()
|
|
||||||
return nil, fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
|
|
||||||
}
|
|
||||||
if rfd.FdNum > MaxFdNum {
|
|
||||||
cmd.Close()
|
|
||||||
return nil, fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
|
|
||||||
}
|
|
||||||
if rfd.FdNum >= len(extraFiles) {
|
if rfd.FdNum >= len(extraFiles) {
|
||||||
extraFiles = extraFiles[:rfd.FdNum+1]
|
extraFiles = extraFiles[:rfd.FdNum+1]
|
||||||
}
|
}
|
||||||
if extraFiles[rfd.FdNum] != nil {
|
|
||||||
cmd.Close()
|
|
||||||
return nil, fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
|
|
||||||
}
|
|
||||||
if rfd.Read && rfd.Write {
|
|
||||||
cmd.Close()
|
|
||||||
return nil, fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
|
|
||||||
}
|
|
||||||
if !rfd.Read && !rfd.Write {
|
|
||||||
cmd.Close()
|
|
||||||
return nil, fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
|
|
||||||
}
|
|
||||||
if rfd.Read {
|
if rfd.Read {
|
||||||
// client file is open for reading, so we make a writer pipe
|
// client file is open for reading, so we make a writer pipe
|
||||||
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)
|
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)
|
||||||
|
Loading…
Reference in New Issue
Block a user