waveterm/wavesrv/pkg/scws/scws.go
Sylvie Crowe 167277ec11
Rename Waveshell First Pass (#632)
This begins the process of renaming mshell to waveshell everywhere by
making the most simple changes. There will need to be additional changes
in the future, but the hope is to merge simple changes in now to reduce
the number of future merge conflicts.
2024-05-02 14:16:00 -07:00

335 lines
8.8 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package scws
import (
"context"
"fmt"
"log"
"runtime/debug"
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/wavetermdev/waveterm/wavesrv/pkg/configstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/mapqueue"
"github.com/wavetermdev/waveterm/wavesrv/pkg/remote"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scbus"
"github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket"
"github.com/wavetermdev/waveterm/wavesrv/pkg/sstore"
"github.com/wavetermdev/waveterm/wavesrv/pkg/telemetry"
"github.com/wavetermdev/waveterm/wavesrv/pkg/userinput"
"github.com/wavetermdev/waveterm/wavesrv/pkg/wsshell"
)
const WSStatePacketChSize = 20
const RemoteInputQueueSize = 100
var RemoteInputMapQueue *mapqueue.MapQueue
func init() {
RemoteInputMapQueue = mapqueue.MakeMapQueue(RemoteInputQueueSize)
}
type WSState struct {
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
UpdateCh chan scbus.UpdatePacket
UpdateQueue []any
Authenticated bool
AuthKey string
SessionId string
ScreenId string
}
func MakeWSState(clientId string, authKey string) *WSState {
rtn := &WSState{}
rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId
rtn.ConnectTime = time.Now()
rtn.AuthKey = authKey
return rtn
}
func (ws *WSState) SetAuthenticated(authVal bool) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.Authenticated = authVal
}
func (ws *WSState) IsAuthenticated() bool {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Authenticated
}
func (ws *WSState) GetShell() *wsshell.WSShell {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Shell
}
func (ws *WSState) WriteUpdate(update any) error {
shell := ws.GetShell()
if shell == nil {
return fmt.Errorf("cannot write update, empty shell")
}
err := shell.WriteJson(update)
if err != nil {
return err
}
return nil
}
func (ws *WSState) UpdateConnectTime() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.ConnectTime = time.Now()
}
func (ws *WSState) GetConnectTime() time.Time {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.ConnectTime
}
func (ws *WSState) WatchScreen(sessionId string, screenId string) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.SessionId == sessionId && ws.ScreenId == screenId {
return
}
ws.SessionId = sessionId
ws.ScreenId = screenId
ws.UpdateCh = scbus.MainUpdateBus.RegisterChannel(ws.ClientId, &scbus.UpdateChannel{ScreenId: ws.ScreenId})
log.Printf("[ws] watch screen clientid=%s sessionid=%s screenid=%s, updateCh=%v\n", ws.ClientId, sessionId, screenId, ws.UpdateCh)
go ws.RunUpdates(ws.UpdateCh)
}
func (ws *WSState) UnWatchScreen() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
scbus.MainUpdateBus.UnregisterChannel(ws.ClientId)
ws.SessionId = ""
ws.ScreenId = ""
log.Printf("[ws] unwatch screen clientid=%s\n", ws.ClientId)
}
func (ws *WSState) RunUpdates(updateCh chan scbus.UpdatePacket) {
if updateCh == nil {
panic("invalid nil updateCh passed to RunUpdates")
}
for update := range updateCh {
shell := ws.GetShell()
if shell != nil {
writeJsonProtected(shell, update)
}
}
}
func writeJsonProtected(shell *wsshell.WSShell, update any) {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[error] in scws RunUpdates WriteJson: %v\n", r)
}()
shell.WriteJson(update)
}
func (ws *WSState) ReplaceShell(shell *wsshell.WSShell) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.Shell == nil {
ws.Shell = shell
return
}
ws.Shell.Conn.Close()
ws.Shell = shell
}
// returns all state required to display current UI
func (ws *WSState) handleConnection() error {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
connectUpdate, err := sstore.GetConnectUpdate(ctx)
if err != nil {
return fmt.Errorf("getting sessions: %w", err)
}
remotes := remote.GetAllRemoteRuntimeState()
connectUpdate.Remotes = remotes
// restore status indicators
connectUpdate.ScreenStatusIndicators, connectUpdate.ScreenNumRunningCommands = sstore.GetCurrentIndicatorState()
configs, err := configstore.ScanConfigs()
if err != nil {
return fmt.Errorf("getting configs: %w", err)
}
connectUpdate.TermThemes = &configs
mu := scbus.MakeUpdatePacket()
mu.AddUpdate(*connectUpdate)
err = ws.Shell.WriteJson(mu)
if err != nil {
return err
}
return nil
}
func (ws *WSState) handleWatchScreen(wsPk *scpacket.WatchScreenPacketType) error {
if wsPk.SessionId != "" {
if _, err := uuid.Parse(wsPk.SessionId); err != nil {
return fmt.Errorf("invalid watchscreen sessionid: %w", err)
}
}
if wsPk.ScreenId != "" {
if _, err := uuid.Parse(wsPk.ScreenId); err != nil {
return fmt.Errorf("invalid watchscreen screenid: %w", err)
}
}
if wsPk.AuthKey == "" {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, no authkey")
}
if wsPk.AuthKey != ws.AuthKey {
ws.SetAuthenticated(false)
return fmt.Errorf("invalid watchscreen, invalid authkey")
}
ws.SetAuthenticated(true)
if wsPk.SessionId == "" || wsPk.ScreenId == "" {
ws.UnWatchScreen()
} else {
ws.WatchScreen(wsPk.SessionId, wsPk.ScreenId)
log.Printf("[ws %s] watchscreen %s/%s\n", ws.ClientId, wsPk.SessionId, wsPk.ScreenId)
}
if wsPk.Connect {
// log.Printf("[ws %s] watchscreen connect\n", ws.ClientId)
err := ws.handleConnection()
if err != nil {
return fmt.Errorf("connect: %w", err)
}
}
return nil
}
func (ws *WSState) processMessage(msgBytes []byte) error {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("[scws] panic in processMessage: %v\n", r)
debug.PrintStack()
}()
pk, err := packet.ParseJsonPacket(msgBytes)
if err != nil {
return fmt.Errorf("error unmarshalling ws message: %w", err)
}
if pk.GetType() == scpacket.WatchScreenPacketStr {
wsPk := pk.(*scpacket.WatchScreenPacketType)
err := ws.handleWatchScreen(wsPk)
if err != nil {
return fmt.Errorf("client:%s error %w", ws.ClientId, err)
}
return nil
}
isAuth := ws.IsAuthenticated()
if !isAuth {
return fmt.Errorf("cannot process ws-packet[%s], not authenticated", pk.GetType())
}
if pk.GetType() == scpacket.FeInputPacketStr {
feInputPk := pk.(*scpacket.FeInputPacketType)
if feInputPk.Remote.OwnerId != "" {
return fmt.Errorf("error cannot send input to remote with ownerid")
}
if feInputPk.Remote.RemoteId == "" {
return fmt.Errorf("error invalid input packet, remoteid is not set")
}
err := RemoteInputMapQueue.Enqueue(feInputPk.Remote.RemoteId, func() {
sendErr := sendCmdInput(feInputPk)
if sendErr != nil {
log.Printf("[scws] sending command input: %v\n", sendErr)
}
})
if err != nil {
return fmt.Errorf("[error] could not queue sendCmdInput: %w", err)
}
return nil
}
if pk.GetType() == scpacket.RemoteInputPacketStr {
inputPk := pk.(*scpacket.RemoteInputPacketType)
if inputPk.RemoteId == "" {
return fmt.Errorf("error invalid remoteinput packet, remoteid is not set")
}
go func() {
sendErr := remote.SendRemoteInput(inputPk)
if sendErr != nil {
log.Printf("[scws] error processing remote input: %v\n", sendErr)
}
}()
return nil
}
if pk.GetType() == scpacket.CmdInputTextPacketStr {
cmdInputPk := pk.(*scpacket.CmdInputTextPacketType)
if cmdInputPk.ScreenId == "" {
return fmt.Errorf("error invalid cmdinput packet, screenid is not set")
}
// no need for goroutine for memory ops
sstore.ScreenMemSetCmdInputText(cmdInputPk.ScreenId, cmdInputPk.Text, cmdInputPk.SeqNum)
return nil
}
if pk.GetType() == userinput.UserInputResponsePacketStr {
userInputRespPk := pk.(*userinput.UserInputResponsePacketType)
uich, ok := scbus.MainRpcBus.GetRpcChannel(userInputRespPk.RequestId)
if !ok {
return fmt.Errorf("received User Input Response with invalid Id (%s): %v", userInputRespPk.RequestId, err)
}
select {
case uich <- userInputRespPk:
default:
}
return nil
}
if pk.GetType() == scpacket.FeActivityPacketStr {
feActivityPk := pk.(*scpacket.FeActivityPacketType)
telemetry.UpdateFeActivityWrap(feActivityPk)
return nil
}
return fmt.Errorf("got ws bad message: %v", pk.GetType())
}
func (ws *WSState) RunWSRead() {
shell := ws.GetShell()
if shell == nil {
return
}
shell.WriteJson(map[string]any{"type": "hello"}) // let client know we accepted this connection, ignore error
for msgBytes := range shell.ReadChan {
err := ws.processMessage(msgBytes)
if err != nil {
// TODO send errors back to client? likely unrecoverable
log.Printf("[scws] %v\n", err)
}
}
}
func sendCmdInput(pk *scpacket.FeInputPacketType) error {
err := pk.CK.Validate("input packet")
if err != nil {
return err
}
if pk.Remote.RemoteId == "" {
return fmt.Errorf("input must set remoteid")
}
wsh := remote.GetRemoteById(pk.Remote.RemoteId)
if wsh == nil {
return fmt.Errorf("remote %s not found", pk.Remote.RemoteId)
}
return wsh.HandleFeInput(pk)
}