checkpoint. transfer binary data as base64. handle cwd. detect open fds. working to transfer data in non-error cases.

This commit is contained in:
sawka 2022-06-24 23:42:00 -07:00
parent 5223760a76
commit e6776bd974
7 changed files with 202 additions and 69 deletions

View File

@ -20,6 +20,7 @@ import (
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
"golang.org/x/sys/unix"
)
const MShellVersion = "0.1.0"
@ -256,8 +257,26 @@ func handleRemote() {
func handleServer() {
}
func detectOpenFds() {
func detectOpenFds() ([]packet.RemoteFd, error) {
var fds []packet.RemoteFd
for fdNum := 3; fdNum <= 64; fdNum++ {
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
if err != nil {
continue
}
flags = flags & 3
rfd := packet.RemoteFd{FdNum: fdNum}
if flags&2 == 2 {
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
}
if flags&1 == 1 {
rfd.Write = true
} else {
rfd.Read = true
}
fds = append(fds, rfd)
}
return fds, nil
}
func parseClientOpts() (*shexec.ClientOpts, error) {
@ -272,6 +291,13 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
opts.IsSSH = true
break
}
if argStr == "--cwd" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
opts.Cwd = iter.Next()
continue
}
}
if opts.IsSSH {
// parse SSH opts
@ -281,11 +307,6 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
opts.SSHOptsTerm = true
break
}
if argStr == "--cwd" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
}
opts.SSHOpts = append(opts.SSHOpts, argStr)
}
if !opts.SSHOptsTerm {
@ -310,6 +331,11 @@ func handleClient() (int, error) {
if !opts.IsSSH {
return 1, fmt.Errorf("when running in client mode '--ssh' option must be present")
}
fds, err := detectOpenFds()
if err != nil {
return 1, err
}
opts.Fds = fds
donePacket, err := shexec.RunClientSSHCommandAndWait(opts)
if err != nil {
return 1, err

View File

@ -14,6 +14,7 @@ import (
"os/exec"
"path"
"path/filepath"
"strings"
)
const DefaultMShellPath = "mshell"
@ -176,3 +177,14 @@ func WriteErrorMsg(fileName string, errVal string) error {
_, writeErr := fd.Write([]byte(oscEsc))
return writeErr
}
func ExpandHomeDir(pathStr string) string {
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
return pathStr
}
homeDir := GetHomeDir()
if pathStr == "~" {
return homeDir
}
return path.Join(homeDir, pathStr[2:])
}

View File

@ -48,6 +48,12 @@ func (r *FdReader) Close() {
r.CVar.Broadcast()
}
func (r *FdReader) GetBufSize() int {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
return r.BufSize
}
func (r *FdReader) NotifyAck(ackLen int) {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()

View File

@ -64,15 +64,15 @@ func (w *FdWriter) WaitForData() ([]byte, bool) {
func (w *FdWriter) AddData(data []byte, eof bool) error {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
if w.Closed {
return fmt.Errorf("write to closed file")
}
if w.Eof {
return fmt.Errorf("write to closed file (eof)")
if w.Closed || w.Eof {
if len(data) == 0 {
return nil
}
return fmt.Errorf("write to closed file eof[%v]", w.Eof)
}
if len(data) > 0 {
if len(data)+len(w.Buffer) > WriteBufSize {
return fmt.Errorf("write exceeds buffer size")
return fmt.Errorf("write exceeds buffer size bufsize=%d (max=%d)", len(data)+len(w.Buffer), WriteBufSize)
}
w.Buffer = append(w.Buffer, data...)
}

View File

@ -7,6 +7,7 @@
package mpio
import (
"encoding/base64"
"fmt"
"os"
"sync"
@ -14,8 +15,8 @@ import (
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const ReadBufSize = 32 * 1024
const WriteBufSize = 32 * 1024
const MaxSingleWriteSize = 4 * 1024
type Multiplexer struct {
@ -29,6 +30,8 @@ type Multiplexer struct {
Sender *packet.PacketSender
Input chan packet.PacketType
Started bool
Debug bool
}
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
@ -97,16 +100,16 @@ func (m *Multiplexer) MakeWriterPipe(fdNum int) (*os.File, error) {
return pr, nil
}
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File) {
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) {
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, false)
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
}
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File) {
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) {
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, false)
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)
}
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
@ -126,7 +129,7 @@ func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.
pk.SessionId = m.SessionId
pk.CmdId = m.CmdId
pk.FdNum = fdNum
pk.Data = string(data)
pk.Data64 = base64.StdEncoding.EncodeToString(data)
if err != nil {
pk.Error = err.Error()
}
@ -173,6 +176,9 @@ func (m *Multiplexer) startIO(packetCh chan packet.PacketType, sender *packet.Pa
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
defer m.HandleInputDone()
for pk := range m.Input {
if m.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := m.processDataPacket(dataPacket)
@ -191,12 +197,26 @@ func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
donePacket := pk.(*packet.CmdDonePacketType)
return donePacket
}
// other packet types are ignored
if pk.GetType() == packet.ErrorPacketStr {
errPacket := pk.(*packet.ErrorPacketType)
// at this point, just send the error packet to stderr rather than try to do something special
fmt.Fprintf(os.Stderr, "%s\n", errPacket.Error)
return nil
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Fprintf(os.Stderr, "%s\n", rawPacket.Data)
continue
}
}
return nil
}
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
if err != nil {
return fmt.Errorf("decoding base64 data: %w", err)
}
m.Lock.Lock()
defer m.Lock.Unlock()
fw := m.FdWriters[dataPacket.FdNum]
@ -205,9 +225,9 @@ func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error
fw := MakeFdWriter(m, nil, dataPacket.FdNum, false)
fw.Close()
m.FdWriters[dataPacket.FdNum] = fw
return fmt.Errorf("write to closed file")
return fmt.Errorf("write to closed file (no fd)")
}
err := fw.AddData([]byte(dataPacket.Data), dataPacket.Eof)
err = fw.AddData(realData, dataPacket.Eof)
if err != nil {
fw.Close()
return err

View File

@ -121,7 +121,7 @@ type DataPacketType struct {
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
FdNum int `json:"fdnum"`
Data string `json:"data"`
Data64 string `json:"data64"` // base64 encoded
Eof bool `json:"eof,omitempty"`
Error string `json:"error,omitempty"`
}
@ -130,6 +130,32 @@ func (*DataPacketType) GetType() string {
return DataPacketStr
}
func B64DecodedLen(b64 string) int {
if len(b64) < 4 {
return 0 // we use padded strings, so < 4 is always 0
}
realLen := 3 * (len(b64) / 4)
if b64[len(b64)-1] == '=' {
realLen--
}
if b64[len(b64)-2] == '=' {
realLen--
}
return realLen
}
func (p *DataPacketType) String() string {
eofStr := ""
if p.Eof {
eofStr = ", eof"
}
errStr := ""
if p.Error != "" {
errStr = fmt.Sprintf(", err=%s", p.Error)
}
return fmt.Sprintf("data[fd=%d, len=%d%s%s]", p.FdNum, B64DecodedLen(p.Data64), eofStr, errStr)
}
func MakeDataPacket() *DataPacketType {
return &DataPacketType{Type: DataPacketStr}
}
@ -140,13 +166,21 @@ type DataAckPacketType struct {
CmdId string `json:"cmdid,omitempty"`
FdNum int `json:"fdnum"`
AckLen int `json:"acklen"`
Error string `json:"error"`
Error string `json:"error,omitempty"`
}
func (*DataAckPacketType) GetType() string {
return DataAckPacketStr
}
func (p *DataAckPacketType) String() string {
errStr := ""
if p.Error != "" {
errStr = fmt.Sprintf(" err=%s", p.Error)
}
return fmt.Sprintf("ack[fd=%d, acklen=%d%s]", p.FdNum, p.AckLen, errStr)
}
func MakeDataAckPacket() *DataAckPacketType {
return &DataAckPacketType{Type: DataAckPacketStr}
}
@ -252,6 +286,10 @@ func (*RawPacketType) GetType() string {
return RawPacketStr
}
func (p *RawPacketType) String() string {
return fmt.Sprintf("raw[%s]", p.Data)
}
func MakeRawPacket(val string) *RawPacketType {
return &RawPacketType{Type: RawPacketStr, Data: val}
}
@ -265,6 +303,10 @@ func (*MessagePacketType) GetType() string {
return MessagePacketStr
}
func (p *MessagePacketType) String() string {
return fmt.Sprintf("messsage[%s]", p.Message)
}
func MakeMessagePacket(message string) *MessagePacketType {
return &MessagePacketType{Type: MessagePacketStr, Message: message}
}
@ -394,6 +436,13 @@ type PacketType interface {
GetType() string
}
func AsString(pk PacketType) string {
if s, ok := pk.(fmt.Stringer); ok {
return s.String()
}
return fmt.Sprintf("%s[]", pk.GetType())
}
type RpcPacketType interface {
GetType() string
GetPacketId() string
@ -433,8 +482,7 @@ func SendPacket(w io.Writer, packet PacketType) error {
outBuf.Write(jsonBytes)
outBuf.WriteByte('\n')
if GlobalDebug {
outBytes := outBuf.Bytes()
fmt.Printf("SEND>%s", string(outBytes[1:]))
fmt.Printf("SEND> %s\n", AsString(packet))
}
_, err = w.Write(outBuf.Bytes())
if err != nil {
@ -519,12 +567,14 @@ func (sender *PacketSender) SendMessage(fmtStr string, args ...interface{}) erro
}
func PacketParser(input io.Reader) chan PacketType {
bufReader := bufio.NewReader(input)
rtnCh := make(chan PacketType)
PacketParserAttach(input, rtnCh)
return rtnCh
}
func PacketParserAttach(input io.Reader, rtnCh chan PacketType) {
bufReader := bufio.NewReader(input)
go func() {
defer func() {
close(rtnCh)
}()
for {
line, err := bufReader.ReadString('\n')
if err == io.EOF {
@ -562,7 +612,6 @@ func PacketParser(input io.Reader) chan PacketType {
rtnCh <- pk
}
}()
return rtnCh
}
type ErrorReporter interface {

View File

@ -111,7 +111,7 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
ecmd := exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(ecmd, pk.Env)
if pk.Cwd != "" {
ecmd.Dir = pk.Cwd
ecmd.Dir = base.ExpandHomeDir(pk.Cwd)
}
ecmd.Stdin = cmdTty
ecmd.Stdout = cmdTty
@ -175,12 +175,13 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
}
}
if pk.Cwd != "" {
dirInfo, err := os.Stat(pk.Cwd)
realCwd := base.ExpandHomeDir(pk.Cwd)
dirInfo, err := os.Stat(realCwd)
if err != nil {
return fmt.Errorf("invalid cwd '%s' for command: %v", pk.Cwd, err)
return fmt.Errorf("invalid cwd '%s' for command: %v", realCwd, err)
}
if !dirInfo.IsDir() {
return fmt.Errorf("invalid cwd '%s' for command, not a directory", pk.Cwd)
return fmt.Errorf("invalid cwd '%s' for command, not a directory", realCwd)
}
}
return nil
@ -228,8 +229,37 @@ func (opts *ClientOpts) MakeRunPacket() *packet.RunPacketType {
return runPacket
}
func ValidateRemoteFds(rfds []packet.RemoteFd) error {
dupMap := make(map[int]bool)
for _, rfd := range rfds {
if rfd.FdNum < 0 {
return fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
}
if rfd.FdNum < FirstExtraFilesFdNum {
return fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
}
if rfd.FdNum > MaxFdNum {
return fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
}
if dupMap[rfd.FdNum] {
return fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
}
if rfd.Read && rfd.Write {
return fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
}
if !rfd.Read && !rfd.Write {
return fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
}
dupMap[rfd.FdNum] = true
}
return nil
}
func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, error) {
// packet.GlobalDebug = true
err := ValidateRemoteFds(opts.Fds)
if err != nil {
return nil, err
}
cmd := MakeShExec("", "")
sshRemoteCommand := `PATH=$PATH:~/.mshell; mshell --remote`
var fullSshOpts []string
@ -249,15 +279,27 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
if err != nil {
return nil, fmt.Errorf("creating stderr pipe: %v", err)
}
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false)
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false)
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false)
for _, rfd := range opts.Fds {
fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum))
if fd == nil {
return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum)
}
if rfd.Read {
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true)
} else if rfd.Write {
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true)
}
}
err = ecmd.Start()
if err != nil {
return nil, fmt.Errorf("running ssh command: %w", err)
}
defer cmd.Close()
packetCh := packet.PacketParser(stdoutReader)
go func() {
io.Copy(os.Stderr, stderrReader)
}()
packet.PacketParserAttach(stderrReader, packetCh)
sender := packet.MakePacketSender(inputWriter)
for pk := range packetCh {
if pk.GetType() == packet.RawPacketStr {
@ -275,9 +317,6 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
}
runPacket := opts.MakeRunPacket()
sender.SendPacket(runPacket)
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin)
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout)
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr)
remoteDonePacket := cmd.Multiplexer.RunIOAndWait(packetCh, sender, false, true, true)
donePacket := cmd.WaitForCommand()
if remoteDonePacket != nil {
@ -287,6 +326,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
}
func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sender *packet.PacketSender) {
defer cmd.Close()
cmd.Multiplexer.RunIOAndWait(packetCh, sender, true, false, false)
donePacket := cmd.WaitForCommand()
sender.SendPacket(donePacket)
@ -297,9 +337,13 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(cmd.Cmd, pk.Env)
if pk.Cwd != "" {
cmd.Cmd.Dir = pk.Cwd
cmd.Cmd.Dir = base.ExpandHomeDir(pk.Cwd)
}
err := ValidateRemoteFds(pk.Fds)
if err != nil {
cmd.Close()
return nil, err
}
var err error
cmd.Cmd.Stdin, err = cmd.Multiplexer.MakeWriterPipe(0)
if err != nil {
cmd.Close()
@ -317,33 +361,9 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
}
extraFiles := make([]*os.File, 0, MaxFdNum+1)
for _, rfd := range pk.Fds {
if rfd.FdNum < 0 {
cmd.Close()
return nil, fmt.Errorf("mshell negative fd numbers fd=%d", rfd.FdNum)
}
if rfd.FdNum < FirstExtraFilesFdNum {
cmd.Close()
return nil, fmt.Errorf("mshell does not support re-opening fd=%d (0, 1, and 2, are always open)", rfd.FdNum)
}
if rfd.FdNum > MaxFdNum {
cmd.Close()
return nil, fmt.Errorf("mshell does not support opening fd numbers above %d", MaxFdNum)
}
if rfd.FdNum >= len(extraFiles) {
extraFiles = extraFiles[:rfd.FdNum+1]
}
if extraFiles[rfd.FdNum] != nil {
cmd.Close()
return nil, fmt.Errorf("mshell got duplicate entries for fd=%d", rfd.FdNum)
}
if rfd.Read && rfd.Write {
cmd.Close()
return nil, fmt.Errorf("mshell does not support opening fd numbers for reading and writing, fd=%d", rfd.FdNum)
}
if !rfd.Read && !rfd.Write {
cmd.Close()
return nil, fmt.Errorf("invalid fd=%d, neither reading or writing mode specified", rfd.FdNum)
}
if rfd.Read {
// client file is open for reading, so we make a writer pipe
extraFiles[rfd.FdNum], err = cmd.Multiplexer.MakeWriterPipe(rfd.FdNum)