got streaming ptyout via websockets working

This commit is contained in:
sawka 2022-06-16 22:22:47 -07:00
parent 862014bd82
commit 2a5cde908a
2 changed files with 191 additions and 53 deletions

View File

@ -2,8 +2,10 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@ -29,35 +31,137 @@ const HttpTimeoutDuration = 21 * time.Second
const WebSocketServerAddr = "localhost:8081" const WebSocketServerAddr = "localhost:8081"
const MainServerAddr = "localhost:8080" const MainServerAddr = "localhost:8080"
const WSStateReconnectTime = 30 * time.Second
const WSStatePacketChSize = 20
var GlobalRunnerProc *RunnerProc var GlobalRunnerProc *RunnerProc
var GlobalLock = &sync.Mutex{}
var WSStateMap = make(map[string]*WSState) // clientid -> WsState
type WsConnType struct { func setWSState(state *WSState) {
Id string GlobalLock.Lock()
Shell *wsshell.WSShell defer GlobalLock.Unlock()
Tailer *cmdtail.Tailer WSStateMap[state.ClientId] = state
}
func getWSState(clientId string) *WSState {
GlobalLock.Lock()
defer GlobalLock.Unlock()
return WSStateMap[clientId]
}
func removeWSStateAfterTimeout(clientId string, connectTime time.Time, waitDuration time.Duration) {
go func() {
time.Sleep(waitDuration)
GlobalLock.Lock()
defer GlobalLock.Unlock()
state := WSStateMap[clientId]
if state == nil || state.ConnectTime != connectTime {
return
}
delete(WSStateMap, clientId)
err := state.CloseTailer()
if err != nil {
fmt.Printf("[error] closing tailer on ws %v\n", err)
}
}()
}
type WSState struct {
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
Tailer *cmdtail.Tailer
PacketCh chan packet.PacketType
}
func MakeWSState(clientId string) (*WSState, error) {
var err error
rtn := &WSState{}
rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId
rtn.ConnectTime = time.Now()
rtn.PacketCh = make(chan packet.PacketType, WSStatePacketChSize)
rtn.Tailer, err = cmdtail.MakeTailer(rtn.PacketCh)
if err != nil {
return nil, err
}
go func() {
defer close(rtn.PacketCh)
rtn.Tailer.Run()
}()
go rtn.runTailerToWS()
return rtn, nil
}
func (ws *WSState) CloseTailer() error {
return ws.Tailer.Close()
}
func (ws *WSState) getShell() *wsshell.WSShell {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Shell
}
func (ws *WSState) runTailerToWS() {
for pk := range ws.PacketCh {
if pk.GetType() == "cmddata" {
dataPacket := pk.(*packet.CmdDataPacketType)
err := ws.writePacket(dataPacket)
if err != nil {
fmt.Printf("[error] writing packet to ws: %v\n", err)
}
continue
}
fmt.Printf("tailer-to-ws, bad packet %v\n", pk.GetType())
}
}
func (ws *WSState) writePacket(pk packet.PacketType) error {
shell := ws.getShell()
if shell == nil || shell.IsClosed() {
return fmt.Errorf("cannot write packet, empty or closed wsshell")
}
err := shell.WriteJson(pk)
if err != nil {
return err
}
return nil
}
func (ws *WSState) getConnectTime() time.Time {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.ConnectTime
}
func (ws *WSState) updateConnectTime() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.ConnectTime = time.Now()
}
func (ws *WSState) replaceExistingShell(shell *wsshell.WSShell) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.Shell == nil {
ws.Shell = shell
return
}
ws.Shell.Conn.Close()
ws.Shell = shell
return
} }
type RunnerProc struct { type RunnerProc struct {
Lock *sync.Mutex Lock *sync.Mutex
Cmd *exec.Cmd Cmd *exec.Cmd
Input *packet.PacketSender Input *packet.PacketSender
Output chan packet.PacketType Output chan packet.PacketType
WsConnMap map[string]*WsConnType Local bool
Local bool DoneCh chan bool
DoneCh chan bool
}
func (rp *RunnerProc) AddWsConn(ws *WsConnType) {
rp.Lock.Lock()
defer rp.Lock.Unlock()
rp.WsConnMap[ws.Id] = ws
}
func (rp *RunnerProc) RemoveWsConn(ws *WsConnType) {
rp.Lock.Lock()
defer rp.Lock.Unlock()
delete(rp.WsConnMap, ws.Id)
} }
func HandleWs(w http.ResponseWriter, r *http.Request) { func HandleWs(w http.ResponseWriter, r *http.Request) {
@ -66,20 +170,46 @@ func HandleWs(w http.ResponseWriter, r *http.Request) {
fmt.Printf("WebSocket Upgrade Failed %T: %v\n", w, err) fmt.Printf("WebSocket Upgrade Failed %T: %v\n", w, err)
return return
} }
wsConn := &WsConnType{Id: uuid.New().String(), Shell: shell} defer shell.Conn.Close()
GlobalRunnerProc.AddWsConn(wsConn) clientId := r.URL.Query().Get("clientid")
if clientId == "" {
close(shell.WriteChan)
return
}
state := getWSState(clientId)
if state == nil {
state, err = MakeWSState(clientId)
if err != nil {
fmt.Printf("cannot make wsstate: %v\n", err)
close(shell.WriteChan)
return
}
state.replaceExistingShell(shell)
setWSState(state)
} else {
state.updateConnectTime()
state.replaceExistingShell(shell)
}
stateConnectTime := state.getConnectTime()
defer func() { defer func() {
GlobalRunnerProc.RemoveWsConn(wsConn) removeWSStateAfterTimeout(clientId, stateConnectTime, WSStateReconnectTime)
wsConn.Shell.Conn.Close()
}() }()
fmt.Printf("WebSocket opened %s\n", shell.RemoteAddr) shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error
fmt.Printf("WebSocket opened %s %s\n", shell.RemoteAddr, state.ClientId)
for msgBytes := range shell.ReadChan { for msgBytes := range shell.ReadChan {
pk, err := packet.ParseJsonPacket(msgBytes) pk, err := packet.ParseJsonPacket(msgBytes)
if err != nil { if err != nil {
fmt.Printf("error unmarshalling ws message: %v\n", err) fmt.Printf("error unmarshalling ws message: %v\n", err)
continue continue
} }
fmt.Printf("got ws message: %v\n", pk) if pk.GetType() == "getcmd" {
err = state.Tailer.AddWatch(pk.(*packet.GetCmdPacketType))
if err != nil {
fmt.Printf("error adding watch to tailer: %v\n", err)
}
continue
}
fmt.Printf("got ws bad message: %v\n", pk.GetType())
} }
} }
@ -104,6 +234,10 @@ func GetPtyOut(w http.ResponseWriter, r *http.Request) {
pathStr := GetPtyOutFile(sessionId, cmdId) pathStr := GetPtyOutFile(sessionId, cmdId)
fd, err := os.Open(pathStr) fd, err := os.Open(pathStr)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(500) w.WriteHeader(500)
w.Write([]byte(fmt.Sprintf("cannot open file '%s': %v", pathStr, err))) w.Write([]byte(fmt.Sprintf("cannot open file '%s': %v", pathStr, err)))
return return
@ -290,7 +424,7 @@ func LaunchRunnerProc() (*RunnerProc, error) {
} }
ecmd.Stderr = ecmd.Stdout // /dev/null ecmd.Stderr = ecmd.Stdout // /dev/null
ecmd.Start() ecmd.Start()
rtn := &RunnerProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd, WsConnMap: make(map[string]*WsConnType)} rtn := &RunnerProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd}
rtn.Output = packet.PacketParser(outputReader) rtn.Output = packet.PacketParser(outputReader)
rtn.Input = packet.MakePacketSender(inputWriter) rtn.Input = packet.MakePacketSender(inputWriter)
rtn.DoneCh = make(chan bool) rtn.DoneCh = make(chan bool)
@ -303,31 +437,10 @@ func LaunchRunnerProc() (*RunnerProc, error) {
return rtn, nil return rtn, nil
} }
func (runner *RunnerProc) ForwardDataPacket(pk *packet.CmdDataPacketType) int {
barr, err := json.Marshal(pk)
if err != nil {
fmt.Printf("cannot marshal cmddata packet %s/%s: %v)\n", pk.SessionId, pk.CmdId, err)
return 0
}
runner.Lock.Lock()
defer runner.Lock.Unlock()
numSent := 0
for _, ws := range runner.WsConnMap {
ok := ws.Shell.NonBlockingWrite(barr)
if !ok {
fmt.Printf("write was dropped, no queue space in '%s'\n", ws.Id)
continue
}
numSent++
}
return numSent
}
func (runner *RunnerProc) ProcessPackets() { func (runner *RunnerProc) ProcessPackets() {
for pk := range runner.Output { for pk := range runner.Output {
if pk.GetType() == packet.CmdDataPacketStr { if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType) dataPacket := pk.(*packet.CmdDataPacketType)
runner.ForwardDataPacket(dataPacket)
fmt.Printf("cmd-data %s/%s pty=%d run=%d\n", dataPacket.SessionId, dataPacket.CmdId, len(dataPacket.PtyData), len(dataPacket.RunData)) fmt.Printf("cmd-data %s/%s pty=%d run=%d\n", dataPacket.SessionId, dataPacket.CmdId, len(dataPacket.PtyData), len(dataPacket.RunData))
continue continue
} }

View File

@ -15,6 +15,7 @@ import (
const readWaitTimeout = 15 * time.Second const readWaitTimeout = 15 * time.Second
const writeWaitTimeout = 10 * time.Second const writeWaitTimeout = 10 * time.Second
const pingPeriodTickTime = 10 * time.Second const pingPeriodTickTime = 10 * time.Second
const initialPingTime = 1 * time.Second
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
ReadBufferSize: 4 * 1024, ReadBufferSize: 4 * 1024,
@ -63,13 +64,25 @@ func (ws *WSShell) WritePing() error {
return nil return nil
} }
func (ws *WSShell) WriteJson(val interface{}) error {
barr, err := json.Marshal(val)
if err != nil {
return err
}
ws.WriteChan <- barr
return nil
}
func (ws *WSShell) WritePump() { func (ws *WSShell) WritePump() {
ticker := time.NewTicker(pingPeriodTickTime) ticker := time.NewTicker(pingPeriodTickTime)
defer func() { defer func() {
ticker.Stop() ticker.Stop()
ws.Conn.Close() ws.Conn.Close()
}() }()
ws.WritePing() go func() {
time.Sleep(initialPingTime)
ws.WritePing()
}()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
@ -79,7 +92,10 @@ func (ws *WSShell) WritePump() {
return return
} }
case msgBytes := <-ws.WriteChan: case msgBytes, ok := <-ws.WriteChan:
if !ok {
return
}
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error _ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
err := ws.Conn.WriteMessage(websocket.TextMessage, msgBytes) err := ws.Conn.WriteMessage(websocket.TextMessage, msgBytes)
if err != nil { if err != nil {
@ -124,7 +140,16 @@ func (ws *WSShell) ReadPump() {
} }
ws.ReadChan <- message ws.ReadChan <- message
} }
}
func (ws *WSShell) IsClosed() bool {
select {
case <-ws.CloseChan:
return true
default:
return false
}
} }
func StartWS(w http.ResponseWriter, r *http.Request) (*WSShell, error) { func StartWS(w http.ResponseWriter, r *http.Request) (*WSShell, error) {