combine sessionid and cmdid into one field ck (commandkey)

This commit is contained in:
sawka 2022-06-27 12:03:47 -07:00
parent 1ea8393844
commit 2a6791bcd6
6 changed files with 236 additions and 172 deletions

View File

@ -15,7 +15,6 @@ import (
"syscall"
"time"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet"
@ -38,7 +37,7 @@ func setupSingleSignals(cmd *shexec.ShExecType) {
}()
}
func doSingle(cmdId string) {
func doSingle(ck base.CommandKey) {
packetCh := packet.PacketParser(os.Stdin)
sender := packet.MakePacketSender(os.Stdout)
var runPacket *packet.RunPacketType
@ -57,11 +56,11 @@ func doSingle(cmdId string) {
sender.SendErrorPacket("did not receive a 'run' packet")
return
}
if runPacket.CmdId == "" {
runPacket.CmdId = cmdId
if runPacket.CK.IsEmpty() {
runPacket.CK = ck
}
if runPacket.CmdId != cmdId {
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId))
if runPacket.CK != ck {
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CK, ck))
return
}
cmd, err := shexec.RunCommand(runPacket, sender)
@ -79,39 +78,36 @@ func doSingle(cmdId string) {
}
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
if pk.CmdId == "" {
pk.CmdId = uuid.New().String()
}
err := shexec.ValidateRunPacket(pk)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err)))
return
}
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId)
fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err)))
return
}
cmd, err := shexec.MakeRunnerExec(pk.CmdId)
cmd, err := shexec.MakeRunnerExec(pk.CK)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make mshell command: %v", err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err)))
return
}
cmdStdin, err := cmd.StdinPipe()
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
return
}
// touch ptyout file (should exist for tailer to work correctly)
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
return
}
ptyOutFd.Close() // just opened to create the file, can close right after
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
return
}
defer runnerOutFd.Close()
@ -119,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
cmd.Stderr = runnerOutFd
err = cmd.Start()
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err)))
return
}
go func() {
err = packet.SendPacket(cmdStdin, pk)
if err != nil {
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err)))
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err)))
return
}
cmdStdin.Close()
@ -451,12 +447,12 @@ func main() {
}
if len(os.Args) >= 2 {
cmdId, err := uuid.Parse(os.Args[1])
if err != nil {
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to mshell", err))
ck := base.CommandKey(os.Args[1])
if err := ck.Validate("mshell arg"); err != nil {
packet.SendErrorPacket(os.Stdout, err.Error())
return
}
doSingle(cmdId.String())
doSingle(ck)
time.Sleep(100 * time.Millisecond)
return
} else {

View File

@ -15,6 +15,8 @@ import (
"path"
"path/filepath"
"strings"
"github.com/google/uuid"
)
const DefaultMShellPath = "mshell"
@ -37,6 +39,68 @@ type CommandFileNames struct {
RunnerOutFile string
}
type CommandKey string
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
if sessionId == "" && cmdId == "" {
return CommandKey("")
}
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
}
func (ckey CommandKey) IsEmpty() bool {
return string(ckey) == ""
}
func (ckey CommandKey) GetSessionId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[0:slashIdx])
}
func (ckey CommandKey) GetCmdId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[slashIdx+1:])
}
func (ckey CommandKey) Split() (string, string) {
fields := strings.SplitN(string(ckey), "/", 2)
if len(fields) < 2 {
return "", ""
}
return fields[0], fields[1]
}
func (ckey CommandKey) Validate(typeStr string) error {
if typeStr == "" {
typeStr = "ck"
}
if ckey == "" {
return fmt.Errorf("%s has empty commandkey", typeStr)
}
sessionId, cmdId := ckey.Split()
if sessionId == "" {
return fmt.Errorf("%s does not have sessionid", typeStr)
}
_, err := uuid.Parse(sessionId)
if err != nil {
return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId)
}
if cmdId == "" {
return fmt.Errorf("%s does not have cmdid", typeStr)
}
_, err = uuid.Parse(cmdId)
if err != nil {
return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId)
}
return nil
}
func GetHomeDir() string {
homeVar := os.Getenv(HomeVarName)
if homeVar == "" {
@ -57,10 +121,11 @@ func GetScHomeDir() (string, error) {
return scHome, nil
}
func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, error) {
if sessionId == "" || cmdId == "" {
return nil, fmt.Errorf("cannot get command-files when sessionid or cmdid is empty")
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
if err := ck.Validate("ck"); err != nil {
return nil, fmt.Errorf("cannot get command files: %w", err)
}
sessionId, cmdId := ck.Split()
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return nil, err
@ -73,8 +138,8 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err
}, nil
}
func MakeCommandFileNamesWithHome(scHome string, sessionId string, cmdId string) *CommandFileNames {
base := path.Join(scHome, SessionsDirBaseName, sessionId, cmdId)
func MakeCommandFileNamesWithHome(scHome string, ck CommandKey) *CommandFileNames {
base := path.Join(scHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId())
return &CommandFileNames{
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",

View File

@ -15,7 +15,6 @@ import (
"time"
"github.com/fsnotify/fsnotify"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
@ -33,7 +32,7 @@ type TailPos struct {
}
type CmdWatchEntry struct {
CmdKey CmdKey
CmdKey base.CommandKey
FilePtyLen int64
FileRunLen int64
Tails []TailPos
@ -73,20 +72,15 @@ func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool {
return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen
}
type CmdKey struct {
SessionId string
CmdId string
}
type Tailer struct {
Lock *sync.Mutex
WatchList map[CmdKey]CmdWatchEntry
WatchList map[base.CommandKey]CmdWatchEntry
ScHomeDir string
Watcher *fsnotify.Watcher
SendCh chan packet.PacketType
}
func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos) {
func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) {
entry, found := t.WatchList[cmdKey]
if !found {
return
@ -95,7 +89,7 @@ func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos)
t.WatchList[cmdKey] = entry
}
func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) {
func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
entry, found := t.WatchList[cmdKey]
if !found {
return
@ -107,13 +101,13 @@ func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) {
}
// delete from watchlist, remove watches
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey.SessionId, cmdKey.CmdId)
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey)
delete(t.WatchList, cmdKey)
t.Watcher.Remove(fileNames.PtyOutFile)
t.Watcher.Remove(fileNames.RunnerOutFile)
}
func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int64) {
func (t *Tailer) updateEntrySizes_nolock(cmdKey base.CommandKey, ptyLen int64, runLen int64) {
entry, found := t.WatchList[cmdKey]
if !found {
return
@ -123,7 +117,7 @@ func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int
t.WatchList[cmdKey] = entry
}
func (t *Tailer) getEntryAndPos_nolock(cmdKey CmdKey, reqId string) (CmdWatchEntry, TailPos, bool) {
func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) {
entry, found := t.WatchList[cmdKey]
if !found {
return CmdWatchEntry{}, TailPos{}, false
@ -142,7 +136,7 @@ func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) {
}
rtn := &Tailer{
Lock: &sync.Mutex{},
WatchList: make(map[CmdKey]CmdWatchEntry),
WatchList: make(map[base.CommandKey]CmdWatchEntry),
ScHomeDir: scHomeDir,
SendCh: sendCh,
}
@ -170,8 +164,7 @@ func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]b
func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) *packet.CmdDataPacketType {
dataPacket := packet.MakeCmdDataPacket()
dataPacket.ReqId = pos.ReqId
dataPacket.SessionId = entry.CmdKey.SessionId
dataPacket.CmdId = entry.CmdKey.CmdId
dataPacket.CK = entry.CmdKey
dataPacket.PtyPos = pos.TailPtyPos
dataPacket.RunPos = pos.TailRunPos
if entry.FilePtyLen > pos.TailPtyPos {
@ -196,14 +189,14 @@ func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWa
}
// returns (data-packet, keepRunning)
func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDataPacketType, bool) {
func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool) {
t.Lock.Lock()
entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId)
t.Lock.Unlock()
if !foundPos {
return nil, false
}
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key.SessionId, key.CmdId)
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key)
dataPacket := t.makeCmdDataPacket(fileNames, entry, pos)
t.Lock.Lock()
@ -232,7 +225,7 @@ func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDat
return dataPacket, pos.Running
}
func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) {
func (t *Tailer) checkRemoveNoFollow(cmdKey base.CommandKey, reqId string) {
t.Lock.Lock()
defer t.Lock.Unlock()
_, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId)
@ -244,7 +237,7 @@ func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) {
}
}
func (t *Tailer) RunDataTransfer(key CmdKey, reqId string) {
func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) {
for {
dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId)
if dataPacket != nil {
@ -283,7 +276,7 @@ func (t *Tailer) updateFile(relFileName string) {
t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err)
return
}
cmdKey := CmdKey{SessionId: m[1], CmdId: m[2]}
cmdKey := base.MakeCommandKey(m[1], m[2])
t.Lock.Lock()
defer t.Lock.Unlock()
entry, foundEntry := t.WatchList[cmdKey]
@ -336,7 +329,7 @@ func max(v1 int64, v2 int64) int64 {
}
func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) {
fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey.SessionId, entry.CmdKey.CmdId)
fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey)
ptyInfo, _ := os.Stat(fileNames.PtyOutFile)
if ptyInfo != nil {
entry.FilePtyLen = ptyInfo.Size()
@ -350,30 +343,24 @@ func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) {
func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) {
t.Lock.Lock()
defer t.Lock.Unlock()
key := CmdKey{pk.SessionId, pk.CmdId}
t.removeTailPos_nolock(key, pk.ReqId)
t.removeTailPos_nolock(pk.CK, pk.ReqId)
}
func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
_, err := uuid.Parse(getPacket.SessionId)
if err != nil {
return fmt.Errorf("getcmd, bad sessionid '%s': %w", getPacket.SessionId, err)
}
_, err = uuid.Parse(getPacket.CmdId)
if err != nil {
return fmt.Errorf("getcmd, bad cmdid '%s': %w", getPacket.CmdId, err)
if err := getPacket.CK.Validate("getcmd"); err != nil {
return err
}
if getPacket.ReqId == "" {
return fmt.Errorf("getcmd, no reqid specified")
}
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.SessionId, getPacket.CmdId)
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.CK)
t.Lock.Lock()
defer t.Lock.Unlock()
key := CmdKey{getPacket.SessionId, getPacket.CmdId}
key := getPacket.CK
entry, foundEntry := t.WatchList[key]
if !foundEntry {
// add watches, initialize entry
err = t.Watcher.Add(fileNames.PtyOutFile)
err := t.Watcher.Add(fileNames.PtyOutFile)
if err != nil {
return err
}

View File

@ -12,6 +12,7 @@ import (
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
@ -21,8 +22,7 @@ const MaxSingleWriteSize = 4 * 1024
type Multiplexer struct {
Lock *sync.Mutex
SessionId string
CmdId string
CK base.CommandKey
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
CloseAfterStart []*os.File // synchronized
@ -34,11 +34,10 @@ type Multiplexer struct {
Debug bool
}
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
func MakeMultiplexer(ck base.CommandKey) *Multiplexer {
return &Multiplexer{
Lock: &sync.Mutex{},
SessionId: sessionId,
CmdId: cmdId,
CK: ck,
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
}
@ -126,8 +125,7 @@ func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool)
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.SessionId = m.SessionId
ack.CmdId = m.CmdId
ack.CK = m.CK
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
@ -138,8 +136,7 @@ func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packe
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.SessionId = m.SessionId
pk.CmdId = m.CmdId
pk.CK = m.CK
pk.FdNum = fdNum
pk.Data64 = base64.StdEncoding.EncodeToString(data)
if err != nil {

View File

@ -16,6 +16,8 @@ import (
"strconv"
"strings"
"sync"
"github.com/scripthaus-dev/mshell/pkg/base"
)
// remote: init, run, ping, data, cmdstart, cmddone
@ -82,8 +84,7 @@ func MakePacket(packetType string) (PacketType, error) {
type CmdDataPacketType struct {
Type string `json:"type"`
ReqId string `json:"reqid"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
CK base.CommandKey `json:"ck"`
PtyPos int64 `json:"ptypos"`
PtyLen int64 `json:"ptylen"`
RunPos int64 `json:"runpos"`
@ -100,6 +101,10 @@ func (*CmdDataPacketType) GetType() string {
return CmdDataPacketStr
}
func (p *CmdDataPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeCmdDataPacket() *CmdDataPacketType {
return &CmdDataPacketType{Type: CmdDataPacketStr}
}
@ -118,8 +123,7 @@ func MakePingPacket() *PingPacketType {
type DataPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
CK base.CommandKey `json:"ck"`
FdNum int `json:"fdnum"`
Data64 string `json:"data64"` // base64 encoded
Eof bool `json:"eof,omitempty"`
@ -130,6 +134,10 @@ func (*DataPacketType) GetType() string {
return DataPacketStr
}
func (p *DataPacketType) GetCK() base.CommandKey {
return p.CK
}
func B64DecodedLen(b64 string) int {
if len(b64) < 4 {
return 0 // we use padded strings, so < 4 is always 0
@ -162,8 +170,7 @@ func MakeDataPacket() *DataPacketType {
type DataAckPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
CK base.CommandKey `json:"ck"`
FdNum int `json:"fdnum"`
AckLen int `json:"acklen"`
Error string `json:"error,omitempty"`
@ -173,6 +180,10 @@ func (*DataAckPacketType) GetType() string {
return DataAckPacketStr
}
func (p *DataAckPacketType) GetCK() base.CommandKey {
return p.CK
}
func (p *DataAckPacketType) String() string {
errStr := ""
if p.Error != "" {
@ -190,8 +201,7 @@ func MakeDataAckPacket() *DataAckPacketType {
// WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH
type InputPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
CK base.CommandKey `json:"ck"`
InputData string `json:"inputdata"`
SigNum int `json:"signum,omitempty"`
WinSizeRows int `json:"winsizerows"`
@ -202,6 +212,10 @@ func (*InputPacketType) GetType() string {
return InputPacketStr
}
func (p *InputPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeInputPacket() *InputPacketType {
return &InputPacketType{Type: InputPacketStr}
}
@ -209,14 +223,17 @@ func MakeInputPacket() *InputPacketType {
type UntailCmdPacketType struct {
Type string `json:"type"`
ReqId string `json:"reqid"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
CK base.CommandKey `json:"ck"`
}
func (*UntailCmdPacketType) GetType() string {
return UntailCmdPacketStr
}
func (p *UntailCmdPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeUntailCmdPacket() *UntailCmdPacketType {
return &UntailCmdPacketType{Type: UntailCmdPacketStr}
}
@ -224,8 +241,7 @@ func MakeUntailCmdPacket() *UntailCmdPacketType {
type GetCmdPacketType struct {
Type string `json:"type"`
ReqId string `json:"reqid"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
CK base.CommandKey `json:"ck"`
PtyPos int64 `json:"ptypos"`
RunPos int64 `json:"runpos"`
Tail bool `json:"tail,omitempty"`
@ -235,6 +251,10 @@ func (*GetCmdPacketType) GetType() string {
return GetCmdPacketStr
}
func (p *GetCmdPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeGetCmdPacket() *GetCmdPacketType {
return &GetCmdPacketType{Type: GetCmdPacketStr}
}
@ -348,8 +368,7 @@ func MakeDonePacket() *DonePacketType {
type CmdDonePacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
CK base.CommandKey `json:"ck"`
ExitCode int `json:"exitcode"`
DurationMs int64 `json:"durationms"`
}
@ -358,6 +377,10 @@ func (*CmdDonePacketType) GetType() string {
return CmdDonePacketStr
}
func (p *CmdDonePacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeCmdDonePacket() *CmdDonePacketType {
return &CmdDonePacketType{Type: CmdDonePacketStr}
}
@ -365,8 +388,7 @@ func MakeCmdDonePacket() *CmdDonePacketType {
type CmdStartPacketType struct {
Type string `json:"type"`
Ts int64 `json:"ts"`
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
CK base.CommandKey `json:"ck"`
Pid int `json:"pid"`
MShellPid int `json:"mshellpid"`
}
@ -375,6 +397,10 @@ func (*CmdStartPacketType) GetType() string {
return CmdStartPacketStr
}
func (p *CmdStartPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeCmdStartPacket() *CmdStartPacketType {
return &CmdStartPacketType{Type: CmdStartPacketStr}
}
@ -394,8 +420,7 @@ type RemoteFd struct {
type RunPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid,omitempty"`
CmdId string `json:"cmdid,omitempty"`
CK base.CommandKey `json:"ck"`
Command string `json:"command"`
Cwd string `json:"cwd,omitempty"`
Env map[string]string `json:"env,omitempty"`
@ -408,6 +433,10 @@ func (*RunPacketType) GetType() string {
return RunPacketStr
}
func (p *RunPacketType) GetCK() base.CommandKey {
return p.CK
}
func MakeRunPacket() *RunPacketType {
return &RunPacketType{Type: RunPacketStr}
}
@ -417,7 +446,7 @@ type BarePacketType struct {
}
type ErrorPacketType struct {
Id string `json:"id,omitempty"`
CK base.CommandKey `json:"ck,omitempty"`
Type string `json:"type"`
Error string `json:"error"`
}
@ -430,8 +459,8 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr}
}
func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr}
func MakeCKErrorPacket(ck base.CommandKey, errorStr string) *ErrorPacketType {
return &ErrorPacketType{Type: ErrorPacketStr, CK: ck, Error: errorStr}
}
type PacketType interface {
@ -450,6 +479,11 @@ type RpcPacketType interface {
GetPacketId() string
}
type CommandPacketType interface {
GetType() string
GetCK() base.CommandKey
}
func ParseJsonPacket(jsonBuf []byte) (PacketType, error) {
var bareCmd BarePacketType
err := json.Unmarshal(jsonBuf, &bareCmd)

View File

@ -17,7 +17,6 @@ import (
"time"
"github.com/creack/pty"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/mpio"
"github.com/scripthaus-dev/mshell/pkg/packet"
@ -39,21 +38,19 @@ const RemoteSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -S -C %d bash -c "ec
type ShExecType struct {
Lock *sync.Mutex
StartTs time.Time
SessionId string
CmdId string
CK base.CommandKey
FileNames *base.CommandFileNames
Cmd *exec.Cmd
CmdPty *os.File
Multiplexer *mpio.Multiplexer
}
func MakeShExec(sessionId string, cmdId string) *ShExecType {
func MakeShExec(ck base.CommandKey) *ShExecType {
return &ShExecType{
Lock: &sync.Mutex{},
StartTs: time.Now(),
SessionId: sessionId,
CmdId: cmdId,
Multiplexer: mpio.MakeMultiplexer(sessionId, cmdId),
CK: ck,
Multiplexer: mpio.MakeMultiplexer(ck),
}
}
@ -67,8 +64,7 @@ func (c *ShExecType) Close() {
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType {
startPacket := packet.MakeCmdStartPacket()
startPacket.Ts = time.Now().UnixMilli()
startPacket.SessionId = c.SessionId
startPacket.CmdId = c.CmdId
startPacket.CK = c.CK
startPacket.Pid = c.Cmd.Process.Pid
startPacket.MShellPid = os.Getpid()
return startPacket
@ -129,12 +125,12 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
return ecmd
}
func MakeRunnerExec(cmdId string) (*exec.Cmd, error) {
func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath, cmdId)
ecmd := exec.Command(msPath, string(ck))
return ecmd, nil
}
@ -165,19 +161,9 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
return fmt.Errorf("run packet has wrong type: %s", pk.Type)
}
if pk.Detached {
if pk.SessionId == "" {
return fmt.Errorf("run packet does not have sessionid")
}
_, err := uuid.Parse(pk.SessionId)
err := pk.CK.Validate("run packet")
if err != nil {
return fmt.Errorf("invalid sessionid '%s' for command", pk.SessionId)
}
if pk.CmdId == "" {
return fmt.Errorf("run packet does not have cmdid")
}
_, err = uuid.Parse(pk.CmdId)
if err != nil {
return fmt.Errorf("invalid cmdid '%s' for command", pk.CmdId)
return err
}
}
if pk.Cwd != "" {
@ -337,7 +323,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
if err != nil {
return nil, err
}
cmd := MakeShExec("", "")
cmd := MakeShExec("")
var fullSshOpts []string
fullSshOpts = append(fullSshOpts, opts.SSHOpts...)
fullSshOpts = append(fullSshOpts, SSHRemoteCommand)
@ -430,7 +416,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sende
}
func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
cmd := MakeShExec(pk.SessionId, pk.CmdId)
cmd := MakeShExec(pk.CK)
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
UpdateCmdEnv(cmd.Cmd, pk.Env)
if pk.Cwd != "" {
@ -491,7 +477,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
}
func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId)
fileNames, err := base.GetCommandFileNames(pk.CK)
if err != nil {
return nil, err
}
@ -499,7 +485,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
if err == nil { // non-nil error will be caught by regular OpenFile below
// must have size 0
if ptyOutInfo.Size() != 0 {
return nil, fmt.Errorf("cmdid '%s' was already used (ptyout len=%d)", pk.CmdId, ptyOutInfo.Size())
return nil, fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size())
}
}
cmdPty, cmdTty, err := pty.Open()
@ -510,7 +496,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
defer func() {
cmdTty.Close()
}()
rtn := MakeShExec(pk.SessionId, pk.CmdId)
rtn := MakeShExec(pk.CK)
ecmd := MakeExecCmd(pk, cmdTty)
err = ecmd.Start()
if err != nil {
@ -558,8 +544,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
exitCode := GetExitCode(exitErr)
donePacket := packet.MakeCmdDonePacket()
donePacket.Ts = endTs.UnixMilli()
donePacket.SessionId = c.SessionId
donePacket.CmdId = c.CmdId
donePacket.CK = c.CK
donePacket.ExitCode = exitCode
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
if c.FileNames != nil {