diff --git a/cmd/main-server.go b/cmd/main-server.go index 20d44fbd1..a177f7e9d 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -85,7 +85,8 @@ func MakeWSState(clientId string) (*WSState, error) { rtn.ClientId = clientId rtn.ConnectTime = time.Now() rtn.PacketCh = make(chan packet.PacketType, WSStatePacketChSize) - rtn.Tailer, err = cmdtail.MakeTailer(rtn.PacketCh) + chSender := packet.MakeChannelPacketSender(rtn.PacketCh) + rtn.Tailer, err = cmdtail.MakeTailer(chSender) if err != nil { return nil, err } @@ -166,7 +167,7 @@ type MShellProc struct { Lock *sync.Mutex Cmd *exec.Cmd Input *packet.PacketSender - Output chan packet.PacketType + Output *packet.PacketParser Local bool DoneCh chan bool CurDir string @@ -271,17 +272,14 @@ func writeToFifo(fifoName string, data []byte) error { } func sendCmdInput(pk *packet.InputPacketType) error { - var err error - if _, err = uuid.Parse(pk.SessionId); err != nil { - return fmt.Errorf("invalid sessionid '%s': %w", pk.SessionId, err) - } - if _, err = uuid.Parse(pk.CmdId); err != nil { - return fmt.Errorf("invalid cmdid '%s': %w", pk.CmdId, err) + err := pk.CK.Validate("input packet") + if err != nil { + return err } if len(pk.InputData) > MaxInputDataSize { return fmt.Errorf("input data size too large, len=%d (max=%d)", len(pk.InputData), MaxInputDataSize) } - fileNames, err := base.GetCommandFileNames(pk.SessionId, pk.CmdId) + fileNames, err := base.GetCommandFileNames(pk.CK) if err != nil { return err } @@ -405,8 +403,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { rtnLine := sstore.MakeNewLineCmd(params.SessionId, params.WindowId) rtnLine.CmdText = commandStr runPacket := packet.MakeRunPacket() - runPacket.SessionId = params.SessionId - runPacket.CmdId = rtnLine.CmdId + runPacket.CK = base.MakeCommandKey(params.SessionId, rtnLine.CmdId) runPacket.Cwd = "" runPacket.Env = nil runPacket.Command = commandStr @@ -416,8 +413,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { GlobalMShellProc.Input.SendPacket(runPacket) if !GlobalMShellProc.Local { getPacket := packet.MakeGetCmdPacket() - getPacket.SessionId = runPacket.SessionId - getPacket.CmdId = runPacket.CmdId + getPacket.CK = runPacket.CK getPacket.Tail = true GlobalMShellProc.Input.SendPacket(getPacket) } @@ -503,7 +499,10 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) { // func LaunchMShell() (*MShellProc, error) { - msPath := base.GetMShellPath() + msPath, err := base.GetMShellPath() + if err != nil { + return nil, err + } ecmd := exec.Command(msPath) inputWriter, err := ecmd.StdinPipe() if err != nil { @@ -513,13 +512,13 @@ func LaunchMShell() (*MShellProc, error) { if err != nil { return nil, err } - ecmd.Stderr = ecmd.Stdout // /dev/null + ecmd.Stderr = ecmd.Stdout err = ecmd.Start() if err != nil { return nil, err } rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd} - rtn.Output = packet.PacketParser(outputReader) + rtn.Output = packet.MakePacketParser(outputReader) rtn.Input = packet.MakePacketSender(inputWriter) rtn.RpcMap = make(map[string]*RpcEntry) rtn.DoneCh = make(chan bool) @@ -559,7 +558,7 @@ func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Durati } func (runner *MShellProc) ProcessPackets() { - for pk := range runner.Output { + for pk := range runner.Output.MainCh { if rpcPk, ok := pk.(packet.RpcPacketType); ok { rpcId := rpcPk.GetPacketId() runner.Lock.Lock() @@ -576,12 +575,12 @@ func (runner *MShellProc) ProcessPackets() { } if pk.GetType() == packet.CmdDataPacketStr { dataPacket := pk.(*packet.CmdDataPacketType) - fmt.Printf("cmd-data %s/%s pty=%d run=%d\n", dataPacket.SessionId, dataPacket.CmdId, len(dataPacket.PtyData), len(dataPacket.RunData)) + fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData)) continue } - if pk.GetType() == packet.RunnerInitPacketStr { - initPacket := pk.(*packet.RunnerInitPacketType) - fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.ScHomeDir, initPacket.User, initPacket.HomeDir) + if pk.GetType() == packet.InitPacketStr { + initPacket := pk.(*packet.InitPacketType) + fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir) runner.Lock.Lock() runner.Initialized = true runner.User = initPacket.User diff --git a/go.mod b/go.mod index 234a4331f..fb0ee179d 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,17 @@ module github.com/scripthaus-dev/sh2-server go 1.17 require ( - github.com/creack/pty v1.1.18 // indirect - github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/google/uuid v1.3.0 // indirect - github.com/gorilla/mux v1.8.0 // indirect - github.com/gorilla/websocket v1.5.0 // indirect - golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect + github.com/google/uuid v1.3.0 + github.com/gorilla/mux v1.8.0 + github.com/gorilla/websocket v1.5.0 github.com/scripthaus-dev/mshell v0.0.0 ) -replace "github.com/scripthaus-dev/mshell" v0.0.0 => /Users/mike/work/gopath/src/github.com/scripthaus-dev/mshell/ +require ( + github.com/alessio/shellescape v1.4.1 // indirect + github.com/creack/pty v1.1.18 // indirect + github.com/fsnotify/fsnotify v1.5.4 // indirect + golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect +) + +replace github.com/scripthaus-dev/mshell v0.0.0 => /Users/mike/work/gopath/src/github.com/scripthaus-dev/mshell/ diff --git a/go.sum b/go.sum index a1b065d43..e0a22d41c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0= +github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= diff --git a/pkg/scbase/scbase.go b/pkg/scbase/scbase.go new file mode 100644 index 000000000..830903b20 --- /dev/null +++ b/pkg/scbase/scbase.go @@ -0,0 +1,21 @@ +package scbase + +import ( + "os" + "path" +) + +const HomeVarName = "HOME" +const ScHomeVarName = "SCRIPTHAUS_HOME" + +func GetScHomeDir() string { + scHome := os.Getenv(ScHomeVarName) + if scHome == "" { + homeVar := os.Getenv(HomeVarName) + if homeVar == "" { + homeVar = "/" + } + scHome = path.Join(homeVar, "scripthaus") + } + return scHome +} diff --git a/pkg/sstore/sstore.go b/pkg/sstore/sstore.go index ddbbc10c8..8ed346223 100644 --- a/pkg/sstore/sstore.go +++ b/pkg/sstore/sstore.go @@ -1,10 +1,12 @@ package sstore import ( + "path" "sync" "time" "github.com/google/uuid" + "github.com/scripthaus-dev/sh2-server/pkg/scbase" ) var NextLineId = 10 @@ -12,6 +14,12 @@ var NextLineLock = &sync.Mutex{} const LineTypeCmd = "cmd" const LineTypeText = "text" +const DBFileName = "scripthaus.db" + +func GetSessionDBName(sessionId string) string { + scHome := scbase.GetScHomeDir() + return path.Join(scHome, DBFileName) +} type SessionType struct { SessionId string `json:"sessionid"`