diff --git a/main-mshell.go b/main-mshell.go index fd9f03fd4..64ff09488 100644 --- a/main-mshell.go +++ b/main-mshell.go @@ -15,7 +15,6 @@ import ( "syscall" "time" - "github.com/google/uuid" "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/cmdtail" "github.com/scripthaus-dev/mshell/pkg/packet" @@ -38,7 +37,7 @@ func setupSingleSignals(cmd *shexec.ShExecType) { }() } -func doSingle(cmdId string) { +func doSingle(ck base.CommandKey) { packetCh := packet.PacketParser(os.Stdin) sender := packet.MakePacketSender(os.Stdout) var runPacket *packet.RunPacketType @@ -57,11 +56,11 @@ func doSingle(cmdId string) { sender.SendErrorPacket("did not receive a 'run' packet") return } - if runPacket.CmdId == "" { - runPacket.CmdId = cmdId + if runPacket.CK.IsEmpty() { + runPacket.CK = ck } - if runPacket.CmdId != cmdId { - sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId)) + if runPacket.CK != ck { + sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CK, ck)) return } cmd, err := shexec.RunCommand(runPacket, sender) @@ -79,39 +78,36 @@ func doSingle(cmdId string) { } func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { - if pk.CmdId == "" { - pk.CmdId = uuid.New().String() - } err := shexec.ValidateRunPacket(pk) if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))) return } - fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId) + fileNames, err := base.GetCommandFileNames(pk.CK) if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))) return } - cmd, err := shexec.MakeRunnerExec(pk.CmdId) + cmd, err := shexec.MakeRunnerExec(pk.CK) if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make mshell command: %v", err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))) return } cmdStdin, err := cmd.StdinPipe() if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))) return } // touch ptyout file (should exist for tailer to work correctly) ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))) return } ptyOutFd.Close() // just opened to create the file, can close right after runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600) if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))) return } defer runnerOutFd.Close() @@ -119,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) { cmd.Stderr = runnerOutFd err = cmd.Start() if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))) return } go func() { err = packet.SendPacket(cmdStdin, pk) if err != nil { - sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err))) + sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))) return } cmdStdin.Close() @@ -451,12 +447,12 @@ func main() { } if len(os.Args) >= 2 { - cmdId, err := uuid.Parse(os.Args[1]) - if err != nil { - packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to mshell", err)) + ck := base.CommandKey(os.Args[1]) + if err := ck.Validate("mshell arg"); err != nil { + packet.SendErrorPacket(os.Stdout, err.Error()) return } - doSingle(cmdId.String()) + doSingle(ck) time.Sleep(100 * time.Millisecond) return } else { diff --git a/pkg/base/base.go b/pkg/base/base.go index eae9a8ccb..39398e014 100644 --- a/pkg/base/base.go +++ b/pkg/base/base.go @@ -15,6 +15,8 @@ import ( "path" "path/filepath" "strings" + + "github.com/google/uuid" ) const DefaultMShellPath = "mshell" @@ -37,6 +39,68 @@ type CommandFileNames struct { RunnerOutFile string } +type CommandKey string + +func MakeCommandKey(sessionId string, cmdId string) CommandKey { + if sessionId == "" && cmdId == "" { + return CommandKey("") + } + return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId)) +} + +func (ckey CommandKey) IsEmpty() bool { + return string(ckey) == "" +} + +func (ckey CommandKey) GetSessionId() string { + slashIdx := strings.Index(string(ckey), "/") + if slashIdx == -1 { + return "" + } + return string(ckey[0:slashIdx]) +} + +func (ckey CommandKey) GetCmdId() string { + slashIdx := strings.Index(string(ckey), "/") + if slashIdx == -1 { + return "" + } + return string(ckey[slashIdx+1:]) +} + +func (ckey CommandKey) Split() (string, string) { + fields := strings.SplitN(string(ckey), "/", 2) + if len(fields) < 2 { + return "", "" + } + return fields[0], fields[1] +} + +func (ckey CommandKey) Validate(typeStr string) error { + if typeStr == "" { + typeStr = "ck" + } + if ckey == "" { + return fmt.Errorf("%s has empty commandkey", typeStr) + } + sessionId, cmdId := ckey.Split() + if sessionId == "" { + return fmt.Errorf("%s does not have sessionid", typeStr) + } + _, err := uuid.Parse(sessionId) + if err != nil { + return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId) + } + if cmdId == "" { + return fmt.Errorf("%s does not have cmdid", typeStr) + } + _, err = uuid.Parse(cmdId) + if err != nil { + return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId) + } + return nil +} + func GetHomeDir() string { homeVar := os.Getenv(HomeVarName) if homeVar == "" { @@ -57,10 +121,11 @@ func GetScHomeDir() (string, error) { return scHome, nil } -func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, error) { - if sessionId == "" || cmdId == "" { - return nil, fmt.Errorf("cannot get command-files when sessionid or cmdid is empty") +func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) { + if err := ck.Validate("ck"); err != nil { + return nil, fmt.Errorf("cannot get command files: %w", err) } + sessionId, cmdId := ck.Split() sdir, err := EnsureSessionDir(sessionId) if err != nil { return nil, err @@ -73,8 +138,8 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err }, nil } -func MakeCommandFileNamesWithHome(scHome string, sessionId string, cmdId string) *CommandFileNames { - base := path.Join(scHome, SessionsDirBaseName, sessionId, cmdId) +func MakeCommandFileNamesWithHome(scHome string, ck CommandKey) *CommandFileNames { + base := path.Join(scHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId()) return &CommandFileNames{ PtyOutFile: base + ".ptyout", StdinFifo: base + ".stdin", diff --git a/pkg/cmdtail/cmdtail.go b/pkg/cmdtail/cmdtail.go index 51cdbfd8a..517b35f8c 100644 --- a/pkg/cmdtail/cmdtail.go +++ b/pkg/cmdtail/cmdtail.go @@ -15,7 +15,6 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/google/uuid" "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/packet" ) @@ -33,7 +32,7 @@ type TailPos struct { } type CmdWatchEntry struct { - CmdKey CmdKey + CmdKey base.CommandKey FilePtyLen int64 FileRunLen int64 Tails []TailPos @@ -73,20 +72,15 @@ func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool { return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen } -type CmdKey struct { - SessionId string - CmdId string -} - type Tailer struct { Lock *sync.Mutex - WatchList map[CmdKey]CmdWatchEntry + WatchList map[base.CommandKey]CmdWatchEntry ScHomeDir string Watcher *fsnotify.Watcher SendCh chan packet.PacketType } -func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos) { +func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) { entry, found := t.WatchList[cmdKey] if !found { return @@ -95,7 +89,7 @@ func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos) t.WatchList[cmdKey] = entry } -func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) { +func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) { entry, found := t.WatchList[cmdKey] if !found { return @@ -107,13 +101,13 @@ func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) { } // delete from watchlist, remove watches - fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey.SessionId, cmdKey.CmdId) + fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey) delete(t.WatchList, cmdKey) t.Watcher.Remove(fileNames.PtyOutFile) t.Watcher.Remove(fileNames.RunnerOutFile) } -func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int64) { +func (t *Tailer) updateEntrySizes_nolock(cmdKey base.CommandKey, ptyLen int64, runLen int64) { entry, found := t.WatchList[cmdKey] if !found { return @@ -123,7 +117,7 @@ func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int t.WatchList[cmdKey] = entry } -func (t *Tailer) getEntryAndPos_nolock(cmdKey CmdKey, reqId string) (CmdWatchEntry, TailPos, bool) { +func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) { entry, found := t.WatchList[cmdKey] if !found { return CmdWatchEntry{}, TailPos{}, false @@ -142,7 +136,7 @@ func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) { } rtn := &Tailer{ Lock: &sync.Mutex{}, - WatchList: make(map[CmdKey]CmdWatchEntry), + WatchList: make(map[base.CommandKey]CmdWatchEntry), ScHomeDir: scHomeDir, SendCh: sendCh, } @@ -170,8 +164,7 @@ func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]b func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) *packet.CmdDataPacketType { dataPacket := packet.MakeCmdDataPacket() dataPacket.ReqId = pos.ReqId - dataPacket.SessionId = entry.CmdKey.SessionId - dataPacket.CmdId = entry.CmdKey.CmdId + dataPacket.CK = entry.CmdKey dataPacket.PtyPos = pos.TailPtyPos dataPacket.RunPos = pos.TailRunPos if entry.FilePtyLen > pos.TailPtyPos { @@ -196,14 +189,14 @@ func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWa } // returns (data-packet, keepRunning) -func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDataPacketType, bool) { +func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool) { t.Lock.Lock() entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId) t.Lock.Unlock() if !foundPos { return nil, false } - fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key.SessionId, key.CmdId) + fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key) dataPacket := t.makeCmdDataPacket(fileNames, entry, pos) t.Lock.Lock() @@ -232,7 +225,7 @@ func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDat return dataPacket, pos.Running } -func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) { +func (t *Tailer) checkRemoveNoFollow(cmdKey base.CommandKey, reqId string) { t.Lock.Lock() defer t.Lock.Unlock() _, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId) @@ -244,7 +237,7 @@ func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) { } } -func (t *Tailer) RunDataTransfer(key CmdKey, reqId string) { +func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) { for { dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId) if dataPacket != nil { @@ -283,7 +276,7 @@ func (t *Tailer) updateFile(relFileName string) { t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err) return } - cmdKey := CmdKey{SessionId: m[1], CmdId: m[2]} + cmdKey := base.MakeCommandKey(m[1], m[2]) t.Lock.Lock() defer t.Lock.Unlock() entry, foundEntry := t.WatchList[cmdKey] @@ -336,7 +329,7 @@ func max(v1 int64, v2 int64) int64 { } func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) { - fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey.SessionId, entry.CmdKey.CmdId) + fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey) ptyInfo, _ := os.Stat(fileNames.PtyOutFile) if ptyInfo != nil { entry.FilePtyLen = ptyInfo.Size() @@ -350,30 +343,24 @@ func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) { func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) { t.Lock.Lock() defer t.Lock.Unlock() - key := CmdKey{pk.SessionId, pk.CmdId} - t.removeTailPos_nolock(key, pk.ReqId) + t.removeTailPos_nolock(pk.CK, pk.ReqId) } func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error { - _, err := uuid.Parse(getPacket.SessionId) - if err != nil { - return fmt.Errorf("getcmd, bad sessionid '%s': %w", getPacket.SessionId, err) - } - _, err = uuid.Parse(getPacket.CmdId) - if err != nil { - return fmt.Errorf("getcmd, bad cmdid '%s': %w", getPacket.CmdId, err) + if err := getPacket.CK.Validate("getcmd"); err != nil { + return err } if getPacket.ReqId == "" { return fmt.Errorf("getcmd, no reqid specified") } - fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.SessionId, getPacket.CmdId) + fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.CK) t.Lock.Lock() defer t.Lock.Unlock() - key := CmdKey{getPacket.SessionId, getPacket.CmdId} + key := getPacket.CK entry, foundEntry := t.WatchList[key] if !foundEntry { // add watches, initialize entry - err = t.Watcher.Add(fileNames.PtyOutFile) + err := t.Watcher.Add(fileNames.PtyOutFile) if err != nil { return err } diff --git a/pkg/mpio/mpio.go b/pkg/mpio/mpio.go index a1b93bd4f..ef56355f2 100644 --- a/pkg/mpio/mpio.go +++ b/pkg/mpio/mpio.go @@ -12,6 +12,7 @@ import ( "os" "sync" + "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/packet" ) @@ -21,8 +22,7 @@ const MaxSingleWriteSize = 4 * 1024 type Multiplexer struct { Lock *sync.Mutex - SessionId string - CmdId string + CK base.CommandKey FdReaders map[int]*FdReader // synchronized FdWriters map[int]*FdWriter // synchronized CloseAfterStart []*os.File // synchronized @@ -34,11 +34,10 @@ type Multiplexer struct { Debug bool } -func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer { +func MakeMultiplexer(ck base.CommandKey) *Multiplexer { return &Multiplexer{ Lock: &sync.Mutex{}, - SessionId: sessionId, - CmdId: cmdId, + CK: ck, FdReaders: make(map[int]*FdReader), FdWriters: make(map[int]*FdWriter), } @@ -126,8 +125,7 @@ func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType { ack := packet.MakeDataAckPacket() - ack.SessionId = m.SessionId - ack.CmdId = m.CmdId + ack.CK = m.CK ack.FdNum = fdNum ack.AckLen = ackLen if err != nil { @@ -138,8 +136,7 @@ func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packe func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType { pk := packet.MakeDataPacket() - pk.SessionId = m.SessionId - pk.CmdId = m.CmdId + pk.CK = m.CK pk.FdNum = fdNum pk.Data64 = base64.StdEncoding.EncodeToString(data) if err != nil { diff --git a/pkg/packet/packet.go b/pkg/packet/packet.go index 3c6d9dcec..d996ae9a0 100644 --- a/pkg/packet/packet.go +++ b/pkg/packet/packet.go @@ -16,6 +16,8 @@ import ( "strconv" "strings" "sync" + + "github.com/scripthaus-dev/mshell/pkg/base" ) // remote: init, run, ping, data, cmdstart, cmddone @@ -80,26 +82,29 @@ func MakePacket(packetType string) (PacketType, error) { } type CmdDataPacketType struct { - Type string `json:"type"` - ReqId string `json:"reqid"` - SessionId string `json:"sessionid"` - CmdId string `json:"cmdid"` - PtyPos int64 `json:"ptypos"` - PtyLen int64 `json:"ptylen"` - RunPos int64 `json:"runpos"` - RunLen int64 `json:"runlen"` - PtyData string `json:"ptydata"` - PtyDataLen int `json:"ptydatalen"` - RunData string `json:"rundata"` - RunDataLen int `json:"rundatalen"` - Error string `json:"error"` - NotFound bool `json:"notfound,omitempty"` + Type string `json:"type"` + ReqId string `json:"reqid"` + CK base.CommandKey `json:"ck"` + PtyPos int64 `json:"ptypos"` + PtyLen int64 `json:"ptylen"` + RunPos int64 `json:"runpos"` + RunLen int64 `json:"runlen"` + PtyData string `json:"ptydata"` + PtyDataLen int `json:"ptydatalen"` + RunData string `json:"rundata"` + RunDataLen int `json:"rundatalen"` + Error string `json:"error"` + NotFound bool `json:"notfound,omitempty"` } func (*CmdDataPacketType) GetType() string { return CmdDataPacketStr } +func (p *CmdDataPacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeCmdDataPacket() *CmdDataPacketType { return &CmdDataPacketType{Type: CmdDataPacketStr} } @@ -117,19 +122,22 @@ func MakePingPacket() *PingPacketType { } type DataPacketType struct { - Type string `json:"type"` - SessionId string `json:"sessionid,omitempty"` - CmdId string `json:"cmdid,omitempty"` - FdNum int `json:"fdnum"` - Data64 string `json:"data64"` // base64 encoded - Eof bool `json:"eof,omitempty"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + CK base.CommandKey `json:"ck"` + FdNum int `json:"fdnum"` + Data64 string `json:"data64"` // base64 encoded + Eof bool `json:"eof,omitempty"` + Error string `json:"error,omitempty"` } func (*DataPacketType) GetType() string { return DataPacketStr } +func (p *DataPacketType) GetCK() base.CommandKey { + return p.CK +} + func B64DecodedLen(b64 string) int { if len(b64) < 4 { return 0 // we use padded strings, so < 4 is always 0 @@ -161,18 +169,21 @@ func MakeDataPacket() *DataPacketType { } type DataAckPacketType struct { - Type string `json:"type"` - SessionId string `json:"sessionid,omitempty"` - CmdId string `json:"cmdid,omitempty"` - FdNum int `json:"fdnum"` - AckLen int `json:"acklen"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + CK base.CommandKey `json:"ck"` + FdNum int `json:"fdnum"` + AckLen int `json:"acklen"` + Error string `json:"error,omitempty"` } func (*DataAckPacketType) GetType() string { return DataAckPacketStr } +func (p *DataAckPacketType) GetCK() base.CommandKey { + return p.CK +} + func (p *DataAckPacketType) String() string { errStr := "" if p.Error != "" { @@ -189,52 +200,61 @@ func MakeDataAckPacket() *DataAckPacketType { // SigNum gets sent to process via a signal // WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH type InputPacketType struct { - Type string `json:"type"` - SessionId string `json:"sessionid"` - CmdId string `json:"cmdid"` - InputData string `json:"inputdata"` - SigNum int `json:"signum,omitempty"` - WinSizeRows int `json:"winsizerows"` - WinSizeCols int `json:"winsizecols"` + Type string `json:"type"` + CK base.CommandKey `json:"ck"` + InputData string `json:"inputdata"` + SigNum int `json:"signum,omitempty"` + WinSizeRows int `json:"winsizerows"` + WinSizeCols int `json:"winsizecols"` } func (*InputPacketType) GetType() string { return InputPacketStr } +func (p *InputPacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeInputPacket() *InputPacketType { return &InputPacketType{Type: InputPacketStr} } type UntailCmdPacketType struct { - Type string `json:"type"` - ReqId string `json:"reqid"` - SessionId string `json:"sessionid"` - CmdId string `json:"cmdid"` + Type string `json:"type"` + ReqId string `json:"reqid"` + CK base.CommandKey `json:"ck"` } func (*UntailCmdPacketType) GetType() string { return UntailCmdPacketStr } +func (p *UntailCmdPacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeUntailCmdPacket() *UntailCmdPacketType { return &UntailCmdPacketType{Type: UntailCmdPacketStr} } type GetCmdPacketType struct { - Type string `json:"type"` - ReqId string `json:"reqid"` - SessionId string `json:"sessionid"` - CmdId string `json:"cmdid"` - PtyPos int64 `json:"ptypos"` - RunPos int64 `json:"runpos"` - Tail bool `json:"tail,omitempty"` + Type string `json:"type"` + ReqId string `json:"reqid"` + CK base.CommandKey `json:"ck"` + PtyPos int64 `json:"ptypos"` + RunPos int64 `json:"runpos"` + Tail bool `json:"tail,omitempty"` } func (*GetCmdPacketType) GetType() string { return GetCmdPacketStr } +func (p *GetCmdPacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeGetCmdPacket() *GetCmdPacketType { return &GetCmdPacketType{Type: GetCmdPacketStr} } @@ -346,35 +366,41 @@ func MakeDonePacket() *DonePacketType { } type CmdDonePacketType struct { - Type string `json:"type"` - Ts int64 `json:"ts"` - SessionId string `json:"sessionid,omitempty"` - CmdId string `json:"cmdid,omitempty"` - ExitCode int `json:"exitcode"` - DurationMs int64 `json:"durationms"` + Type string `json:"type"` + Ts int64 `json:"ts"` + CK base.CommandKey `json:"ck"` + ExitCode int `json:"exitcode"` + DurationMs int64 `json:"durationms"` } func (*CmdDonePacketType) GetType() string { return CmdDonePacketStr } +func (p *CmdDonePacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeCmdDonePacket() *CmdDonePacketType { return &CmdDonePacketType{Type: CmdDonePacketStr} } type CmdStartPacketType struct { - Type string `json:"type"` - Ts int64 `json:"ts"` - SessionId string `json:"sessionid,omitempty"` - CmdId string `json:"cmdid,omitempty"` - Pid int `json:"pid"` - MShellPid int `json:"mshellpid"` + Type string `json:"type"` + Ts int64 `json:"ts"` + CK base.CommandKey `json:"ck"` + Pid int `json:"pid"` + MShellPid int `json:"mshellpid"` } func (*CmdStartPacketType) GetType() string { return CmdStartPacketStr } +func (p *CmdStartPacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeCmdStartPacket() *CmdStartPacketType { return &CmdStartPacketType{Type: CmdStartPacketStr} } @@ -393,21 +419,24 @@ type RemoteFd struct { } type RunPacketType struct { - Type string `json:"type"` - SessionId string `json:"sessionid,omitempty"` - CmdId string `json:"cmdid,omitempty"` - Command string `json:"command"` - Cwd string `json:"cwd,omitempty"` - Env map[string]string `json:"env,omitempty"` - TermSize *TermSize `json:"termsize,omitempty"` - Fds []RemoteFd `json:"fds,omitempty"` - Detached bool `json:"detached,omitempty"` + Type string `json:"type"` + CK base.CommandKey `json:"ck"` + Command string `json:"command"` + Cwd string `json:"cwd,omitempty"` + Env map[string]string `json:"env,omitempty"` + TermSize *TermSize `json:"termsize,omitempty"` + Fds []RemoteFd `json:"fds,omitempty"` + Detached bool `json:"detached,omitempty"` } func (*RunPacketType) GetType() string { return RunPacketStr } +func (p *RunPacketType) GetCK() base.CommandKey { + return p.CK +} + func MakeRunPacket() *RunPacketType { return &RunPacketType{Type: RunPacketStr} } @@ -417,9 +446,9 @@ type BarePacketType struct { } type ErrorPacketType struct { - Id string `json:"id,omitempty"` - Type string `json:"type"` - Error string `json:"error"` + CK base.CommandKey `json:"ck,omitempty"` + Type string `json:"type"` + Error string `json:"error"` } func (et *ErrorPacketType) GetType() string { @@ -430,8 +459,8 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType { return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr} } -func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType { - return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr} +func MakeCKErrorPacket(ck base.CommandKey, errorStr string) *ErrorPacketType { + return &ErrorPacketType{Type: ErrorPacketStr, CK: ck, Error: errorStr} } type PacketType interface { @@ -450,6 +479,11 @@ type RpcPacketType interface { GetPacketId() string } +type CommandPacketType interface { + GetType() string + GetCK() base.CommandKey +} + func ParseJsonPacket(jsonBuf []byte) (PacketType, error) { var bareCmd BarePacketType err := json.Unmarshal(jsonBuf, &bareCmd) diff --git a/pkg/shexec/shexec.go b/pkg/shexec/shexec.go index 73afa925c..5ea68d2b3 100644 --- a/pkg/shexec/shexec.go +++ b/pkg/shexec/shexec.go @@ -17,7 +17,6 @@ import ( "time" "github.com/creack/pty" - "github.com/google/uuid" "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/mpio" "github.com/scripthaus-dev/mshell/pkg/packet" @@ -39,21 +38,19 @@ const RemoteSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -S -C %d bash -c "ec type ShExecType struct { Lock *sync.Mutex StartTs time.Time - SessionId string - CmdId string + CK base.CommandKey FileNames *base.CommandFileNames Cmd *exec.Cmd CmdPty *os.File Multiplexer *mpio.Multiplexer } -func MakeShExec(sessionId string, cmdId string) *ShExecType { +func MakeShExec(ck base.CommandKey) *ShExecType { return &ShExecType{ Lock: &sync.Mutex{}, StartTs: time.Now(), - SessionId: sessionId, - CmdId: cmdId, - Multiplexer: mpio.MakeMultiplexer(sessionId, cmdId), + CK: ck, + Multiplexer: mpio.MakeMultiplexer(ck), } } @@ -67,8 +64,7 @@ func (c *ShExecType) Close() { func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType { startPacket := packet.MakeCmdStartPacket() startPacket.Ts = time.Now().UnixMilli() - startPacket.SessionId = c.SessionId - startPacket.CmdId = c.CmdId + startPacket.CK = c.CK startPacket.Pid = c.Cmd.Process.Pid startPacket.MShellPid = os.Getpid() return startPacket @@ -129,12 +125,12 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd { return ecmd } -func MakeRunnerExec(cmdId string) (*exec.Cmd, error) { +func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) { msPath, err := base.GetMShellPath() if err != nil { return nil, err } - ecmd := exec.Command(msPath, cmdId) + ecmd := exec.Command(msPath, string(ck)) return ecmd, nil } @@ -165,19 +161,9 @@ func ValidateRunPacket(pk *packet.RunPacketType) error { return fmt.Errorf("run packet has wrong type: %s", pk.Type) } if pk.Detached { - if pk.SessionId == "" { - return fmt.Errorf("run packet does not have sessionid") - } - _, err := uuid.Parse(pk.SessionId) + err := pk.CK.Validate("run packet") if err != nil { - return fmt.Errorf("invalid sessionid '%s' for command", pk.SessionId) - } - if pk.CmdId == "" { - return fmt.Errorf("run packet does not have cmdid") - } - _, err = uuid.Parse(pk.CmdId) - if err != nil { - return fmt.Errorf("invalid cmdid '%s' for command", pk.CmdId) + return err } } if pk.Cwd != "" { @@ -337,7 +323,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er if err != nil { return nil, err } - cmd := MakeShExec("", "") + cmd := MakeShExec("") var fullSshOpts []string fullSshOpts = append(fullSshOpts, opts.SSHOpts...) fullSshOpts = append(fullSshOpts, SSHRemoteCommand) @@ -430,7 +416,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sende } func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) { - cmd := MakeShExec(pk.SessionId, pk.CmdId) + cmd := MakeShExec(pk.CK) cmd.Cmd = exec.Command("bash", "-c", pk.Command) UpdateCmdEnv(cmd.Cmd, pk.Env) if pk.Cwd != "" { @@ -491,7 +477,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S } func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) { - fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId) + fileNames, err := base.GetCommandFileNames(pk.CK) if err != nil { return nil, err } @@ -499,7 +485,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) ( if err == nil { // non-nil error will be caught by regular OpenFile below // must have size 0 if ptyOutInfo.Size() != 0 { - return nil, fmt.Errorf("cmdid '%s' was already used (ptyout len=%d)", pk.CmdId, ptyOutInfo.Size()) + return nil, fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size()) } } cmdPty, cmdTty, err := pty.Open() @@ -510,7 +496,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) ( defer func() { cmdTty.Close() }() - rtn := MakeShExec(pk.SessionId, pk.CmdId) + rtn := MakeShExec(pk.CK) ecmd := MakeExecCmd(pk, cmdTty) err = ecmd.Start() if err != nil { @@ -558,8 +544,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType { exitCode := GetExitCode(exitErr) donePacket := packet.MakeCmdDonePacket() donePacket.Ts = endTs.UnixMilli() - donePacket.SessionId = c.SessionId - donePacket.CmdId = c.CmdId + donePacket.CK = c.CK donePacket.ExitCode = exitCode donePacket.DurationMs = int64(cmdDuration / time.Millisecond) if c.FileNames != nil {