From 2a5cde908a4da725544dfa2b51e306f54c3a8bcc Mon Sep 17 00:00:00 2001 From: sawka Date: Thu, 16 Jun 2022 22:22:47 -0700 Subject: [PATCH] got streaming ptyout via websockets working --- cmd/main-server.go | 215 +++++++++++++++++++++++++++++++---------- pkg/wsshell/wsshell.go | 29 +++++- 2 files changed, 191 insertions(+), 53 deletions(-) diff --git a/cmd/main-server.go b/cmd/main-server.go index 736d35f0c..3dca07eb2 100644 --- a/cmd/main-server.go +++ b/cmd/main-server.go @@ -2,8 +2,10 @@ package main import ( "encoding/json" + "errors" "fmt" "io" + "io/fs" "net/http" "os" "os/exec" @@ -29,35 +31,137 @@ const HttpTimeoutDuration = 21 * time.Second const WebSocketServerAddr = "localhost:8081" const MainServerAddr = "localhost:8080" +const WSStateReconnectTime = 30 * time.Second +const WSStatePacketChSize = 20 var GlobalRunnerProc *RunnerProc +var GlobalLock = &sync.Mutex{} +var WSStateMap = make(map[string]*WSState) // clientid -> WsState -type WsConnType struct { - Id string - Shell *wsshell.WSShell - Tailer *cmdtail.Tailer +func setWSState(state *WSState) { + GlobalLock.Lock() + defer GlobalLock.Unlock() + WSStateMap[state.ClientId] = state +} + +func getWSState(clientId string) *WSState { + GlobalLock.Lock() + defer GlobalLock.Unlock() + return WSStateMap[clientId] +} + +func removeWSStateAfterTimeout(clientId string, connectTime time.Time, waitDuration time.Duration) { + go func() { + time.Sleep(waitDuration) + GlobalLock.Lock() + defer GlobalLock.Unlock() + state := WSStateMap[clientId] + if state == nil || state.ConnectTime != connectTime { + return + } + delete(WSStateMap, clientId) + err := state.CloseTailer() + if err != nil { + fmt.Printf("[error] closing tailer on ws %v\n", err) + } + }() +} + +type WSState struct { + Lock *sync.Mutex + ClientId string + ConnectTime time.Time + Shell *wsshell.WSShell + Tailer *cmdtail.Tailer + PacketCh chan packet.PacketType +} + +func MakeWSState(clientId string) (*WSState, error) { + var err error + rtn := &WSState{} + rtn.Lock = &sync.Mutex{} + rtn.ClientId = clientId + rtn.ConnectTime = time.Now() + rtn.PacketCh = make(chan packet.PacketType, WSStatePacketChSize) + rtn.Tailer, err = cmdtail.MakeTailer(rtn.PacketCh) + if err != nil { + return nil, err + } + go func() { + defer close(rtn.PacketCh) + rtn.Tailer.Run() + }() + go rtn.runTailerToWS() + return rtn, nil +} + +func (ws *WSState) CloseTailer() error { + return ws.Tailer.Close() +} + +func (ws *WSState) getShell() *wsshell.WSShell { + ws.Lock.Lock() + defer ws.Lock.Unlock() + return ws.Shell +} + +func (ws *WSState) runTailerToWS() { + for pk := range ws.PacketCh { + if pk.GetType() == "cmddata" { + dataPacket := pk.(*packet.CmdDataPacketType) + err := ws.writePacket(dataPacket) + if err != nil { + fmt.Printf("[error] writing packet to ws: %v\n", err) + } + continue + } + fmt.Printf("tailer-to-ws, bad packet %v\n", pk.GetType()) + } +} + +func (ws *WSState) writePacket(pk packet.PacketType) error { + shell := ws.getShell() + if shell == nil || shell.IsClosed() { + return fmt.Errorf("cannot write packet, empty or closed wsshell") + } + err := shell.WriteJson(pk) + if err != nil { + return err + } + return nil +} + +func (ws *WSState) getConnectTime() time.Time { + ws.Lock.Lock() + defer ws.Lock.Unlock() + return ws.ConnectTime +} + +func (ws *WSState) updateConnectTime() { + ws.Lock.Lock() + defer ws.Lock.Unlock() + ws.ConnectTime = time.Now() +} + +func (ws *WSState) replaceExistingShell(shell *wsshell.WSShell) { + ws.Lock.Lock() + defer ws.Lock.Unlock() + if ws.Shell == nil { + ws.Shell = shell + return + } + ws.Shell.Conn.Close() + ws.Shell = shell + return } type RunnerProc struct { - Lock *sync.Mutex - Cmd *exec.Cmd - Input *packet.PacketSender - Output chan packet.PacketType - WsConnMap map[string]*WsConnType - Local bool - DoneCh chan bool -} - -func (rp *RunnerProc) AddWsConn(ws *WsConnType) { - rp.Lock.Lock() - defer rp.Lock.Unlock() - rp.WsConnMap[ws.Id] = ws -} - -func (rp *RunnerProc) RemoveWsConn(ws *WsConnType) { - rp.Lock.Lock() - defer rp.Lock.Unlock() - delete(rp.WsConnMap, ws.Id) + Lock *sync.Mutex + Cmd *exec.Cmd + Input *packet.PacketSender + Output chan packet.PacketType + Local bool + DoneCh chan bool } func HandleWs(w http.ResponseWriter, r *http.Request) { @@ -66,20 +170,46 @@ func HandleWs(w http.ResponseWriter, r *http.Request) { fmt.Printf("WebSocket Upgrade Failed %T: %v\n", w, err) return } - wsConn := &WsConnType{Id: uuid.New().String(), Shell: shell} - GlobalRunnerProc.AddWsConn(wsConn) + defer shell.Conn.Close() + clientId := r.URL.Query().Get("clientid") + if clientId == "" { + close(shell.WriteChan) + return + } + state := getWSState(clientId) + if state == nil { + state, err = MakeWSState(clientId) + if err != nil { + fmt.Printf("cannot make wsstate: %v\n", err) + close(shell.WriteChan) + return + } + state.replaceExistingShell(shell) + setWSState(state) + } else { + state.updateConnectTime() + state.replaceExistingShell(shell) + } + stateConnectTime := state.getConnectTime() defer func() { - GlobalRunnerProc.RemoveWsConn(wsConn) - wsConn.Shell.Conn.Close() + removeWSStateAfterTimeout(clientId, stateConnectTime, WSStateReconnectTime) }() - fmt.Printf("WebSocket opened %s\n", shell.RemoteAddr) + shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error + fmt.Printf("WebSocket opened %s %s\n", shell.RemoteAddr, state.ClientId) for msgBytes := range shell.ReadChan { pk, err := packet.ParseJsonPacket(msgBytes) if err != nil { fmt.Printf("error unmarshalling ws message: %v\n", err) continue } - fmt.Printf("got ws message: %v\n", pk) + if pk.GetType() == "getcmd" { + err = state.Tailer.AddWatch(pk.(*packet.GetCmdPacketType)) + if err != nil { + fmt.Printf("error adding watch to tailer: %v\n", err) + } + continue + } + fmt.Printf("got ws bad message: %v\n", pk.GetType()) } } @@ -104,6 +234,10 @@ func GetPtyOut(w http.ResponseWriter, r *http.Request) { pathStr := GetPtyOutFile(sessionId, cmdId) fd, err := os.Open(pathStr) if err != nil { + if errors.Is(err, fs.ErrNotExist) { + w.WriteHeader(http.StatusOK) + return + } w.WriteHeader(500) w.Write([]byte(fmt.Sprintf("cannot open file '%s': %v", pathStr, err))) return @@ -290,7 +424,7 @@ func LaunchRunnerProc() (*RunnerProc, error) { } ecmd.Stderr = ecmd.Stdout // /dev/null ecmd.Start() - rtn := &RunnerProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd, WsConnMap: make(map[string]*WsConnType)} + rtn := &RunnerProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd} rtn.Output = packet.PacketParser(outputReader) rtn.Input = packet.MakePacketSender(inputWriter) rtn.DoneCh = make(chan bool) @@ -303,31 +437,10 @@ func LaunchRunnerProc() (*RunnerProc, error) { return rtn, nil } -func (runner *RunnerProc) ForwardDataPacket(pk *packet.CmdDataPacketType) int { - barr, err := json.Marshal(pk) - if err != nil { - fmt.Printf("cannot marshal cmddata packet %s/%s: %v)\n", pk.SessionId, pk.CmdId, err) - return 0 - } - runner.Lock.Lock() - defer runner.Lock.Unlock() - numSent := 0 - for _, ws := range runner.WsConnMap { - ok := ws.Shell.NonBlockingWrite(barr) - if !ok { - fmt.Printf("write was dropped, no queue space in '%s'\n", ws.Id) - continue - } - numSent++ - } - return numSent -} - func (runner *RunnerProc) ProcessPackets() { for pk := range runner.Output { if pk.GetType() == packet.CmdDataPacketStr { dataPacket := pk.(*packet.CmdDataPacketType) - runner.ForwardDataPacket(dataPacket) fmt.Printf("cmd-data %s/%s pty=%d run=%d\n", dataPacket.SessionId, dataPacket.CmdId, len(dataPacket.PtyData), len(dataPacket.RunData)) continue } diff --git a/pkg/wsshell/wsshell.go b/pkg/wsshell/wsshell.go index a3d586987..2d9abb935 100644 --- a/pkg/wsshell/wsshell.go +++ b/pkg/wsshell/wsshell.go @@ -15,6 +15,7 @@ import ( const readWaitTimeout = 15 * time.Second const writeWaitTimeout = 10 * time.Second const pingPeriodTickTime = 10 * time.Second +const initialPingTime = 1 * time.Second var upgrader = websocket.Upgrader{ ReadBufferSize: 4 * 1024, @@ -63,13 +64,25 @@ func (ws *WSShell) WritePing() error { return nil } +func (ws *WSShell) WriteJson(val interface{}) error { + barr, err := json.Marshal(val) + if err != nil { + return err + } + ws.WriteChan <- barr + return nil +} + func (ws *WSShell) WritePump() { ticker := time.NewTicker(pingPeriodTickTime) defer func() { ticker.Stop() ws.Conn.Close() }() - ws.WritePing() + go func() { + time.Sleep(initialPingTime) + ws.WritePing() + }() for { select { case <-ticker.C: @@ -79,7 +92,10 @@ func (ws *WSShell) WritePump() { return } - case msgBytes := <-ws.WriteChan: + case msgBytes, ok := <-ws.WriteChan: + if !ok { + return + } _ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error err := ws.Conn.WriteMessage(websocket.TextMessage, msgBytes) if err != nil { @@ -124,7 +140,16 @@ func (ws *WSShell) ReadPump() { } ws.ReadChan <- message } +} +func (ws *WSShell) IsClosed() bool { + select { + case <-ws.CloseChan: + return true + + default: + return false + } } func StartWS(w http.ResponseWriter, r *http.Request) (*WSShell, error) {