package main import ( "context" "encoding/json" "errors" "fmt" "io" "io/fs" "net/http" "os" "strings" "sync" "time" "github.com/google/uuid" "github.com/gorilla/mux" "github.com/scripthaus-dev/mshell/pkg/base" "github.com/scripthaus-dev/mshell/pkg/cmdtail" "github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/sh2-server/pkg/remote" "github.com/scripthaus-dev/sh2-server/pkg/sstore" "github.com/scripthaus-dev/sh2-server/pkg/wsshell" ) const HttpReadTimeout = 5 * time.Second const HttpWriteTimeout = 21 * time.Second const HttpMaxHeaderBytes = 60000 const HttpTimeoutDuration = 21 * time.Second const WebSocketServerAddr = "localhost:8081" const MainServerAddr = "localhost:8080" const WSStateReconnectTime = 30 * time.Second const WSStatePacketChSize = 20 const MaxInputDataSize = 1000 var GlobalMShellProc *remote.MShellProc var GlobalLock = &sync.Mutex{} var WSStateMap = make(map[string]*WSState) // clientid -> WsState func setWSState(state *WSState) { GlobalLock.Lock() defer GlobalLock.Unlock() WSStateMap[state.ClientId] = state } func getWSState(clientId string) *WSState { GlobalLock.Lock() defer GlobalLock.Unlock() return WSStateMap[clientId] } func removeWSStateAfterTimeout(clientId string, connectTime time.Time, waitDuration time.Duration) { go func() { time.Sleep(waitDuration) GlobalLock.Lock() defer GlobalLock.Unlock() state := WSStateMap[clientId] if state == nil || state.ConnectTime != connectTime { return } delete(WSStateMap, clientId) err := state.CloseTailer() if err != nil { fmt.Printf("[error] closing tailer on ws %v\n", err) } }() } type WSState struct { Lock *sync.Mutex ClientId string ConnectTime time.Time Shell *wsshell.WSShell Tailer *cmdtail.Tailer PacketCh chan packet.PacketType } func MakeWSState(clientId string) (*WSState, error) { var err error rtn := &WSState{} rtn.Lock = &sync.Mutex{} rtn.ClientId = clientId rtn.ConnectTime = time.Now() rtn.PacketCh = make(chan packet.PacketType, WSStatePacketChSize) chSender := packet.MakeChannelPacketSender(rtn.PacketCh) rtn.Tailer, err = cmdtail.MakeTailer(chSender) if err != nil { return nil, err } go func() { defer close(rtn.PacketCh) rtn.Tailer.Run() }() go rtn.runTailerToWS() return rtn, nil } func (ws *WSState) CloseTailer() error { return ws.Tailer.Close() } func (ws *WSState) getShell() *wsshell.WSShell { ws.Lock.Lock() defer ws.Lock.Unlock() return ws.Shell } func (ws *WSState) runTailerToWS() { for pk := range ws.PacketCh { if pk.GetType() == "cmddata" { dataPacket := pk.(*packet.CmdDataPacketType) err := ws.writePacket(dataPacket) if err != nil { fmt.Printf("[error] writing packet to ws: %v\n", err) } continue } fmt.Printf("tailer-to-ws, bad packet %v\n", pk.GetType()) } } func (ws *WSState) writePacket(pk packet.PacketType) error { shell := ws.getShell() if shell == nil || shell.IsClosed() { return fmt.Errorf("cannot write packet, empty or closed wsshell") } err := shell.WriteJson(pk) if err != nil { return err } return nil } func (ws *WSState) getConnectTime() time.Time { ws.Lock.Lock() defer ws.Lock.Unlock() return ws.ConnectTime } func (ws *WSState) updateConnectTime() { ws.Lock.Lock() defer ws.Lock.Unlock() ws.ConnectTime = time.Now() } func (ws *WSState) replaceExistingShell(shell *wsshell.WSShell) { ws.Lock.Lock() defer ws.Lock.Unlock() if ws.Shell == nil { ws.Shell = shell return } ws.Shell.Conn.Close() ws.Shell = shell return } func HandleWs(w http.ResponseWriter, r *http.Request) { shell, err := wsshell.StartWS(w, r) if err != nil { fmt.Printf("WebSocket Upgrade Failed %T: %v\n", w, err) return } defer shell.Conn.Close() clientId := r.URL.Query().Get("clientid") if clientId == "" { close(shell.WriteChan) return } state := getWSState(clientId) if state == nil { state, err = MakeWSState(clientId) if err != nil { fmt.Printf("cannot make wsstate: %v\n", err) close(shell.WriteChan) return } state.replaceExistingShell(shell) setWSState(state) } else { state.updateConnectTime() state.replaceExistingShell(shell) } stateConnectTime := state.getConnectTime() defer func() { removeWSStateAfterTimeout(clientId, stateConnectTime, WSStateReconnectTime) }() shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error fmt.Printf("WebSocket opened %s %s\n", shell.RemoteAddr, state.ClientId) for msgBytes := range shell.ReadChan { pk, err := packet.ParseJsonPacket(msgBytes) if err != nil { fmt.Printf("error unmarshalling ws message: %v\n", err) continue } if pk.GetType() == "getcmd" { err = state.Tailer.AddWatch(pk.(*packet.GetCmdPacketType)) if err != nil { fmt.Printf("[error] adding watch to tailer: %v\n", err) } continue } if pk.GetType() == "input" { go func() { err = sendCmdInput(pk.(*packet.InputPacketType)) if err != nil { fmt.Printf("[error] sending command input: %v\n", err) } }() continue } fmt.Printf("got ws bad message: %v\n", pk.GetType()) } } // todo: sync multiple writes to the same fifoName into a single go-routine and do liveness checking on fifo // if this returns an error, likely the fifo is dead and the cmd should be marked as 'done' func writeToFifo(fifoName string, data []byte) error { rwfd, err := os.OpenFile(fifoName, os.O_RDWR, 0600) if err != nil { return err } defer rwfd.Close() fifoWriter, err := os.OpenFile(fifoName, os.O_WRONLY, 0600) // blocking open (open won't block because of rwfd) if err != nil { return err } defer fifoWriter.Close() // this *could* block if the fifo buffer is full // unlikely because if the reader is dead, and len(data) < pipe size, then the buffer will be empty and will clear after rwfd is closed _, err = fifoWriter.Write(data) if err != nil { return err } return nil } func sendCmdInput(pk *packet.InputPacketType) error { err := pk.CK.Validate("input packet") if err != nil { return err } if len(pk.InputData) > MaxInputDataSize { return fmt.Errorf("input data size too large, len=%d (max=%d)", len(pk.InputData), MaxInputDataSize) } fileNames, err := base.GetCommandFileNames(pk.CK) if err != nil { return err } err = writeToFifo(fileNames.StdinFifo, []byte(pk.InputData)) if err != nil { return err } return nil } func GetPtyOutFile(sessionId string, cmdId string) string { pathStr := fmt.Sprintf("/Users/mike/scripthaus/.sessions/%s/%s.ptyout", sessionId, cmdId) return pathStr } func GetPtyOut(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Vary", "Origin") w.Header().Set("Cache-Control", "no-cache") qvals := r.URL.Query() sessionId := qvals.Get("sessionid") cmdId := qvals.Get("cmdid") if sessionId == "" || cmdId == "" { w.WriteHeader(500) w.Write([]byte(fmt.Sprintf("must specify sessionid and cmdid"))) return } pathStr := GetPtyOutFile(sessionId, cmdId) fd, err := os.Open(pathStr) if err != nil { if errors.Is(err, fs.ErrNotExist) { w.WriteHeader(http.StatusOK) return } w.WriteHeader(500) w.Write([]byte(fmt.Sprintf("cannot open file '%s': %v", pathStr, err))) return } w.WriteHeader(http.StatusOK) io.Copy(w, fd) } func WriteJsonError(w http.ResponseWriter, errVal error) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(500) errMap := make(map[string]interface{}) errMap["error"] = errVal.Error() barr, _ := json.Marshal(errMap) w.Write(barr) return } func WriteJsonSuccess(w http.ResponseWriter, data interface{}) { w.Header().Set("Content-Type", "application/json") rtnMap := make(map[string]interface{}) rtnMap["success"] = true if data != nil { rtnMap["data"] = data } barr, err := json.Marshal(rtnMap) if err != nil { WriteJsonError(w, err) return } w.WriteHeader(200) w.Write(barr) return } type runCommandParams struct { SessionId string `json:"sessionid"` WindowId string `json:"windowid"` Command string `json:"command"` } type runCommandResponse struct { Line *sstore.LineType `json:"line"` } func HandleRunCommand(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") w.Header().Set("Vary", "Origin") w.Header().Set("Cache-Control", "no-cache") if r.Method == "GET" || r.Method == "OPTIONS" { w.WriteHeader(200) return } decoder := json.NewDecoder(r.Body) var params runCommandParams err := decoder.Decode(¶ms) if err != nil { WriteJsonError(w, fmt.Errorf("error decoding json: %w", err)) return } if _, err = uuid.Parse(params.SessionId); err != nil { WriteJsonError(w, fmt.Errorf("invalid sessionid '%s': %w", params.SessionId, err)) return } commandStr := strings.TrimSpace(params.Command) if commandStr == "" { WriteJsonError(w, fmt.Errorf("invalid emtpty command")) return } if strings.HasPrefix(commandStr, "/comment ") { text := strings.TrimSpace(commandStr[9:]) rtnLine := sstore.MakeNewLineText(params.SessionId, params.WindowId, text) WriteJsonSuccess(w, &runCommandResponse{Line: rtnLine}) return } if strings.HasPrefix(commandStr, "cd ") { newDir := strings.TrimSpace(commandStr[3:]) cdPacket := packet.MakeCdPacket() cdPacket.PacketId = uuid.New().String() cdPacket.Dir = newDir GlobalMShellProc.Input.SendPacket(cdPacket) return } rtnLine := sstore.MakeNewLineCmd(params.SessionId, params.WindowId) // rtnLine.CmdText = commandStr runPacket := packet.MakeRunPacket() runPacket.CK = base.MakeCommandKey(params.SessionId, rtnLine.CmdId) runPacket.Cwd = "" runPacket.Env = nil runPacket.Command = commandStr fmt.Printf("run-packet %v\n", runPacket) WriteJsonSuccess(w, &runCommandResponse{Line: rtnLine}) go func() { GlobalMShellProc.Input.SendPacket(runPacket) if !GlobalMShellProc.Local { getPacket := packet.MakeGetCmdPacket() getPacket.CK = runPacket.CK getPacket.Tail = true GlobalMShellProc.Input.SendPacket(getPacket) } }() return } // /api/start-session // returns: // * userid // * sessionid // // /api/ptyout (pos=[position]) - returns contents of ptyout file // params: // * sessionid // * cmdid // * pos // returns: // * stream of ptyout file (text, utf-8) // // POST /api/run-command // params // * userid // * sessionid // returns // * cmdid // // /api/refresh-session // params // * sessionid // * start -- can be negative // * numlines // returns // * permissions (readonly, comment, command) // * lines // * lineid // * ts // * userid // * linetype // * text // * cmdid // /ws // ->watch-session: // * sessionid // ->watch: // * sessionid // * cmdid // ->focus: // * sessionid // * cmdid // ->input: // * sessionid // * cmdid // * data // ->signal: // * sessionid // * cmdid // * data // <-data: // * sessionid // * cmdid // * pos // * data // <-session-data: // * sessionid // * line // session-doc // timestamp | user | cmd-type | data // cmd-type = comment // cmd-type = command, commandid=ABC // how to know if command is still executing? is command done? // local -- .ptyout, .stdin // remote -- transfer controller program // controller-startcmd -- start command (with options) => returns cmdid // controller-watchsession [sessionid] // transfer [cmdid:pos] pairs. streams back anything new written to ptyout on stdout // stdin-packet [cmdid:user:data] // startcmd will figure out the correct // func runWebSocketServer() { gr := mux.NewRouter() gr.HandleFunc("/ws", HandleWs) server := &http.Server{ Addr: WebSocketServerAddr, ReadTimeout: HttpReadTimeout, WriteTimeout: HttpWriteTimeout, MaxHeaderBytes: HttpMaxHeaderBytes, Handler: gr, } server.SetKeepAlivesEnabled(false) fmt.Printf("Running websocket server on %s\n", WebSocketServerAddr) err := server.ListenAndServe() if err != nil { fmt.Printf("[error] trying to run websocket server: %v\n", err) } } func main() { if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") { err := sstore.MigrateCommandOpts(os.Args[1:]) if err != nil { fmt.Printf("[error] %v\n", err) } return } err := sstore.TryMigrateUp() if err != nil { fmt.Printf("[error] %v\n", err) return } numSessions, err := sstore.NumSessions(context.Background()) if err != nil { fmt.Printf("[error] getting num sessions: %v\n", err) return } err = sstore.EnsureLocalRemote(context.Background()) if err != nil { fmt.Printf("[error] ensuring local remote: %v\n", err) return } fmt.Printf("[db] sessions count=%d\n", numSessions) if numSessions == 0 { sstore.CreateInitialSession(context.Background()) } return runnerProc, err := remote.LaunchMShell() if err != nil { fmt.Printf("error launching runner-proc: %v\n", err) return } GlobalMShellProc = runnerProc go runnerProc.ProcessPackets() fmt.Printf("Started local runner pid[%d]\n", runnerProc.Cmd.Process.Pid) go runWebSocketServer() gr := mux.NewRouter() gr.HandleFunc("/api/ptyout", GetPtyOut) gr.HandleFunc("/api/run-command", HandleRunCommand).Methods("GET", "POST", "OPTIONS") server := &http.Server{ Addr: MainServerAddr, ReadTimeout: HttpReadTimeout, WriteTimeout: HttpWriteTimeout, MaxHeaderBytes: HttpMaxHeaderBytes, Handler: http.TimeoutHandler(gr, HttpTimeoutDuration, "Timeout"), } server.SetKeepAlivesEnabled(false) fmt.Printf("Running main server on %s\n", MainServerAddr) err = server.ListenAndServe() if err != nil { fmt.Printf("ERROR: %v\n", err) } }