mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
01b5d71709
lots of changes. new wshrpc implementation. unify websocket, web, blockcontroller, domain sockets, and terminal inputs to all use the new rpc system. lots of moving files around to deal with circular dependencies use new wshrpc as a client in wsh cmd
313 lines
7.8 KiB
Go
313 lines
7.8 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package web
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"runtime/debug"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/mux"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
|
"github.com/wavetermdev/thenextwave/pkg/web/webcmd"
|
|
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
|
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
|
)
|
|
|
|
// set by main-server.go (for dependency inversion)
|
|
var WshServerFactoryFn func(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) = nil
|
|
|
|
const wsReadWaitTimeout = 15 * time.Second
|
|
const wsWriteWaitTimeout = 10 * time.Second
|
|
const wsPingPeriodTickTime = 10 * time.Second
|
|
const wsInitialPingTime = 1 * time.Second
|
|
|
|
const DefaultCommandTimeout = 2 * time.Second
|
|
|
|
func RunWebSocketServer() {
|
|
gr := mux.NewRouter()
|
|
gr.HandleFunc("/ws", HandleWs)
|
|
serverAddr := WebSocketServerDevAddr
|
|
server := &http.Server{
|
|
Addr: serverAddr,
|
|
ReadTimeout: HttpReadTimeout,
|
|
WriteTimeout: HttpWriteTimeout,
|
|
MaxHeaderBytes: HttpMaxHeaderBytes,
|
|
Handler: gr,
|
|
}
|
|
server.SetKeepAlivesEnabled(false)
|
|
log.Printf("Running websocket server on %s\n", serverAddr)
|
|
err := server.ListenAndServe()
|
|
if err != nil {
|
|
log.Printf("[error] trying to run websocket server: %v\n", err)
|
|
}
|
|
}
|
|
|
|
var WebSocketUpgrader = websocket.Upgrader{
|
|
ReadBufferSize: 4 * 1024,
|
|
WriteBufferSize: 32 * 1024,
|
|
HandshakeTimeout: 1 * time.Second,
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
|
|
func HandleWs(w http.ResponseWriter, r *http.Request) {
|
|
err := HandleWsInternal(w, r)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func getMessageType(jmsg map[string]any) string {
|
|
if str, ok := jmsg["type"].(string); ok {
|
|
return str
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func getStringFromMap(jmsg map[string]any, key string) string {
|
|
if str, ok := jmsg[key].(string); ok {
|
|
return str
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan []byte) {
|
|
var rtnErr error
|
|
defer func() {
|
|
r := recover()
|
|
if r != nil {
|
|
rtnErr = fmt.Errorf("panic: %v", r)
|
|
log.Printf("panic in processMessage: %v\n", r)
|
|
debug.PrintStack()
|
|
}
|
|
if rtnErr == nil {
|
|
return
|
|
}
|
|
rtn := map[string]any{"type": "error", "error": rtnErr.Error()}
|
|
outputCh <- rtn
|
|
}()
|
|
wsCommand, err := webcmd.ParseWSCommandMap(jmsg)
|
|
if err != nil {
|
|
rtnErr = fmt.Errorf("cannot parse wscommand: %v", err)
|
|
return
|
|
}
|
|
switch cmd := wsCommand.(type) {
|
|
case *webcmd.SetBlockTermSizeWSCommand:
|
|
data := wshrpc.CommandBlockInputData{
|
|
BlockId: cmd.BlockId,
|
|
TermSize: &cmd.TermSize,
|
|
}
|
|
rpcMsg := wshutil.RpcMessage{
|
|
Command: wshrpc.Command_BlockInput,
|
|
Data: data,
|
|
}
|
|
msgBytes, err := json.Marshal(rpcMsg)
|
|
if err != nil {
|
|
// this really should never fail since we just unmarshalled this value
|
|
log.Printf("error marshalling rpc message: %v\n", err)
|
|
return
|
|
}
|
|
rpcInputCh <- msgBytes
|
|
|
|
case *webcmd.BlockInputWSCommand:
|
|
data := wshrpc.CommandBlockInputData{
|
|
BlockId: cmd.BlockId,
|
|
InputData64: cmd.InputData64,
|
|
}
|
|
rpcMsg := wshutil.RpcMessage{
|
|
Command: wshrpc.Command_BlockInput,
|
|
Data: data,
|
|
}
|
|
msgBytes, err := json.Marshal(rpcMsg)
|
|
if err != nil {
|
|
// this really should never fail since we just unmarshalled this value
|
|
log.Printf("error marshalling rpc message: %v\n", err)
|
|
return
|
|
}
|
|
rpcInputCh <- msgBytes
|
|
|
|
case *webcmd.WSRpcCommand:
|
|
rpcMsg := cmd.Message
|
|
if rpcMsg == nil {
|
|
return
|
|
}
|
|
msgBytes, err := json.Marshal(rpcMsg)
|
|
if err != nil {
|
|
// this really should never fail since we just unmarshalled this value
|
|
return
|
|
}
|
|
rpcInputCh <- msgBytes
|
|
}
|
|
}
|
|
|
|
func processMessage(jmsg map[string]any, outputCh chan any, rpcInputCh chan []byte) {
|
|
wsCommand := getStringFromMap(jmsg, "wscommand")
|
|
if wsCommand != "" {
|
|
processWSCommand(jmsg, outputCh, rpcInputCh)
|
|
return
|
|
}
|
|
msgType := getMessageType(jmsg)
|
|
if msgType != "rpc" {
|
|
return
|
|
}
|
|
reqId := getStringFromMap(jmsg, "reqid")
|
|
var rtnErr error
|
|
defer func() {
|
|
r := recover()
|
|
if r != nil {
|
|
rtnErr = fmt.Errorf("panic: %v", r)
|
|
log.Printf("panic in processMessage: %v\n", r)
|
|
debug.PrintStack()
|
|
}
|
|
if rtnErr == nil {
|
|
return
|
|
}
|
|
rtn := map[string]any{"type": "rpcresp", "reqid": reqId, "error": rtnErr.Error()}
|
|
outputCh <- rtn
|
|
}()
|
|
method := getStringFromMap(jmsg, "method")
|
|
rtnErr = fmt.Errorf("unknown method %q", method)
|
|
}
|
|
|
|
func ReadLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, rpcInputCh chan []byte) {
|
|
readWait := wsReadWaitTimeout
|
|
conn.SetReadLimit(64 * 1024)
|
|
conn.SetReadDeadline(time.Now().Add(readWait))
|
|
defer close(closeCh)
|
|
for {
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
log.Printf("ReadPump error: %v\n", err)
|
|
break
|
|
}
|
|
jmsg := map[string]any{}
|
|
err = json.Unmarshal(message, &jmsg)
|
|
if err != nil {
|
|
log.Printf("Error unmarshalling json: %v\n", err)
|
|
break
|
|
}
|
|
conn.SetReadDeadline(time.Now().Add(readWait))
|
|
msgType := getMessageType(jmsg)
|
|
if msgType == "pong" {
|
|
// nothing
|
|
continue
|
|
}
|
|
if msgType == "ping" {
|
|
now := time.Now()
|
|
pongMessage := map[string]interface{}{"type": "pong", "stime": now.UnixMilli()}
|
|
outputCh <- pongMessage
|
|
continue
|
|
}
|
|
go processMessage(jmsg, outputCh, rpcInputCh)
|
|
}
|
|
}
|
|
|
|
func WritePing(conn *websocket.Conn) error {
|
|
now := time.Now()
|
|
pingMessage := map[string]interface{}{"type": "ping", "stime": now.UnixMilli()}
|
|
jsonVal, _ := json.Marshal(pingMessage)
|
|
_ = conn.SetWriteDeadline(time.Now().Add(wsWriteWaitTimeout)) // no error
|
|
err := conn.WriteMessage(websocket.TextMessage, jsonVal)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func WriteLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any) {
|
|
ticker := time.NewTicker(wsInitialPingTime)
|
|
defer ticker.Stop()
|
|
initialPing := true
|
|
for {
|
|
select {
|
|
case msg := <-outputCh:
|
|
var barr []byte
|
|
var err error
|
|
if _, ok := msg.([]byte); ok {
|
|
barr = msg.([]byte)
|
|
} else {
|
|
barr, err = json.Marshal(msg)
|
|
if err != nil {
|
|
log.Printf("cannot marshal websocket message: %v\n", err)
|
|
// just loop again
|
|
break
|
|
}
|
|
}
|
|
err = conn.WriteMessage(websocket.TextMessage, barr)
|
|
if err != nil {
|
|
conn.Close()
|
|
log.Printf("WritePump error: %v\n", err)
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
err := WritePing(conn)
|
|
if err != nil {
|
|
log.Printf("WritePump error: %v\n", err)
|
|
return
|
|
}
|
|
if initialPing {
|
|
initialPing = false
|
|
ticker.Reset(wsPingPeriodTickTime)
|
|
}
|
|
|
|
case <-closeCh:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
|
|
windowId := r.URL.Query().Get("windowid")
|
|
if windowId == "" {
|
|
return fmt.Errorf("windowid is required")
|
|
}
|
|
conn, err := WebSocketUpgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("WebSocket Upgrade Failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
wsConnId := uuid.New().String()
|
|
log.Printf("New websocket connection: windowid:%s connid:%s\n", windowId, wsConnId)
|
|
outputCh := make(chan any, 100)
|
|
closeCh := make(chan any)
|
|
rpcInputCh := make(chan []byte, 32)
|
|
rpcOutputCh := make(chan []byte, 32)
|
|
eventbus.RegisterWSChannel(wsConnId, windowId, outputCh)
|
|
defer eventbus.UnregisterWSChannel(wsConnId)
|
|
WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshutil.RpcContext{WindowId: windowId})
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(2)
|
|
go func() {
|
|
// no waitgroup add here
|
|
// move values from rpcOutputCh to outputCh
|
|
for msgBytes := range rpcOutputCh {
|
|
rpcWSMsg := map[string]any{
|
|
"eventtype": "rpc", // TODO don't hard code this (but def is in eventbus)
|
|
"data": json.RawMessage(msgBytes),
|
|
}
|
|
outputCh <- rpcWSMsg
|
|
}
|
|
}()
|
|
go func() {
|
|
// read loop
|
|
defer wg.Done()
|
|
ReadLoop(conn, outputCh, closeCh, rpcInputCh)
|
|
}()
|
|
go func() {
|
|
// write loop
|
|
defer wg.Done()
|
|
WriteLoop(conn, outputCh, closeCh)
|
|
}()
|
|
wg.Wait()
|
|
close(rpcInputCh)
|
|
return nil
|
|
}
|