combine sessionid and cmdid into one field ck (commandkey)

This commit is contained in:
sawka 2022-06-27 12:03:47 -07:00
parent 1ea8393844
commit 2a6791bcd6
6 changed files with 236 additions and 172 deletions

View File

@ -15,7 +15,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/cmdtail" "github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
@ -38,7 +37,7 @@ func setupSingleSignals(cmd *shexec.ShExecType) {
}() }()
} }
func doSingle(cmdId string) { func doSingle(ck base.CommandKey) {
packetCh := packet.PacketParser(os.Stdin) packetCh := packet.PacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout) sender := packet.MakePacketSender(os.Stdout)
var runPacket *packet.RunPacketType var runPacket *packet.RunPacketType
@ -57,11 +56,11 @@ func doSingle(cmdId string) {
sender.SendErrorPacket("did not receive a 'run' packet") sender.SendErrorPacket("did not receive a 'run' packet")
return return
} }
if runPacket.CmdId == "" { if runPacket.CK.IsEmpty() {
runPacket.CmdId = cmdId runPacket.CK = ck
} }
if runPacket.CmdId != cmdId { if runPacket.CK != ck {
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId)) sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CK, ck))
return return
} }
cmd, err := shexec.RunCommand(runPacket, sender) cmd, err := shexec.RunCommand(runPacket, sender)
@ -79,39 +78,36 @@ func doSingle(cmdId string) {
} }
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
if pk.CmdId == "" {
pk.CmdId = uuid.New().String()
}
err := shexec.ValidateRunPacket(pk) err := shexec.ValidateRunPacket(pk)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err)))
return return
} }
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId) fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err)))
return return
} }
cmd, err := shexec.MakeRunnerExec(pk.CmdId) cmd, err := shexec.MakeRunnerExec(pk.CK)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make mshell command: %v", err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err)))
return return
} }
cmdStdin, err := cmd.StdinPipe() cmdStdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
return return
} }
// touch ptyout file (should exist for tailer to work correctly) // touch ptyout file (should exist for tailer to work correctly)
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
return return
} }
ptyOutFd.Close() // just opened to create the file, can close right after ptyOutFd.Close() // just opened to create the file, can close right after
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
return return
} }
defer runnerOutFd.Close() defer runnerOutFd.Close()
@ -119,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
cmd.Stderr = runnerOutFd cmd.Stderr = runnerOutFd
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err)))
return return
} }
go func() { go func() {
err = packet.SendPacket(cmdStdin, pk) err = packet.SendPacket(cmdStdin, pk)
if err != nil { if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err))) sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err)))
return return
} }
cmdStdin.Close() cmdStdin.Close()
@ -451,12 +447,12 @@ func main() {
} }
if len(os.Args) >= 2 { if len(os.Args) >= 2 {
cmdId, err := uuid.Parse(os.Args[1]) ck := base.CommandKey(os.Args[1])
if err != nil { if err := ck.Validate("mshell arg"); err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to mshell", err)) packet.SendErrorPacket(os.Stdout, err.Error())
return return
} }
doSingle(cmdId.String()) doSingle(ck)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
return return
} else { } else {

View File

@ -15,6 +15,8 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/google/uuid"
) )
const DefaultMShellPath = "mshell" const DefaultMShellPath = "mshell"
@ -37,6 +39,68 @@ type CommandFileNames struct {
RunnerOutFile string RunnerOutFile string
} }
type CommandKey string
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
if sessionId == "" && cmdId == "" {
return CommandKey("")
}
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
}
func (ckey CommandKey) IsEmpty() bool {
return string(ckey) == ""
}
func (ckey CommandKey) GetSessionId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[0:slashIdx])
}
func (ckey CommandKey) GetCmdId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[slashIdx+1:])
}
func (ckey CommandKey) Split() (string, string) {
fields := strings.SplitN(string(ckey), "/", 2)
if len(fields) < 2 {
return "", ""
}
return fields[0], fields[1]
}
func (ckey CommandKey) Validate(typeStr string) error {
if typeStr == "" {
typeStr = "ck"
}
if ckey == "" {
return fmt.Errorf("%s has empty commandkey", typeStr)
}
sessionId, cmdId := ckey.Split()
if sessionId == "" {
return fmt.Errorf("%s does not have sessionid", typeStr)
}
_, err := uuid.Parse(sessionId)
if err != nil {
return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId)
}
if cmdId == "" {
return fmt.Errorf("%s does not have cmdid", typeStr)
}
_, err = uuid.Parse(cmdId)
if err != nil {
return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId)
}
return nil
}
func GetHomeDir() string { func GetHomeDir() string {
homeVar := os.Getenv(HomeVarName) homeVar := os.Getenv(HomeVarName)
if homeVar == "" { if homeVar == "" {
@ -57,10 +121,11 @@ func GetScHomeDir() (string, error) {
return scHome, nil return scHome, nil
} }
func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, error) { func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
if sessionId == "" || cmdId == "" { if err := ck.Validate("ck"); err != nil {
return nil, fmt.Errorf("cannot get command-files when sessionid or cmdid is empty") return nil, fmt.Errorf("cannot get command files: %w", err)
} }
sessionId, cmdId := ck.Split()
sdir, err := EnsureSessionDir(sessionId) sdir, err := EnsureSessionDir(sessionId)
if err != nil { if err != nil {
return nil, err return nil, err
@ -73,8 +138,8 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err
}, nil }, nil
} }
func MakeCommandFileNamesWithHome(scHome string, sessionId string, cmdId string) *CommandFileNames { func MakeCommandFileNamesWithHome(scHome string, ck CommandKey) *CommandFileNames {
base := path.Join(scHome, SessionsDirBaseName, sessionId, cmdId) base := path.Join(scHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId())
return &CommandFileNames{ return &CommandFileNames{
PtyOutFile: base + ".ptyout", PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin", StdinFifo: base + ".stdin",

View File

@ -15,7 +15,6 @@ import (
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
) )
@ -33,7 +32,7 @@ type TailPos struct {
} }
type CmdWatchEntry struct { type CmdWatchEntry struct {
CmdKey CmdKey CmdKey base.CommandKey
FilePtyLen int64 FilePtyLen int64
FileRunLen int64 FileRunLen int64
Tails []TailPos Tails []TailPos
@ -73,20 +72,15 @@ func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool {
return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen
} }
type CmdKey struct {
SessionId string
CmdId string
}
type Tailer struct { type Tailer struct {
Lock *sync.Mutex Lock *sync.Mutex
WatchList map[CmdKey]CmdWatchEntry WatchList map[base.CommandKey]CmdWatchEntry
ScHomeDir string ScHomeDir string
Watcher *fsnotify.Watcher Watcher *fsnotify.Watcher
SendCh chan packet.PacketType SendCh chan packet.PacketType
} }
func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos) { func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) {
entry, found := t.WatchList[cmdKey] entry, found := t.WatchList[cmdKey]
if !found { if !found {
return return
@ -95,7 +89,7 @@ func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos)
t.WatchList[cmdKey] = entry t.WatchList[cmdKey] = entry
} }
func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) { func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
entry, found := t.WatchList[cmdKey] entry, found := t.WatchList[cmdKey]
if !found { if !found {
return return
@ -107,13 +101,13 @@ func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) {
} }
// delete from watchlist, remove watches // delete from watchlist, remove watches
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey.SessionId, cmdKey.CmdId) fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey)
delete(t.WatchList, cmdKey) delete(t.WatchList, cmdKey)
t.Watcher.Remove(fileNames.PtyOutFile) t.Watcher.Remove(fileNames.PtyOutFile)
t.Watcher.Remove(fileNames.RunnerOutFile) t.Watcher.Remove(fileNames.RunnerOutFile)
} }
func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int64) { func (t *Tailer) updateEntrySizes_nolock(cmdKey base.CommandKey, ptyLen int64, runLen int64) {
entry, found := t.WatchList[cmdKey] entry, found := t.WatchList[cmdKey]
if !found { if !found {
return return
@ -123,7 +117,7 @@ func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int
t.WatchList[cmdKey] = entry t.WatchList[cmdKey] = entry
} }
func (t *Tailer) getEntryAndPos_nolock(cmdKey CmdKey, reqId string) (CmdWatchEntry, TailPos, bool) { func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) {
entry, found := t.WatchList[cmdKey] entry, found := t.WatchList[cmdKey]
if !found { if !found {
return CmdWatchEntry{}, TailPos{}, false return CmdWatchEntry{}, TailPos{}, false
@ -142,7 +136,7 @@ func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) {
} }
rtn := &Tailer{ rtn := &Tailer{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
WatchList: make(map[CmdKey]CmdWatchEntry), WatchList: make(map[base.CommandKey]CmdWatchEntry),
ScHomeDir: scHomeDir, ScHomeDir: scHomeDir,
SendCh: sendCh, SendCh: sendCh,
} }
@ -170,8 +164,7 @@ func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]b
func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) *packet.CmdDataPacketType { func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) *packet.CmdDataPacketType {
dataPacket := packet.MakeCmdDataPacket() dataPacket := packet.MakeCmdDataPacket()
dataPacket.ReqId = pos.ReqId dataPacket.ReqId = pos.ReqId
dataPacket.SessionId = entry.CmdKey.SessionId dataPacket.CK = entry.CmdKey
dataPacket.CmdId = entry.CmdKey.CmdId
dataPacket.PtyPos = pos.TailPtyPos dataPacket.PtyPos = pos.TailPtyPos
dataPacket.RunPos = pos.TailRunPos dataPacket.RunPos = pos.TailRunPos
if entry.FilePtyLen > pos.TailPtyPos { if entry.FilePtyLen > pos.TailPtyPos {
@ -196,14 +189,14 @@ func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWa
} }
// returns (data-packet, keepRunning) // returns (data-packet, keepRunning)
func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDataPacketType, bool) { func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool) {
t.Lock.Lock() t.Lock.Lock()
entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId) entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId)
t.Lock.Unlock() t.Lock.Unlock()
if !foundPos { if !foundPos {
return nil, false return nil, false
} }
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key.SessionId, key.CmdId) fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key)
dataPacket := t.makeCmdDataPacket(fileNames, entry, pos) dataPacket := t.makeCmdDataPacket(fileNames, entry, pos)
t.Lock.Lock() t.Lock.Lock()
@ -232,7 +225,7 @@ func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDat
return dataPacket, pos.Running return dataPacket, pos.Running
} }
func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) { func (t *Tailer) checkRemoveNoFollow(cmdKey base.CommandKey, reqId string) {
t.Lock.Lock() t.Lock.Lock()
defer t.Lock.Unlock() defer t.Lock.Unlock()
_, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId) _, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId)
@ -244,7 +237,7 @@ func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) {
} }
} }
func (t *Tailer) RunDataTransfer(key CmdKey, reqId string) { func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) {
for { for {
dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId) dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId)
if dataPacket != nil { if dataPacket != nil {
@ -283,7 +276,7 @@ func (t *Tailer) updateFile(relFileName string) {
t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err) t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err)
return return
} }
cmdKey := CmdKey{SessionId: m[1], CmdId: m[2]} cmdKey := base.MakeCommandKey(m[1], m[2])
t.Lock.Lock() t.Lock.Lock()
defer t.Lock.Unlock() defer t.Lock.Unlock()
entry, foundEntry := t.WatchList[cmdKey] entry, foundEntry := t.WatchList[cmdKey]
@ -336,7 +329,7 @@ func max(v1 int64, v2 int64) int64 {
} }
func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) { func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) {
fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey.SessionId, entry.CmdKey.CmdId) fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey)
ptyInfo, _ := os.Stat(fileNames.PtyOutFile) ptyInfo, _ := os.Stat(fileNames.PtyOutFile)
if ptyInfo != nil { if ptyInfo != nil {
entry.FilePtyLen = ptyInfo.Size() entry.FilePtyLen = ptyInfo.Size()
@ -350,30 +343,24 @@ func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) {
func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) { func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) {
t.Lock.Lock() t.Lock.Lock()
defer t.Lock.Unlock() defer t.Lock.Unlock()
key := CmdKey{pk.SessionId, pk.CmdId} t.removeTailPos_nolock(pk.CK, pk.ReqId)
t.removeTailPos_nolock(key, pk.ReqId)
} }
func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error { func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
_, err := uuid.Parse(getPacket.SessionId) if err := getPacket.CK.Validate("getcmd"); err != nil {
if err != nil { return err
return fmt.Errorf("getcmd, bad sessionid '%s': %w", getPacket.SessionId, err)
}
_, err = uuid.Parse(getPacket.CmdId)
if err != nil {
return fmt.Errorf("getcmd, bad cmdid '%s': %w", getPacket.CmdId, err)
} }
if getPacket.ReqId == "" { if getPacket.ReqId == "" {
return fmt.Errorf("getcmd, no reqid specified") return fmt.Errorf("getcmd, no reqid specified")
} }
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.SessionId, getPacket.CmdId) fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.CK)
t.Lock.Lock() t.Lock.Lock()
defer t.Lock.Unlock() defer t.Lock.Unlock()
key := CmdKey{getPacket.SessionId, getPacket.CmdId} key := getPacket.CK
entry, foundEntry := t.WatchList[key] entry, foundEntry := t.WatchList[key]
if !foundEntry { if !foundEntry {
// add watches, initialize entry // add watches, initialize entry
err = t.Watcher.Add(fileNames.PtyOutFile) err := t.Watcher.Add(fileNames.PtyOutFile)
if err != nil { if err != nil {
return err return err
} }

View File

@ -12,6 +12,7 @@ import (
"os" "os"
"sync" "sync"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
) )
@ -21,8 +22,7 @@ const MaxSingleWriteSize = 4 * 1024
type Multiplexer struct { type Multiplexer struct {
Lock *sync.Mutex Lock *sync.Mutex
SessionId string CK base.CommandKey
CmdId string
FdReaders map[int]*FdReader // synchronized FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized CloseAfterStart []*os.File // synchronized
@ -34,11 +34,10 @@ type Multiplexer struct {
Debug bool Debug bool
} }
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer { func MakeMultiplexer(ck base.CommandKey) *Multiplexer {
return &Multiplexer{ return &Multiplexer{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
SessionId: sessionId, CK: ck,
CmdId: cmdId,
FdReaders: make(map[int]*FdReader), FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter), FdWriters: make(map[int]*FdWriter),
} }
@ -126,8 +125,7 @@ func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool)
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType { func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket() ack := packet.MakeDataAckPacket()
ack.SessionId = m.SessionId ack.CK = m.CK
ack.CmdId = m.CmdId
ack.FdNum = fdNum ack.FdNum = fdNum
ack.AckLen = ackLen ack.AckLen = ackLen
if err != nil { if err != nil {
@ -138,8 +136,7 @@ func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packe
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType { func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket() pk := packet.MakeDataPacket()
pk.SessionId = m.SessionId pk.CK = m.CK
pk.CmdId = m.CmdId
pk.FdNum = fdNum pk.FdNum = fdNum
pk.Data64 = base64.StdEncoding.EncodeToString(data) pk.Data64 = base64.StdEncoding.EncodeToString(data)
if err != nil { if err != nil {

View File

@ -16,6 +16,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"github.com/scripthaus-dev/mshell/pkg/base"
) )
// remote: init, run, ping, data, cmdstart, cmddone // remote: init, run, ping, data, cmdstart, cmddone
@ -80,26 +82,29 @@ func MakePacket(packetType string) (PacketType, error) {
} }
type CmdDataPacketType struct { type CmdDataPacketType struct {
Type string `json:"type"` Type string `json:"type"`
ReqId string `json:"reqid"` ReqId string `json:"reqid"`
SessionId string `json:"sessionid"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid"` PtyPos int64 `json:"ptypos"`
PtyPos int64 `json:"ptypos"` PtyLen int64 `json:"ptylen"`
PtyLen int64 `json:"ptylen"` RunPos int64 `json:"runpos"`
RunPos int64 `json:"runpos"` RunLen int64 `json:"runlen"`
RunLen int64 `json:"runlen"` PtyData string `json:"ptydata"`
PtyData string `json:"ptydata"` PtyDataLen int `json:"ptydatalen"`
PtyDataLen int `json:"ptydatalen"` RunData string `json:"rundata"`
RunData string `json:"rundata"` RunDataLen int `json:"rundatalen"`
RunDataLen int `json:"rundatalen"` Error string `json:"error"`
Error string `json:"error"` NotFound bool `json:"notfound,omitempty"`
NotFound bool `json:"notfound,omitempty"`
} }
func (*CmdDataPacketType) GetType() string { func (*CmdDataPacketType) GetType() string {
return CmdDataPacketStr return CmdDataPacketStr
} }
func (p *CmdDataPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeCmdDataPacket() *CmdDataPacketType { func MakeCmdDataPacket() *CmdDataPacketType {
return &CmdDataPacketType{Type: CmdDataPacketStr} return &CmdDataPacketType{Type: CmdDataPacketStr}
} }
@ -117,19 +122,22 @@ func MakePingPacket() *PingPacketType {
} }
type DataPacketType struct { type DataPacketType struct {
Type string `json:"type"` Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid,omitempty"` FdNum int `json:"fdnum"`
FdNum int `json:"fdnum"` Data64 string `json:"data64"` // base64 encoded
Data64 string `json:"data64"` // base64 encoded Eof bool `json:"eof,omitempty"`
Eof bool `json:"eof,omitempty"` Error string `json:"error,omitempty"`
Error string `json:"error,omitempty"`
} }
func (*DataPacketType) GetType() string { func (*DataPacketType) GetType() string {
return DataPacketStr return DataPacketStr
} }
func (p *DataPacketType) GetCK() base.CommandKey {
return p.CK
}
func B64DecodedLen(b64 string) int { func B64DecodedLen(b64 string) int {
if len(b64) < 4 { if len(b64) < 4 {
return 0 // we use padded strings, so < 4 is always 0 return 0 // we use padded strings, so < 4 is always 0
@ -161,18 +169,21 @@ func MakeDataPacket() *DataPacketType {
} }
type DataAckPacketType struct { type DataAckPacketType struct {
Type string `json:"type"` Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid,omitempty"` FdNum int `json:"fdnum"`
FdNum int `json:"fdnum"` AckLen int `json:"acklen"`
AckLen int `json:"acklen"` Error string `json:"error,omitempty"`
Error string `json:"error,omitempty"`
} }
func (*DataAckPacketType) GetType() string { func (*DataAckPacketType) GetType() string {
return DataAckPacketStr return DataAckPacketStr
} }
func (p *DataAckPacketType) GetCK() base.CommandKey {
return p.CK
}
func (p *DataAckPacketType) String() string { func (p *DataAckPacketType) String() string {
errStr := "" errStr := ""
if p.Error != "" { if p.Error != "" {
@ -189,52 +200,61 @@ func MakeDataAckPacket() *DataAckPacketType {
// SigNum gets sent to process via a signal // SigNum gets sent to process via a signal
// WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH // WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH
type InputPacketType struct { type InputPacketType struct {
Type string `json:"type"` Type string `json:"type"`
SessionId string `json:"sessionid"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid"` InputData string `json:"inputdata"`
InputData string `json:"inputdata"` SigNum int `json:"signum,omitempty"`
SigNum int `json:"signum,omitempty"` WinSizeRows int `json:"winsizerows"`
WinSizeRows int `json:"winsizerows"` WinSizeCols int `json:"winsizecols"`
WinSizeCols int `json:"winsizecols"`
} }
func (*InputPacketType) GetType() string { func (*InputPacketType) GetType() string {
return InputPacketStr return InputPacketStr
} }
func (p *InputPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeInputPacket() *InputPacketType { func MakeInputPacket() *InputPacketType {
return &InputPacketType{Type: InputPacketStr} return &InputPacketType{Type: InputPacketStr}
} }
type UntailCmdPacketType struct { type UntailCmdPacketType struct {
Type string `json:"type"` Type string `json:"type"`
ReqId string `json:"reqid"` ReqId string `json:"reqid"`
SessionId string `json:"sessionid"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid"`
} }
func (*UntailCmdPacketType) GetType() string { func (*UntailCmdPacketType) GetType() string {
return UntailCmdPacketStr return UntailCmdPacketStr
} }
func (p *UntailCmdPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeUntailCmdPacket() *UntailCmdPacketType { func MakeUntailCmdPacket() *UntailCmdPacketType {
return &UntailCmdPacketType{Type: UntailCmdPacketStr} return &UntailCmdPacketType{Type: UntailCmdPacketStr}
} }
type GetCmdPacketType struct { type GetCmdPacketType struct {
Type string `json:"type"` Type string `json:"type"`
ReqId string `json:"reqid"` ReqId string `json:"reqid"`
SessionId string `json:"sessionid"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid"` PtyPos int64 `json:"ptypos"`
PtyPos int64 `json:"ptypos"` RunPos int64 `json:"runpos"`
RunPos int64 `json:"runpos"` Tail bool `json:"tail,omitempty"`
Tail bool `json:"tail,omitempty"`
} }
func (*GetCmdPacketType) GetType() string { func (*GetCmdPacketType) GetType() string {
return GetCmdPacketStr return GetCmdPacketStr
} }
func (p *GetCmdPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeGetCmdPacket() *GetCmdPacketType { func MakeGetCmdPacket() *GetCmdPacketType {
return &GetCmdPacketType{Type: GetCmdPacketStr} return &GetCmdPacketType{Type: GetCmdPacketStr}
} }
@ -346,35 +366,41 @@ func MakeDonePacket() *DonePacketType {
} }
type CmdDonePacketType struct { type CmdDonePacketType struct {
Type string `json:"type"` Type string `json:"type"`
Ts int64 `json:"ts"` Ts int64 `json:"ts"`
SessionId string `json:"sessionid,omitempty"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid,omitempty"` ExitCode int `json:"exitcode"`
ExitCode int `json:"exitcode"` DurationMs int64 `json:"durationms"`
DurationMs int64 `json:"durationms"`
} }
func (*CmdDonePacketType) GetType() string { func (*CmdDonePacketType) GetType() string {
return CmdDonePacketStr return CmdDonePacketStr
} }
func (p *CmdDonePacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeCmdDonePacket() *CmdDonePacketType { func MakeCmdDonePacket() *CmdDonePacketType {
return &CmdDonePacketType{Type: CmdDonePacketStr} return &CmdDonePacketType{Type: CmdDonePacketStr}
} }
type CmdStartPacketType struct { type CmdStartPacketType struct {
Type string `json:"type"` Type string `json:"type"`
Ts int64 `json:"ts"` Ts int64 `json:"ts"`
SessionId string `json:"sessionid,omitempty"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid,omitempty"` Pid int `json:"pid"`
Pid int `json:"pid"` MShellPid int `json:"mshellpid"`
MShellPid int `json:"mshellpid"`
} }
func (*CmdStartPacketType) GetType() string { func (*CmdStartPacketType) GetType() string {
return CmdStartPacketStr return CmdStartPacketStr
} }
func (p *CmdStartPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeCmdStartPacket() *CmdStartPacketType { func MakeCmdStartPacket() *CmdStartPacketType {
return &CmdStartPacketType{Type: CmdStartPacketStr} return &CmdStartPacketType{Type: CmdStartPacketStr}
} }
@ -393,21 +419,24 @@ type RemoteFd struct {
} }
type RunPacketType struct { type RunPacketType struct {
Type string `json:"type"` Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"` CK base.CommandKey `json:"ck"`
CmdId string `json:"cmdid,omitempty"` Command string `json:"command"`
Command string `json:"command"` Cwd string `json:"cwd,omitempty"`
Cwd string `json:"cwd,omitempty"` Env map[string]string `json:"env,omitempty"`
Env map[string]string `json:"env,omitempty"` TermSize *TermSize `json:"termsize,omitempty"`
TermSize *TermSize `json:"termsize,omitempty"` Fds []RemoteFd `json:"fds,omitempty"`
Fds []RemoteFd `json:"fds,omitempty"` Detached bool `json:"detached,omitempty"`
Detached bool `json:"detached,omitempty"`
} }
func (*RunPacketType) GetType() string { func (*RunPacketType) GetType() string {
return RunPacketStr return RunPacketStr
} }
func (p *RunPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeRunPacket() *RunPacketType { func MakeRunPacket() *RunPacketType {
return &RunPacketType{Type: RunPacketStr} return &RunPacketType{Type: RunPacketStr}
} }
@ -417,9 +446,9 @@ type BarePacketType struct {
} }
type ErrorPacketType struct { type ErrorPacketType struct {
Id string `json:"id,omitempty"` CK base.CommandKey `json:"ck,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Error string `json:"error"` Error string `json:"error"`
} }
func (et *ErrorPacketType) GetType() string { func (et *ErrorPacketType) GetType() string {
@ -430,8 +459,8 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr} return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr}
} }
func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType { func MakeCKErrorPacket(ck base.CommandKey, errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr} return &ErrorPacketType{Type: ErrorPacketStr, CK: ck, Error: errorStr}
} }
type PacketType interface { type PacketType interface {
@ -450,6 +479,11 @@ type RpcPacketType interface {
GetPacketId() string GetPacketId() string
} }
type CommandPacketType interface {
GetType() string
GetCK() base.CommandKey
}
func ParseJsonPacket(jsonBuf []byte) (PacketType, error) { func ParseJsonPacket(jsonBuf []byte) (PacketType, error) {
var bareCmd BarePacketType var bareCmd BarePacketType
err := json.Unmarshal(jsonBuf, &bareCmd) err := json.Unmarshal(jsonBuf, &bareCmd)

View File

@ -17,7 +17,6 @@ import (
"time" "time"
"github.com/creack/pty" "github.com/creack/pty"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/mpio" "github.com/scripthaus-dev/mshell/pkg/mpio"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
@ -39,21 +38,19 @@ const RemoteSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -S -C %d bash -c "ec
type ShExecType struct { type ShExecType struct {
Lock *sync.Mutex Lock *sync.Mutex
StartTs time.Time StartTs time.Time
SessionId string CK base.CommandKey
CmdId string
FileNames *base.CommandFileNames FileNames *base.CommandFileNames
Cmd *exec.Cmd Cmd *exec.Cmd
CmdPty *os.File CmdPty *os.File
Multiplexer *mpio.Multiplexer Multiplexer *mpio.Multiplexer
} }
func MakeShExec(sessionId string, cmdId string) *ShExecType { func MakeShExec(ck base.CommandKey) *ShExecType {
return &ShExecType{ return &ShExecType{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
StartTs: time.Now(), StartTs: time.Now(),
SessionId: sessionId, CK: ck,
CmdId: cmdId, Multiplexer: mpio.MakeMultiplexer(ck),
Multiplexer: mpio.MakeMultiplexer(sessionId, cmdId),
} }
} }
@ -67,8 +64,7 @@ func (c *ShExecType) Close() {
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType { func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType {
startPacket := packet.MakeCmdStartPacket() startPacket := packet.MakeCmdStartPacket()
startPacket.Ts = time.Now().UnixMilli() startPacket.Ts = time.Now().UnixMilli()
startPacket.SessionId = c.SessionId startPacket.CK = c.CK
startPacket.CmdId = c.CmdId
startPacket.Pid = c.Cmd.Process.Pid startPacket.Pid = c.Cmd.Process.Pid
startPacket.MShellPid = os.Getpid() startPacket.MShellPid = os.Getpid()
return startPacket return startPacket
@ -129,12 +125,12 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
return ecmd return ecmd
} }
func MakeRunnerExec(cmdId string) (*exec.Cmd, error) { func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
msPath, err := base.GetMShellPath() msPath, err := base.GetMShellPath()
if err != nil { if err != nil {
return nil, err return nil, err
} }
ecmd := exec.Command(msPath, cmdId) ecmd := exec.Command(msPath, string(ck))
return ecmd, nil return ecmd, nil
} }
@ -165,19 +161,9 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
return fmt.Errorf("run packet has wrong type: %s", pk.Type) return fmt.Errorf("run packet has wrong type: %s", pk.Type)
} }
if pk.Detached { if pk.Detached {
if pk.SessionId == "" { err := pk.CK.Validate("run packet")
return fmt.Errorf("run packet does not have sessionid")
}
_, err := uuid.Parse(pk.SessionId)
if err != nil { if err != nil {
return fmt.Errorf("invalid sessionid '%s' for command", pk.SessionId) return err
}
if pk.CmdId == "" {
return fmt.Errorf("run packet does not have cmdid")
}
_, err = uuid.Parse(pk.CmdId)
if err != nil {
return fmt.Errorf("invalid cmdid '%s' for command", pk.CmdId)
} }
} }
if pk.Cwd != "" { if pk.Cwd != "" {
@ -337,7 +323,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
cmd := MakeShExec("", "") cmd := MakeShExec("")
var fullSshOpts []string var fullSshOpts []string
fullSshOpts = append(fullSshOpts, opts.SSHOpts...) fullSshOpts = append(fullSshOpts, opts.SSHOpts...)
fullSshOpts = append(fullSshOpts, SSHRemoteCommand) fullSshOpts = append(fullSshOpts, SSHRemoteCommand)
@ -430,7 +416,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sende
} }
func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) { func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
cmd := MakeShExec(pk.SessionId, pk.CmdId) cmd := MakeShExec(pk.CK)
cmd.Cmd = exec.Command("bash", "-c", pk.Command) cmd.Cmd = exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(cmd.Cmd, pk.Env) UpdateCmdEnv(cmd.Cmd, pk.Env)
if pk.Cwd != "" { if pk.Cwd != "" {
@ -491,7 +477,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
} }
func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) { func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId) fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -499,7 +485,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
if err == nil { // non-nil error will be caught by regular OpenFile below if err == nil { // non-nil error will be caught by regular OpenFile below
// must have size 0 // must have size 0
if ptyOutInfo.Size() != 0 { if ptyOutInfo.Size() != 0 {
return nil, fmt.Errorf("cmdid '%s' was already used (ptyout len=%d)", pk.CmdId, ptyOutInfo.Size()) return nil, fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size())
} }
} }
cmdPty, cmdTty, err := pty.Open() cmdPty, cmdTty, err := pty.Open()
@ -510,7 +496,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
defer func() { defer func() {
cmdTty.Close() cmdTty.Close()
}() }()
rtn := MakeShExec(pk.SessionId, pk.CmdId) rtn := MakeShExec(pk.CK)
ecmd := MakeExecCmd(pk, cmdTty) ecmd := MakeExecCmd(pk, cmdTty)
err = ecmd.Start() err = ecmd.Start()
if err != nil { if err != nil {
@ -558,8 +544,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
exitCode := GetExitCode(exitErr) exitCode := GetExitCode(exitErr)
donePacket := packet.MakeCmdDonePacket() donePacket := packet.MakeCmdDonePacket()
donePacket.Ts = endTs.UnixMilli() donePacket.Ts = endTs.UnixMilli()
donePacket.SessionId = c.SessionId donePacket.CK = c.CK
donePacket.CmdId = c.CmdId
donePacket.ExitCode = exitCode donePacket.ExitCode = exitCode
donePacket.DurationMs = int64(cmdDuration / time.Millisecond) donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
if c.FileNames != nil { if c.FileNames != nil {