mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
8acda3525b
* Break update code out of sstore * add license disclaimers * missed one * add another * fix regression in openai updates, remove unnecessary functions * another copyright * update casts * fix issue with variadic updates * remove logs * remove log * remove unnecessary log * save work * moved a bunch of stuff to scbus * make modelupdate an object * fix new screen not updating active screen * add comment * make updates into packet types * different cast * update comments, remove unused methods * add one more comment * add an IsEmpty() on model updates to prevent sending empty updates to client
348 lines
9.0 KiB
Go
348 lines
9.0 KiB
Go
// Copyright 2023, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package scws
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"runtime/debug"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/mapqueue"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/userinput"
|
|
"github.com/wavetermdev/waveterm/wavesrv/pkg/wsshell"
|
|
)
|
|
|
|
const WSStatePacketChSize = 20
|
|
const MaxInputDataSize = 1000
|
|
const RemoteInputQueueSize = 100
|
|
|
|
var RemoteInputMapQueue *mapqueue.MapQueue
|
|
|
|
func init() {
|
|
RemoteInputMapQueue = mapqueue.MakeMapQueue(RemoteInputQueueSize)
|
|
}
|
|
|
|
type WSState struct {
|
|
Lock *sync.Mutex
|
|
ClientId string
|
|
ConnectTime time.Time
|
|
Shell *wsshell.WSShell
|
|
UpdateCh chan scbus.UpdatePacket
|
|
UpdateQueue []any
|
|
Authenticated bool
|
|
AuthKey string
|
|
|
|
SessionId string
|
|
ScreenId string
|
|
}
|
|
|
|
func MakeWSState(clientId string, authKey string) *WSState {
|
|
rtn := &WSState{}
|
|
rtn.Lock = &sync.Mutex{}
|
|
rtn.ClientId = clientId
|
|
rtn.ConnectTime = time.Now()
|
|
rtn.AuthKey = authKey
|
|
return rtn
|
|
}
|
|
|
|
func (ws *WSState) SetAuthenticated(authVal bool) {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
ws.Authenticated = authVal
|
|
}
|
|
|
|
func (ws *WSState) IsAuthenticated() bool {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
return ws.Authenticated
|
|
}
|
|
|
|
func (ws *WSState) GetShell() *wsshell.WSShell {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
return ws.Shell
|
|
}
|
|
|
|
func (ws *WSState) WriteUpdate(update any) error {
|
|
shell := ws.GetShell()
|
|
if shell == nil {
|
|
return fmt.Errorf("cannot write update, empty shell")
|
|
}
|
|
err := shell.WriteJson(update)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ws *WSState) UpdateConnectTime() {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
ws.ConnectTime = time.Now()
|
|
}
|
|
|
|
func (ws *WSState) GetConnectTime() time.Time {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
return ws.ConnectTime
|
|
}
|
|
|
|
func (ws *WSState) WatchScreen(sessionId string, screenId string) {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
if ws.SessionId == sessionId && ws.ScreenId == screenId {
|
|
return
|
|
}
|
|
ws.SessionId = sessionId
|
|
ws.ScreenId = screenId
|
|
ws.UpdateCh = scbus.MainUpdateBus.RegisterChannel(ws.ClientId, &scbus.UpdateChannel{ScreenId: ws.ScreenId})
|
|
log.Printf("[ws] watch screen clientid=%s sessionid=%s screenid=%s, updateCh=%v\n", ws.ClientId, sessionId, screenId, ws.UpdateCh)
|
|
go ws.RunUpdates(ws.UpdateCh)
|
|
}
|
|
|
|
func (ws *WSState) UnWatchScreen() {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
scbus.MainUpdateBus.UnregisterChannel(ws.ClientId)
|
|
ws.SessionId = ""
|
|
ws.ScreenId = ""
|
|
log.Printf("[ws] unwatch screen clientid=%s\n", ws.ClientId)
|
|
}
|
|
|
|
func (ws *WSState) RunUpdates(updateCh chan scbus.UpdatePacket) {
|
|
if updateCh == nil {
|
|
panic("invalid nil updateCh passed to RunUpdates")
|
|
}
|
|
for update := range updateCh {
|
|
shell := ws.GetShell()
|
|
if shell != nil {
|
|
writeJsonProtected(shell, update)
|
|
}
|
|
}
|
|
}
|
|
|
|
func writeJsonProtected(shell *wsshell.WSShell, update any) {
|
|
defer func() {
|
|
r := recover()
|
|
if r == nil {
|
|
return
|
|
}
|
|
log.Printf("[error] in scws RunUpdates WriteJson: %v\n", r)
|
|
}()
|
|
shell.WriteJson(update)
|
|
}
|
|
|
|
func (ws *WSState) ReplaceShell(shell *wsshell.WSShell) {
|
|
ws.Lock.Lock()
|
|
defer ws.Lock.Unlock()
|
|
if ws.Shell == nil {
|
|
ws.Shell = shell
|
|
return
|
|
}
|
|
ws.Shell.Conn.Close()
|
|
ws.Shell = shell
|
|
}
|
|
|
|
// returns all state required to display current UI
|
|
func (ws *WSState) handleConnection() error {
|
|
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancelFn()
|
|
connectUpdate, err := sstore.GetConnectUpdate(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("getting sessions: %w", err)
|
|
}
|
|
remotes := remote.GetAllRemoteRuntimeState()
|
|
connectUpdate.Remotes = remotes
|
|
// restore status indicators
|
|
connectUpdate.ScreenStatusIndicators, connectUpdate.ScreenNumRunningCommands = sstore.GetCurrentIndicatorState()
|
|
mu := scbus.MakeUpdatePacket()
|
|
mu.AddUpdate(*connectUpdate)
|
|
err = ws.Shell.WriteJson(mu)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error {
|
|
if wsPk.SessionId != "" {
|
|
if _, err := uuid.Parse(wsPk.SessionId); err != nil {
|
|
return fmt.Errorf("invalid watchscreen sessionid: %w", err)
|
|
}
|
|
}
|
|
if wsPk.ScreenId != "" {
|
|
if _, err := uuid.Parse(wsPk.ScreenId); err != nil {
|
|
return fmt.Errorf("invalid watchscreen screenid: %w", err)
|
|
}
|
|
}
|
|
if wsPk.AuthKey == "" {
|
|
ws.SetAuthenticated(false)
|
|
return fmt.Errorf("invalid watchscreen, no authkey")
|
|
}
|
|
if wsPk.AuthKey != ws.AuthKey {
|
|
ws.SetAuthenticated(false)
|
|
return fmt.Errorf("invalid watchscreen, invalid authkey")
|
|
}
|
|
ws.SetAuthenticated(true)
|
|
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
|
|
ws.UnWatchScreen()
|
|
} else {
|
|
ws.WatchScreen(wsPk.SessionId, wsPk.ScreenId)
|
|
log.Printf("[ws %s] watchscreen %s/%s\n", ws.ClientId, wsPk.SessionId, wsPk.ScreenId)
|
|
}
|
|
if wsPk.Connect {
|
|
// log.Printf("[ws %s] watchscreen connect\n", ws.ClientId)
|
|
err := ws.handleConnection()
|
|
if err != nil {
|
|
return fmt.Errorf("connect: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ws *WSState) processMessage(msgBytes []byte) error {
|
|
defer func() {
|
|
r := recover()
|
|
if r == nil {
|
|
return
|
|
}
|
|
log.Printf("[scws] panic in processMessage: %v\n", r)
|
|
debug.PrintStack()
|
|
}()
|
|
|
|
pk, err := packet.ParseJsonPacket(msgBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("error unmarshalling ws message: %w", err)
|
|
}
|
|
if pk.GetType() == scpacket.WatchScreenPacketStr {
|
|
wsPk := pk.(*scpacket.WatchScreenPacketType)
|
|
err := ws.handleWatchScreen(wsPk)
|
|
if err != nil {
|
|
return fmt.Errorf("client:%s error %w", ws.ClientId, err)
|
|
}
|
|
return nil
|
|
}
|
|
isAuth := ws.IsAuthenticated()
|
|
if !isAuth {
|
|
return fmt.Errorf("cannot process ws-packet[%s], not authenticated", pk.GetType())
|
|
}
|
|
if pk.GetType() == scpacket.FeInputPacketStr {
|
|
feInputPk := pk.(*scpacket.FeInputPacketType)
|
|
if feInputPk.Remote.OwnerId != "" {
|
|
return fmt.Errorf("error cannot send input to remote with ownerid")
|
|
}
|
|
if feInputPk.Remote.RemoteId == "" {
|
|
return fmt.Errorf("error invalid input packet, remoteid is not set")
|
|
}
|
|
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
|
|
sendErr := sendCmdInput(feInputPk)
|
|
if sendErr != nil {
|
|
log.Printf("[scws] sending command input: %v\n", err)
|
|
}
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("[error] could not queue sendCmdInput: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
if pk.GetType() == scpacket.RemoteInputPacketStr {
|
|
inputPk := pk.(*scpacket.RemoteInputPacketType)
|
|
if inputPk.RemoteId == "" {
|
|
return fmt.Errorf("error invalid remoteinput packet, remoteid is not set")
|
|
}
|
|
go func() {
|
|
sendErr := remote.SendRemoteInput(inputPk)
|
|
if sendErr != nil {
|
|
log.Printf("[scws] error processing remote input: %v\n", err)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
if pk.GetType() == scpacket.CmdInputTextPacketStr {
|
|
cmdInputPk := pk.(*scpacket.CmdInputTextPacketType)
|
|
if cmdInputPk.ScreenId == "" {
|
|
return fmt.Errorf("error invalid cmdinput packet, screenid is not set")
|
|
}
|
|
// no need for goroutine for memory ops
|
|
sstore.ScreenMemSetCmdInputText(cmdInputPk.ScreenId, cmdInputPk.Text, cmdInputPk.SeqNum)
|
|
return nil
|
|
}
|
|
if pk.GetType() == userinput.UserInputResponsePacketStr {
|
|
userInputRespPk := pk.(*userinput.UserInputResponsePacketType)
|
|
uich, ok := scbus.MainRpcBus.GetRpcChannel(userInputRespPk.RequestId)
|
|
if !ok {
|
|
return fmt.Errorf("received User Input Response with invalid Id (%s): %v", userInputRespPk.RequestId, err)
|
|
}
|
|
select {
|
|
case uich <- userInputRespPk:
|
|
default:
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("got ws bad message: %v", pk.GetType())
|
|
}
|
|
|
|
func (ws *WSState) RunWSRead() {
|
|
shell := ws.GetShell()
|
|
if shell == nil {
|
|
return
|
|
}
|
|
shell.WriteJson(map[string]any{"type": "hello"}) // let client know we accepted this connection, ignore error
|
|
for msgBytes := range shell.ReadChan {
|
|
err := ws.processMessage(msgBytes)
|
|
if err != nil {
|
|
// TODO send errors back to client? likely unrecoverable
|
|
log.Printf("[scws] %v\n", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func sendCmdInput(pk *scpacket.FeInputPacketType) error {
|
|
err := pk.CK.Validate("input packet")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if pk.Remote.RemoteId == "" {
|
|
return fmt.Errorf("input must set remoteid")
|
|
}
|
|
msh := remote.GetRemoteById(pk.Remote.RemoteId)
|
|
if msh == nil {
|
|
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
|
|
}
|
|
if len(pk.InputData64) > 0 {
|
|
inputLen := packet.B64DecodedLen(pk.InputData64)
|
|
if inputLen > MaxInputDataSize {
|
|
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
|
|
}
|
|
dataPk := packet.MakeDataPacket()
|
|
dataPk.CK = pk.CK
|
|
dataPk.FdNum = 0 // stdin
|
|
dataPk.Data64 = pk.InputData64
|
|
err = msh.SendInput(dataPk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if pk.SigName != "" || pk.WinSize != nil {
|
|
siPk := packet.MakeSpecialInputPacket()
|
|
siPk.CK = pk.CK
|
|
siPk.SigName = pk.SigName
|
|
siPk.WinSize = pk.WinSize
|
|
err = msh.SendSpecialInput(siPk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|