mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-03 18:47:56 +01:00
combine sessionid and cmdid into one field ck (commandkey)
This commit is contained in:
parent
1ea8393844
commit
2a6791bcd6
@ -15,7 +15,6 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
@ -38,7 +37,7 @@ func setupSingleSignals(cmd *shexec.ShExecType) {
|
||||
}()
|
||||
}
|
||||
|
||||
func doSingle(cmdId string) {
|
||||
func doSingle(ck base.CommandKey) {
|
||||
packetCh := packet.PacketParser(os.Stdin)
|
||||
sender := packet.MakePacketSender(os.Stdout)
|
||||
var runPacket *packet.RunPacketType
|
||||
@ -57,11 +56,11 @@ func doSingle(cmdId string) {
|
||||
sender.SendErrorPacket("did not receive a 'run' packet")
|
||||
return
|
||||
}
|
||||
if runPacket.CmdId == "" {
|
||||
runPacket.CmdId = cmdId
|
||||
if runPacket.CK.IsEmpty() {
|
||||
runPacket.CK = ck
|
||||
}
|
||||
if runPacket.CmdId != cmdId {
|
||||
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CmdId, cmdId))
|
||||
if runPacket.CK != ck {
|
||||
sender.SendErrorPacket(fmt.Sprintf("run packet cmdid[%s] did not match arg[%s]", runPacket.CK, ck))
|
||||
return
|
||||
}
|
||||
cmd, err := shexec.RunCommand(runPacket, sender)
|
||||
@ -79,39 +78,36 @@ func doSingle(cmdId string) {
|
||||
}
|
||||
|
||||
func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
if pk.CmdId == "" {
|
||||
pk.CmdId = uuid.New().String()
|
||||
}
|
||||
err := shexec.ValidateRunPacket(pk)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("invalid run packet: %v", err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err)))
|
||||
return
|
||||
}
|
||||
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId)
|
||||
fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot get command file names: %v", err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err)))
|
||||
return
|
||||
}
|
||||
cmd, err := shexec.MakeRunnerExec(pk.CmdId)
|
||||
cmd, err := shexec.MakeRunnerExec(pk.CK)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot make mshell command: %v", err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err)))
|
||||
return
|
||||
}
|
||||
cmdStdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err)))
|
||||
return
|
||||
}
|
||||
// touch ptyout file (should exist for tailer to work correctly)
|
||||
ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err)))
|
||||
return
|
||||
}
|
||||
ptyOutFd.Close() // just opened to create the file, can close right after
|
||||
runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err)))
|
||||
return
|
||||
}
|
||||
defer runnerOutFd.Close()
|
||||
@ -119,13 +115,13 @@ func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
cmd.Stderr = runnerOutFd
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error starting command: %v", err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err)))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err = packet.SendPacket(cmdStdin, pk)
|
||||
if err != nil {
|
||||
sender.SendPacket(packet.MakeIdErrorPacket(pk.CmdId, fmt.Sprintf("error sending forked runner command: %v", err)))
|
||||
sender.SendPacket(packet.MakeCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err)))
|
||||
return
|
||||
}
|
||||
cmdStdin.Close()
|
||||
@ -451,12 +447,12 @@ func main() {
|
||||
}
|
||||
|
||||
if len(os.Args) >= 2 {
|
||||
cmdId, err := uuid.Parse(os.Args[1])
|
||||
if err != nil {
|
||||
packet.SendErrorPacket(os.Stdout, fmt.Sprintf("invalid non-cmdid passed to mshell", err))
|
||||
ck := base.CommandKey(os.Args[1])
|
||||
if err := ck.Validate("mshell arg"); err != nil {
|
||||
packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
return
|
||||
}
|
||||
doSingle(cmdId.String())
|
||||
doSingle(ck)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return
|
||||
} else {
|
||||
|
@ -15,6 +15,8 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const DefaultMShellPath = "mshell"
|
||||
@ -37,6 +39,68 @@ type CommandFileNames struct {
|
||||
RunnerOutFile string
|
||||
}
|
||||
|
||||
type CommandKey string
|
||||
|
||||
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
|
||||
if sessionId == "" && cmdId == "" {
|
||||
return CommandKey("")
|
||||
}
|
||||
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
|
||||
}
|
||||
|
||||
func (ckey CommandKey) IsEmpty() bool {
|
||||
return string(ckey) == ""
|
||||
}
|
||||
|
||||
func (ckey CommandKey) GetSessionId() string {
|
||||
slashIdx := strings.Index(string(ckey), "/")
|
||||
if slashIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
return string(ckey[0:slashIdx])
|
||||
}
|
||||
|
||||
func (ckey CommandKey) GetCmdId() string {
|
||||
slashIdx := strings.Index(string(ckey), "/")
|
||||
if slashIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
return string(ckey[slashIdx+1:])
|
||||
}
|
||||
|
||||
func (ckey CommandKey) Split() (string, string) {
|
||||
fields := strings.SplitN(string(ckey), "/", 2)
|
||||
if len(fields) < 2 {
|
||||
return "", ""
|
||||
}
|
||||
return fields[0], fields[1]
|
||||
}
|
||||
|
||||
func (ckey CommandKey) Validate(typeStr string) error {
|
||||
if typeStr == "" {
|
||||
typeStr = "ck"
|
||||
}
|
||||
if ckey == "" {
|
||||
return fmt.Errorf("%s has empty commandkey", typeStr)
|
||||
}
|
||||
sessionId, cmdId := ckey.Split()
|
||||
if sessionId == "" {
|
||||
return fmt.Errorf("%s does not have sessionid", typeStr)
|
||||
}
|
||||
_, err := uuid.Parse(sessionId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId)
|
||||
}
|
||||
if cmdId == "" {
|
||||
return fmt.Errorf("%s does not have cmdid", typeStr)
|
||||
}
|
||||
_, err = uuid.Parse(cmdId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetHomeDir() string {
|
||||
homeVar := os.Getenv(HomeVarName)
|
||||
if homeVar == "" {
|
||||
@ -57,10 +121,11 @@ func GetScHomeDir() (string, error) {
|
||||
return scHome, nil
|
||||
}
|
||||
|
||||
func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, error) {
|
||||
if sessionId == "" || cmdId == "" {
|
||||
return nil, fmt.Errorf("cannot get command-files when sessionid or cmdid is empty")
|
||||
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
|
||||
if err := ck.Validate("ck"); err != nil {
|
||||
return nil, fmt.Errorf("cannot get command files: %w", err)
|
||||
}
|
||||
sessionId, cmdId := ck.Split()
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -73,8 +138,8 @@ func GetCommandFileNames(sessionId string, cmdId string) (*CommandFileNames, err
|
||||
}, nil
|
||||
}
|
||||
|
||||
func MakeCommandFileNamesWithHome(scHome string, sessionId string, cmdId string) *CommandFileNames {
|
||||
base := path.Join(scHome, SessionsDirBaseName, sessionId, cmdId)
|
||||
func MakeCommandFileNamesWithHome(scHome string, ck CommandKey) *CommandFileNames {
|
||||
base := path.Join(scHome, SessionsDirBaseName, ck.GetSessionId(), ck.GetCmdId())
|
||||
return &CommandFileNames{
|
||||
PtyOutFile: base + ".ptyout",
|
||||
StdinFifo: base + ".stdin",
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/google/uuid"
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
@ -33,7 +32,7 @@ type TailPos struct {
|
||||
}
|
||||
|
||||
type CmdWatchEntry struct {
|
||||
CmdKey CmdKey
|
||||
CmdKey base.CommandKey
|
||||
FilePtyLen int64
|
||||
FileRunLen int64
|
||||
Tails []TailPos
|
||||
@ -73,20 +72,15 @@ func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool {
|
||||
return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen
|
||||
}
|
||||
|
||||
type CmdKey struct {
|
||||
SessionId string
|
||||
CmdId string
|
||||
}
|
||||
|
||||
type Tailer struct {
|
||||
Lock *sync.Mutex
|
||||
WatchList map[CmdKey]CmdWatchEntry
|
||||
WatchList map[base.CommandKey]CmdWatchEntry
|
||||
ScHomeDir string
|
||||
Watcher *fsnotify.Watcher
|
||||
SendCh chan packet.PacketType
|
||||
}
|
||||
|
||||
func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos) {
|
||||
func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return
|
||||
@ -95,7 +89,7 @@ func (t *Tailer) updateTailPos_nolock(cmdKey CmdKey, reqId string, pos TailPos)
|
||||
t.WatchList[cmdKey] = entry
|
||||
}
|
||||
|
||||
func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) {
|
||||
func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return
|
||||
@ -107,13 +101,13 @@ func (t *Tailer) removeTailPos_nolock(cmdKey CmdKey, reqId string) {
|
||||
}
|
||||
|
||||
// delete from watchlist, remove watches
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey.SessionId, cmdKey.CmdId)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, cmdKey)
|
||||
delete(t.WatchList, cmdKey)
|
||||
t.Watcher.Remove(fileNames.PtyOutFile)
|
||||
t.Watcher.Remove(fileNames.RunnerOutFile)
|
||||
}
|
||||
|
||||
func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int64) {
|
||||
func (t *Tailer) updateEntrySizes_nolock(cmdKey base.CommandKey, ptyLen int64, runLen int64) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return
|
||||
@ -123,7 +117,7 @@ func (t *Tailer) updateEntrySizes_nolock(cmdKey CmdKey, ptyLen int64, runLen int
|
||||
t.WatchList[cmdKey] = entry
|
||||
}
|
||||
|
||||
func (t *Tailer) getEntryAndPos_nolock(cmdKey CmdKey, reqId string) (CmdWatchEntry, TailPos, bool) {
|
||||
func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return CmdWatchEntry{}, TailPos{}, false
|
||||
@ -142,7 +136,7 @@ func MakeTailer(sendCh chan packet.PacketType) (*Tailer, error) {
|
||||
}
|
||||
rtn := &Tailer{
|
||||
Lock: &sync.Mutex{},
|
||||
WatchList: make(map[CmdKey]CmdWatchEntry),
|
||||
WatchList: make(map[base.CommandKey]CmdWatchEntry),
|
||||
ScHomeDir: scHomeDir,
|
||||
SendCh: sendCh,
|
||||
}
|
||||
@ -170,8 +164,7 @@ func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]b
|
||||
func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWatchEntry, pos TailPos) *packet.CmdDataPacketType {
|
||||
dataPacket := packet.MakeCmdDataPacket()
|
||||
dataPacket.ReqId = pos.ReqId
|
||||
dataPacket.SessionId = entry.CmdKey.SessionId
|
||||
dataPacket.CmdId = entry.CmdKey.CmdId
|
||||
dataPacket.CK = entry.CmdKey
|
||||
dataPacket.PtyPos = pos.TailPtyPos
|
||||
dataPacket.RunPos = pos.TailRunPos
|
||||
if entry.FilePtyLen > pos.TailPtyPos {
|
||||
@ -196,14 +189,14 @@ func (t *Tailer) makeCmdDataPacket(fileNames *base.CommandFileNames, entry CmdWa
|
||||
}
|
||||
|
||||
// returns (data-packet, keepRunning)
|
||||
func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDataPacketType, bool) {
|
||||
func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool) {
|
||||
t.Lock.Lock()
|
||||
entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId)
|
||||
t.Lock.Unlock()
|
||||
if !foundPos {
|
||||
return nil, false
|
||||
}
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key.SessionId, key.CmdId)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, key)
|
||||
dataPacket := t.makeCmdDataPacket(fileNames, entry, pos)
|
||||
|
||||
t.Lock.Lock()
|
||||
@ -232,7 +225,7 @@ func (t *Tailer) runSingleDataTransfer(key CmdKey, reqId string) (*packet.CmdDat
|
||||
return dataPacket, pos.Running
|
||||
}
|
||||
|
||||
func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) {
|
||||
func (t *Tailer) checkRemoveNoFollow(cmdKey base.CommandKey, reqId string) {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
_, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId)
|
||||
@ -244,7 +237,7 @@ func (t *Tailer) checkRemoveNoFollow(cmdKey CmdKey, reqId string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tailer) RunDataTransfer(key CmdKey, reqId string) {
|
||||
func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) {
|
||||
for {
|
||||
dataPacket, keepRunning := t.runSingleDataTransfer(key, reqId)
|
||||
if dataPacket != nil {
|
||||
@ -283,7 +276,7 @@ func (t *Tailer) updateFile(relFileName string) {
|
||||
t.SendCh <- packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err)
|
||||
return
|
||||
}
|
||||
cmdKey := CmdKey{SessionId: m[1], CmdId: m[2]}
|
||||
cmdKey := base.MakeCommandKey(m[1], m[2])
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
entry, foundEntry := t.WatchList[cmdKey]
|
||||
@ -336,7 +329,7 @@ func max(v1 int64, v2 int64) int64 {
|
||||
}
|
||||
|
||||
func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) {
|
||||
fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey.SessionId, entry.CmdKey.CmdId)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(scHomeDir, entry.CmdKey)
|
||||
ptyInfo, _ := os.Stat(fileNames.PtyOutFile)
|
||||
if ptyInfo != nil {
|
||||
entry.FilePtyLen = ptyInfo.Size()
|
||||
@ -350,30 +343,24 @@ func (entry *CmdWatchEntry) fillFilePos(scHomeDir string) {
|
||||
func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
key := CmdKey{pk.SessionId, pk.CmdId}
|
||||
t.removeTailPos_nolock(key, pk.ReqId)
|
||||
t.removeTailPos_nolock(pk.CK, pk.ReqId)
|
||||
}
|
||||
|
||||
func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) error {
|
||||
_, err := uuid.Parse(getPacket.SessionId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getcmd, bad sessionid '%s': %w", getPacket.SessionId, err)
|
||||
}
|
||||
_, err = uuid.Parse(getPacket.CmdId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getcmd, bad cmdid '%s': %w", getPacket.CmdId, err)
|
||||
if err := getPacket.CK.Validate("getcmd"); err != nil {
|
||||
return err
|
||||
}
|
||||
if getPacket.ReqId == "" {
|
||||
return fmt.Errorf("getcmd, no reqid specified")
|
||||
}
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.SessionId, getPacket.CmdId)
|
||||
fileNames := base.MakeCommandFileNamesWithHome(t.ScHomeDir, getPacket.CK)
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
key := CmdKey{getPacket.SessionId, getPacket.CmdId}
|
||||
key := getPacket.CK
|
||||
entry, foundEntry := t.WatchList[key]
|
||||
if !foundEntry {
|
||||
// add watches, initialize entry
|
||||
err = t.Watcher.Add(fileNames.PtyOutFile)
|
||||
err := t.Watcher.Add(fileNames.PtyOutFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
|
||||
@ -21,8 +22,7 @@ const MaxSingleWriteSize = 4 * 1024
|
||||
|
||||
type Multiplexer struct {
|
||||
Lock *sync.Mutex
|
||||
SessionId string
|
||||
CmdId string
|
||||
CK base.CommandKey
|
||||
FdReaders map[int]*FdReader // synchronized
|
||||
FdWriters map[int]*FdWriter // synchronized
|
||||
CloseAfterStart []*os.File // synchronized
|
||||
@ -34,11 +34,10 @@ type Multiplexer struct {
|
||||
Debug bool
|
||||
}
|
||||
|
||||
func MakeMultiplexer(sessionId string, cmdId string) *Multiplexer {
|
||||
func MakeMultiplexer(ck base.CommandKey) *Multiplexer {
|
||||
return &Multiplexer{
|
||||
Lock: &sync.Mutex{},
|
||||
SessionId: sessionId,
|
||||
CmdId: cmdId,
|
||||
CK: ck,
|
||||
FdReaders: make(map[int]*FdReader),
|
||||
FdWriters: make(map[int]*FdWriter),
|
||||
}
|
||||
@ -126,8 +125,7 @@ func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool)
|
||||
|
||||
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
|
||||
ack := packet.MakeDataAckPacket()
|
||||
ack.SessionId = m.SessionId
|
||||
ack.CmdId = m.CmdId
|
||||
ack.CK = m.CK
|
||||
ack.FdNum = fdNum
|
||||
ack.AckLen = ackLen
|
||||
if err != nil {
|
||||
@ -138,8 +136,7 @@ func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packe
|
||||
|
||||
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
|
||||
pk := packet.MakeDataPacket()
|
||||
pk.SessionId = m.SessionId
|
||||
pk.CmdId = m.CmdId
|
||||
pk.CK = m.CK
|
||||
pk.FdNum = fdNum
|
||||
pk.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||
if err != nil {
|
||||
|
@ -16,6 +16,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
)
|
||||
|
||||
// remote: init, run, ping, data, cmdstart, cmddone
|
||||
@ -82,8 +84,7 @@ func MakePacket(packetType string) (PacketType, error) {
|
||||
type CmdDataPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ReqId string `json:"reqid"`
|
||||
SessionId string `json:"sessionid"`
|
||||
CmdId string `json:"cmdid"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
PtyPos int64 `json:"ptypos"`
|
||||
PtyLen int64 `json:"ptylen"`
|
||||
RunPos int64 `json:"runpos"`
|
||||
@ -100,6 +101,10 @@ func (*CmdDataPacketType) GetType() string {
|
||||
return CmdDataPacketStr
|
||||
}
|
||||
|
||||
func (p *CmdDataPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeCmdDataPacket() *CmdDataPacketType {
|
||||
return &CmdDataPacketType{Type: CmdDataPacketStr}
|
||||
}
|
||||
@ -118,8 +123,7 @@ func MakePingPacket() *PingPacketType {
|
||||
|
||||
type DataPacketType struct {
|
||||
Type string `json:"type"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
FdNum int `json:"fdnum"`
|
||||
Data64 string `json:"data64"` // base64 encoded
|
||||
Eof bool `json:"eof,omitempty"`
|
||||
@ -130,6 +134,10 @@ func (*DataPacketType) GetType() string {
|
||||
return DataPacketStr
|
||||
}
|
||||
|
||||
func (p *DataPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func B64DecodedLen(b64 string) int {
|
||||
if len(b64) < 4 {
|
||||
return 0 // we use padded strings, so < 4 is always 0
|
||||
@ -162,8 +170,7 @@ func MakeDataPacket() *DataPacketType {
|
||||
|
||||
type DataAckPacketType struct {
|
||||
Type string `json:"type"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
FdNum int `json:"fdnum"`
|
||||
AckLen int `json:"acklen"`
|
||||
Error string `json:"error,omitempty"`
|
||||
@ -173,6 +180,10 @@ func (*DataAckPacketType) GetType() string {
|
||||
return DataAckPacketStr
|
||||
}
|
||||
|
||||
func (p *DataAckPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func (p *DataAckPacketType) String() string {
|
||||
errStr := ""
|
||||
if p.Error != "" {
|
||||
@ -190,8 +201,7 @@ func MakeDataAckPacket() *DataAckPacketType {
|
||||
// WinSize, if set, will run TIOCSWINSZ to set size, and then send SIGWINCH
|
||||
type InputPacketType struct {
|
||||
Type string `json:"type"`
|
||||
SessionId string `json:"sessionid"`
|
||||
CmdId string `json:"cmdid"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
InputData string `json:"inputdata"`
|
||||
SigNum int `json:"signum,omitempty"`
|
||||
WinSizeRows int `json:"winsizerows"`
|
||||
@ -202,6 +212,10 @@ func (*InputPacketType) GetType() string {
|
||||
return InputPacketStr
|
||||
}
|
||||
|
||||
func (p *InputPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeInputPacket() *InputPacketType {
|
||||
return &InputPacketType{Type: InputPacketStr}
|
||||
}
|
||||
@ -209,14 +223,17 @@ func MakeInputPacket() *InputPacketType {
|
||||
type UntailCmdPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ReqId string `json:"reqid"`
|
||||
SessionId string `json:"sessionid"`
|
||||
CmdId string `json:"cmdid"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
}
|
||||
|
||||
func (*UntailCmdPacketType) GetType() string {
|
||||
return UntailCmdPacketStr
|
||||
}
|
||||
|
||||
func (p *UntailCmdPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeUntailCmdPacket() *UntailCmdPacketType {
|
||||
return &UntailCmdPacketType{Type: UntailCmdPacketStr}
|
||||
}
|
||||
@ -224,8 +241,7 @@ func MakeUntailCmdPacket() *UntailCmdPacketType {
|
||||
type GetCmdPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ReqId string `json:"reqid"`
|
||||
SessionId string `json:"sessionid"`
|
||||
CmdId string `json:"cmdid"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
PtyPos int64 `json:"ptypos"`
|
||||
RunPos int64 `json:"runpos"`
|
||||
Tail bool `json:"tail,omitempty"`
|
||||
@ -235,6 +251,10 @@ func (*GetCmdPacketType) GetType() string {
|
||||
return GetCmdPacketStr
|
||||
}
|
||||
|
||||
func (p *GetCmdPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeGetCmdPacket() *GetCmdPacketType {
|
||||
return &GetCmdPacketType{Type: GetCmdPacketStr}
|
||||
}
|
||||
@ -348,8 +368,7 @@ func MakeDonePacket() *DonePacketType {
|
||||
type CmdDonePacketType struct {
|
||||
Type string `json:"type"`
|
||||
Ts int64 `json:"ts"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
ExitCode int `json:"exitcode"`
|
||||
DurationMs int64 `json:"durationms"`
|
||||
}
|
||||
@ -358,6 +377,10 @@ func (*CmdDonePacketType) GetType() string {
|
||||
return CmdDonePacketStr
|
||||
}
|
||||
|
||||
func (p *CmdDonePacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeCmdDonePacket() *CmdDonePacketType {
|
||||
return &CmdDonePacketType{Type: CmdDonePacketStr}
|
||||
}
|
||||
@ -365,8 +388,7 @@ func MakeCmdDonePacket() *CmdDonePacketType {
|
||||
type CmdStartPacketType struct {
|
||||
Type string `json:"type"`
|
||||
Ts int64 `json:"ts"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
Pid int `json:"pid"`
|
||||
MShellPid int `json:"mshellpid"`
|
||||
}
|
||||
@ -375,6 +397,10 @@ func (*CmdStartPacketType) GetType() string {
|
||||
return CmdStartPacketStr
|
||||
}
|
||||
|
||||
func (p *CmdStartPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeCmdStartPacket() *CmdStartPacketType {
|
||||
return &CmdStartPacketType{Type: CmdStartPacketStr}
|
||||
}
|
||||
@ -394,8 +420,7 @@ type RemoteFd struct {
|
||||
|
||||
type RunPacketType struct {
|
||||
Type string `json:"type"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
CmdId string `json:"cmdid,omitempty"`
|
||||
CK base.CommandKey `json:"ck"`
|
||||
Command string `json:"command"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
@ -408,6 +433,10 @@ func (*RunPacketType) GetType() string {
|
||||
return RunPacketStr
|
||||
}
|
||||
|
||||
func (p *RunPacketType) GetCK() base.CommandKey {
|
||||
return p.CK
|
||||
}
|
||||
|
||||
func MakeRunPacket() *RunPacketType {
|
||||
return &RunPacketType{Type: RunPacketStr}
|
||||
}
|
||||
@ -417,7 +446,7 @@ type BarePacketType struct {
|
||||
}
|
||||
|
||||
type ErrorPacketType struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
CK base.CommandKey `json:"ck,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
@ -430,8 +459,8 @@ func MakeErrorPacket(errorStr string) *ErrorPacketType {
|
||||
return &ErrorPacketType{Type: ErrorPacketStr, Error: errorStr}
|
||||
}
|
||||
|
||||
func MakeIdErrorPacket(id string, errorStr string) *ErrorPacketType {
|
||||
return &ErrorPacketType{Type: ErrorPacketStr, Id: id, Error: errorStr}
|
||||
func MakeCKErrorPacket(ck base.CommandKey, errorStr string) *ErrorPacketType {
|
||||
return &ErrorPacketType{Type: ErrorPacketStr, CK: ck, Error: errorStr}
|
||||
}
|
||||
|
||||
type PacketType interface {
|
||||
@ -450,6 +479,11 @@ type RpcPacketType interface {
|
||||
GetPacketId() string
|
||||
}
|
||||
|
||||
type CommandPacketType interface {
|
||||
GetType() string
|
||||
GetCK() base.CommandKey
|
||||
}
|
||||
|
||||
func ParseJsonPacket(jsonBuf []byte) (PacketType, error) {
|
||||
var bareCmd BarePacketType
|
||||
err := json.Unmarshal(jsonBuf, &bareCmd)
|
||||
|
@ -17,7 +17,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/google/uuid"
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/mpio"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
@ -39,21 +38,19 @@ const RemoteSudoPasswordCommandFmt = `cat /dev/fd/%d | sudo -S -C %d bash -c "ec
|
||||
type ShExecType struct {
|
||||
Lock *sync.Mutex
|
||||
StartTs time.Time
|
||||
SessionId string
|
||||
CmdId string
|
||||
CK base.CommandKey
|
||||
FileNames *base.CommandFileNames
|
||||
Cmd *exec.Cmd
|
||||
CmdPty *os.File
|
||||
Multiplexer *mpio.Multiplexer
|
||||
}
|
||||
|
||||
func MakeShExec(sessionId string, cmdId string) *ShExecType {
|
||||
func MakeShExec(ck base.CommandKey) *ShExecType {
|
||||
return &ShExecType{
|
||||
Lock: &sync.Mutex{},
|
||||
StartTs: time.Now(),
|
||||
SessionId: sessionId,
|
||||
CmdId: cmdId,
|
||||
Multiplexer: mpio.MakeMultiplexer(sessionId, cmdId),
|
||||
CK: ck,
|
||||
Multiplexer: mpio.MakeMultiplexer(ck),
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,8 +64,7 @@ func (c *ShExecType) Close() {
|
||||
func (c *ShExecType) MakeCmdStartPacket() *packet.CmdStartPacketType {
|
||||
startPacket := packet.MakeCmdStartPacket()
|
||||
startPacket.Ts = time.Now().UnixMilli()
|
||||
startPacket.SessionId = c.SessionId
|
||||
startPacket.CmdId = c.CmdId
|
||||
startPacket.CK = c.CK
|
||||
startPacket.Pid = c.Cmd.Process.Pid
|
||||
startPacket.MShellPid = os.Getpid()
|
||||
return startPacket
|
||||
@ -129,12 +125,12 @@ func MakeExecCmd(pk *packet.RunPacketType, cmdTty *os.File) *exec.Cmd {
|
||||
return ecmd
|
||||
}
|
||||
|
||||
func MakeRunnerExec(cmdId string) (*exec.Cmd, error) {
|
||||
func MakeRunnerExec(ck base.CommandKey) (*exec.Cmd, error) {
|
||||
msPath, err := base.GetMShellPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecmd := exec.Command(msPath, cmdId)
|
||||
ecmd := exec.Command(msPath, string(ck))
|
||||
return ecmd, nil
|
||||
}
|
||||
|
||||
@ -165,19 +161,9 @@ func ValidateRunPacket(pk *packet.RunPacketType) error {
|
||||
return fmt.Errorf("run packet has wrong type: %s", pk.Type)
|
||||
}
|
||||
if pk.Detached {
|
||||
if pk.SessionId == "" {
|
||||
return fmt.Errorf("run packet does not have sessionid")
|
||||
}
|
||||
_, err := uuid.Parse(pk.SessionId)
|
||||
err := pk.CK.Validate("run packet")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid sessionid '%s' for command", pk.SessionId)
|
||||
}
|
||||
if pk.CmdId == "" {
|
||||
return fmt.Errorf("run packet does not have cmdid")
|
||||
}
|
||||
_, err = uuid.Parse(pk.CmdId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid cmdid '%s' for command", pk.CmdId)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if pk.Cwd != "" {
|
||||
@ -337,7 +323,7 @@ func RunClientSSHCommandAndWait(opts *ClientOpts) (*packet.CmdDonePacketType, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := MakeShExec("", "")
|
||||
cmd := MakeShExec("")
|
||||
var fullSshOpts []string
|
||||
fullSshOpts = append(fullSshOpts, opts.SSHOpts...)
|
||||
fullSshOpts = append(fullSshOpts, SSHRemoteCommand)
|
||||
@ -430,7 +416,7 @@ func (cmd *ShExecType) RunRemoteIOAndWait(packetCh chan packet.PacketType, sende
|
||||
}
|
||||
|
||||
func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
|
||||
cmd := MakeShExec(pk.SessionId, pk.CmdId)
|
||||
cmd := MakeShExec(pk.CK)
|
||||
cmd.Cmd = exec.Command("bash", "-c", pk.Command)
|
||||
UpdateCmdEnv(cmd.Cmd, pk.Env)
|
||||
if pk.Cwd != "" {
|
||||
@ -491,7 +477,7 @@ func runCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender) (*S
|
||||
}
|
||||
|
||||
func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (*ShExecType, error) {
|
||||
fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId)
|
||||
fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -499,7 +485,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
|
||||
if err == nil { // non-nil error will be caught by regular OpenFile below
|
||||
// must have size 0
|
||||
if ptyOutInfo.Size() != 0 {
|
||||
return nil, fmt.Errorf("cmdid '%s' was already used (ptyout len=%d)", pk.CmdId, ptyOutInfo.Size())
|
||||
return nil, fmt.Errorf("cmdkey '%s' was already used (ptyout len=%d)", pk.CK, ptyOutInfo.Size())
|
||||
}
|
||||
}
|
||||
cmdPty, cmdTty, err := pty.Open()
|
||||
@ -510,7 +496,7 @@ func runCommandDetached(pk *packet.RunPacketType, sender *packet.PacketSender) (
|
||||
defer func() {
|
||||
cmdTty.Close()
|
||||
}()
|
||||
rtn := MakeShExec(pk.SessionId, pk.CmdId)
|
||||
rtn := MakeShExec(pk.CK)
|
||||
ecmd := MakeExecCmd(pk, cmdTty)
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
@ -558,8 +544,7 @@ func (c *ShExecType) WaitForCommand() *packet.CmdDonePacketType {
|
||||
exitCode := GetExitCode(exitErr)
|
||||
donePacket := packet.MakeCmdDonePacket()
|
||||
donePacket.Ts = endTs.UnixMilli()
|
||||
donePacket.SessionId = c.SessionId
|
||||
donePacket.CmdId = c.CmdId
|
||||
donePacket.CK = c.CK
|
||||
donePacket.ExitCode = exitCode
|
||||
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
||||
if c.FileNames != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user