waveterm/wavesrv/pkg/wsshell/wsshell.go
2023-10-16 21:31:13 -07:00

193 lines
4.0 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsshell
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
const readWaitTimeout = 15 * time.Second
const writeWaitTimeout = 10 * time.Second
const pingPeriodTickTime = 10 * time.Second
const initialPingTime = 1 * time.Second
var upgrader = websocket.Upgrader{
ReadBufferSize: 4 * 1024,
WriteBufferSize: 32 * 1024,
HandshakeTimeout: 1 * time.Second,
CheckOrigin: func(r *http.Request) bool { return true },
}
type WSShell struct {
Conn *websocket.Conn
RemoteAddr string
ConnId string
Query url.Values
OpenTime time.Time
NumPings int
LastPing time.Time
LastRecv time.Time
Header http.Header
CloseChan chan bool
WriteChan chan []byte
ReadChan chan []byte
}
func (ws *WSShell) NonBlockingWrite(data []byte) bool {
select {
case ws.WriteChan <- data:
return true
default:
return false
}
}
func (ws *WSShell) WritePing() error {
now := time.Now()
pingMessage := map[string]interface{}{"type": "ping", "stime": now.Unix()}
jsonVal, _ := json.Marshal(pingMessage)
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
err := ws.Conn.WriteMessage(websocket.TextMessage, jsonVal)
ws.NumPings++
ws.LastPing = now
if err != nil {
return err
}
return nil
}
func (ws *WSShell) WriteJson(val interface{}) error {
if ws.IsClosed() {
return fmt.Errorf("cannot write packet, empty or closed wsshell")
}
barr, err := json.Marshal(val)
if err != nil {
return err
}
ws.WriteChan <- barr
return nil
}
func (ws *WSShell) WritePump() {
ticker := time.NewTicker(initialPingTime)
defer func() {
ticker.Stop()
ws.Conn.Close()
}()
initialPing := true
for {
select {
case <-ticker.C:
err := ws.WritePing()
if err != nil {
log.Printf("WritePump %s err: %v\n", ws.RemoteAddr, err)
return
}
if initialPing {
initialPing = false
ticker.Reset(pingPeriodTickTime)
}
case msgBytes, ok := <-ws.WriteChan:
if !ok {
return
}
_ = ws.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)) // no error
err := ws.Conn.WriteMessage(websocket.TextMessage, msgBytes)
if err != nil {
log.Printf("WritePump %s err: %v\n", ws.RemoteAddr, err)
return
}
}
}
}
func (ws *WSShell) ReadPump() {
readWait := readWaitTimeout
defer func() {
ws.Conn.Close()
}()
ws.Conn.SetReadLimit(4096)
ws.Conn.SetReadDeadline(time.Now().Add(readWait))
for {
_, message, err := ws.Conn.ReadMessage()
if err != nil {
log.Printf("ReadPump %s Err: %v\n", ws.RemoteAddr, err)
break
}
jmsg := map[string]interface{}{}
err = json.Unmarshal(message, &jmsg)
if err != nil {
log.Printf("Error unmarshalling json: %v\n", err)
break
}
ws.Conn.SetReadDeadline(time.Now().Add(readWait))
ws.LastRecv = time.Now()
if str, ok := jmsg["type"].(string); ok && str == "pong" {
// nothing
continue
}
if str, ok := jmsg["type"].(string); ok && str == "ping" {
now := time.Now()
pongMessage := map[string]interface{}{"type": "pong", "stime": now.Unix()}
jsonVal, _ := json.Marshal(pongMessage)
ws.WriteChan <- jsonVal
continue
}
ws.ReadChan <- message
}
}
func (ws *WSShell) IsClosed() bool {
select {
case <-ws.CloseChan:
return true
default:
return false
}
}
func StartWS(w http.ResponseWriter, r *http.Request) (*WSShell, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}
ws := WSShell{Conn: conn, ConnId: uuid.New().String(), OpenTime: time.Now()}
ws.CloseChan = make(chan bool)
ws.WriteChan = make(chan []byte, 10)
ws.ReadChan = make(chan []byte, 10)
ws.RemoteAddr = r.RemoteAddr
ws.Query = r.URL.Query()
ws.Header = r.Header
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ws.WritePump()
}()
wg.Add(1)
go func() {
defer wg.Done()
ws.ReadPump()
}()
go func() {
wg.Wait()
close(ws.CloseChan)
close(ws.ReadChan)
}()
return &ws, nil
}