waveterm/pkg/wshutil/wshrpc.go

664 lines
16 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshutil
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"reflect"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
)
const DefaultTimeoutMs = 5000
const RespChSize = 32
const DefaultMessageChSize = 32
type ResponseFnType = func(any) error
// returns true if handler is complete, false for an async handler
type CommandHandlerFnType = func(*RpcResponseHandler) bool
type ServerImpl interface {
WshServerImpl()
}
type AbstractRpcClient interface {
SendRpcMessage(msg []byte)
RecvRpcMessage() ([]byte, bool) // blocking
}
type WshRpc struct {
Lock *sync.Mutex
clientId string
InputCh chan []byte
OutputCh chan []byte
RpcContext *atomic.Pointer[wshrpc.RpcContext]
RpcMap map[string]*rpcData
ServerImpl ServerImpl
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
}
type wshRpcContextKey struct{}
type wshRpcRespHandlerContextKey struct{}
func withWshRpcContext(ctx context.Context, wshRpc *WshRpc) context.Context {
return context.WithValue(ctx, wshRpcContextKey{}, wshRpc)
}
func withRespHandler(ctx context.Context, handler *RpcResponseHandler) context.Context {
return context.WithValue(ctx, wshRpcRespHandlerContextKey{}, handler)
}
func GetWshRpcFromContext(ctx context.Context) *WshRpc {
rtn := ctx.Value(wshRpcContextKey{})
if rtn == nil {
return nil
}
return rtn.(*WshRpc)
}
func GetRpcSourceFromContext(ctx context.Context) string {
rtn := ctx.Value(wshRpcRespHandlerContextKey{})
if rtn == nil {
return ""
}
return rtn.(*RpcResponseHandler).GetSource()
}
func GetRpcResponseHandlerFromContext(ctx context.Context) *RpcResponseHandler {
rtn := ctx.Value(wshRpcRespHandlerContextKey{})
if rtn == nil {
return nil
}
return rtn.(*RpcResponseHandler)
}
func (w *WshRpc) SendRpcMessage(msg []byte) {
w.InputCh <- msg
}
func (w *WshRpc) RecvRpcMessage() ([]byte, bool) {
msg, more := <-w.OutputCh
return msg, more
}
type RpcMessage struct {
Command string `json:"command,omitempty"`
ReqId string `json:"reqid,omitempty"`
ResId string `json:"resid,omitempty"`
Timeout int `json:"timeout,omitempty"`
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
Source string `json:"source,omitempty"` // source route id
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
Error string `json:"error,omitempty"`
DataType string `json:"datatype,omitempty"`
Data any `json:"data,omitempty"`
}
func (r *RpcMessage) IsRpcRequest() bool {
return r.Command != "" || r.ReqId != ""
}
func (r *RpcMessage) Validate() error {
if r.ReqId != "" && r.ResId != "" {
return fmt.Errorf("request packets may not have both reqid and resid set")
}
if r.Cancel {
if r.Command != "" {
return fmt.Errorf("cancel packets may not have command set")
}
if r.ReqId == "" && r.ResId == "" {
return fmt.Errorf("cancel packets must have reqid or resid set")
}
if r.Data != nil {
return fmt.Errorf("cancel packets may not have data set")
}
return nil
}
if r.Command != "" {
if r.ResId != "" {
return fmt.Errorf("command packets may not have resid set")
}
if r.Error != "" {
return fmt.Errorf("command packets may not have error set")
}
if r.DataType != "" {
return fmt.Errorf("command packets may not have datatype set")
}
return nil
}
if r.ReqId != "" {
if r.ResId == "" {
return fmt.Errorf("request packets must have resid set")
}
if r.Timeout != 0 {
return fmt.Errorf("non-command request packets may not have timeout set")
}
return nil
}
if r.ResId != "" {
if r.Command != "" {
return fmt.Errorf("response packets may not have command set")
}
if r.ReqId == "" {
return fmt.Errorf("response packets must have reqid set")
}
if r.Timeout != 0 {
return fmt.Errorf("response packets may not have timeout set")
}
return nil
}
return fmt.Errorf("invalid packet: must have command, reqid, or resid set")
}
type rpcData struct {
ResCh chan *RpcMessage
Ctx context.Context
}
func validateServerImpl(serverImpl ServerImpl) {
if serverImpl == nil {
return
}
serverType := reflect.TypeOf(serverImpl)
if serverType.Kind() != reflect.Pointer && serverType.Elem().Kind() != reflect.Struct {
panic(fmt.Sprintf("serverImpl must be a pointer to struct, got %v", serverType))
}
}
// closes outputCh when inputCh is closed/done
func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcContext, serverImpl ServerImpl) *WshRpc {
if inputCh == nil {
inputCh = make(chan []byte, DefaultInputChSize)
}
if outputCh == nil {
outputCh = make(chan []byte, DefaultOutputChSize)
}
validateServerImpl(serverImpl)
rtn := &WshRpc{
Lock: &sync.Mutex{},
clientId: uuid.New().String(),
InputCh: inputCh,
OutputCh: outputCh,
RpcMap: make(map[string]*rpcData),
RpcContext: &atomic.Pointer[wshrpc.RpcContext]{},
ServerImpl: serverImpl,
ResponseHandlerMap: make(map[string]*RpcResponseHandler),
}
rtn.RpcContext.Store(&rpcCtx)
go rtn.runServer()
return rtn
}
func (w *WshRpc) ClientId() string {
return w.clientId
}
func (w *WshRpc) SendEvent(event wshrpc.WaveEvent) {
// for wps compatibility
msg := &RpcMessage{
Command: wshrpc.Command_EventPublish,
Data: event,
}
barr, err := json.Marshal(msg)
if err != nil {
log.Printf("error marshalling event: %v\n", err)
return
}
w.OutputCh <- barr
}
func (w *WshRpc) GetRpcContext() wshrpc.RpcContext {
rtnPtr := w.RpcContext.Load()
return *rtnPtr
}
func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) {
w.RpcContext.Store(&ctx)
}
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
w.Lock.Lock()
defer w.Lock.Unlock()
w.ResponseHandlerMap[reqId] = handler
}
func (w *WshRpc) unregisterResponseHandler(reqId string) {
w.Lock.Lock()
defer w.Lock.Unlock()
delete(w.ResponseHandlerMap, reqId)
}
func (w *WshRpc) cancelRequest(reqId string) {
if reqId == "" {
return
}
w.Lock.Lock()
defer w.Lock.Unlock()
handler := w.ResponseHandlerMap[reqId]
if handler != nil {
handler.canceled.Store(true)
}
}
func (w *WshRpc) handleRequest(req *RpcMessage) {
var respHandler *RpcResponseHandler
defer func() {
if r := recover(); r != nil {
log.Printf("panic in handleRequest: %v\n", r)
debug.PrintStack()
if respHandler != nil {
respHandler.SendResponseError(fmt.Errorf("panic: %v", r))
}
}
}()
timeoutMs := req.Timeout
if timeoutMs <= 0 {
timeoutMs = DefaultTimeoutMs
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
ctx = withWshRpcContext(ctx, w)
respHandler = &RpcResponseHandler{
w: w,
ctx: ctx,
reqId: req.ReqId,
command: req.Command,
commandData: req.Data,
source: req.Source,
done: &atomic.Bool{},
canceled: &atomic.Bool{},
contextCancelFn: &atomic.Pointer[context.CancelFunc]{},
rpcCtx: w.GetRpcContext(),
}
respHandler.contextCancelFn.Store(&cancelFn)
respHandler.ctx = withRespHandler(ctx, respHandler)
w.registerResponseHandler(req.ReqId, respHandler)
isAsync := false
defer func() {
if r := recover(); r != nil {
log.Printf("panic in handleRequest: %v\n", r)
debug.PrintStack()
respHandler.SendResponseError(fmt.Errorf("panic: %v", r))
}
if isAsync {
go func() {
<-ctx.Done()
respHandler.Finalize()
}()
} else {
cancelFn()
respHandler.Finalize()
}
}()
handlerFn := serverImplAdapter(w.ServerImpl)
isAsync = !handlerFn(respHandler)
}
func (w *WshRpc) runServer() {
defer close(w.OutputCh)
for msgBytes := range w.InputCh {
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
if err != nil {
log.Printf("wshrpc received bad message: %v\n", err)
continue
}
if msg.Cancel {
if msg.ReqId != "" {
w.cancelRequest(msg.ReqId)
}
continue
}
if msg.IsRpcRequest() {
w.handleRequest(&msg)
} else {
respCh := w.getResponseCh(msg.ResId)
if respCh == nil {
continue
}
respCh <- &msg
if !msg.Cont {
w.unregisterRpc(msg.ResId, nil)
}
}
}
}
func (w *WshRpc) getResponseCh(resId string) chan *RpcMessage {
if resId == "" {
return nil
}
w.Lock.Lock()
defer w.Lock.Unlock()
rd := w.RpcMap[resId]
if rd == nil {
return nil
}
return rd.ResCh
}
func (w *WshRpc) SetServerImpl(serverImpl ServerImpl) {
validateServerImpl(serverImpl)
w.Lock.Lock()
defer w.Lock.Unlock()
w.ServerImpl = serverImpl
}
func (w *WshRpc) registerRpc(ctx context.Context, reqId string) chan *RpcMessage {
w.Lock.Lock()
defer w.Lock.Unlock()
rpcCh := make(chan *RpcMessage, RespChSize)
w.RpcMap[reqId] = &rpcData{
ResCh: rpcCh,
Ctx: ctx,
}
go func() {
<-ctx.Done()
w.unregisterRpc(reqId, fmt.Errorf("EC-TIME: timeout waiting for response"))
}()
return rpcCh
}
func (w *WshRpc) unregisterRpc(reqId string, err error) {
w.Lock.Lock()
defer w.Lock.Unlock()
rd := w.RpcMap[reqId]
if rd == nil {
return
}
if err != nil {
errResp := &RpcMessage{
ResId: reqId,
Error: err.Error(),
}
rd.ResCh <- errResp
}
delete(w.RpcMap, reqId)
close(rd.ResCh)
}
// no response
func (w *WshRpc) SendCommand(command string, data any, opts *wshrpc.RpcOpts) error {
var optsCopy wshrpc.RpcOpts
if opts != nil {
optsCopy = *opts
}
optsCopy.NoResponse = true
optsCopy.Timeout = 0
handler, err := w.SendComplexRequest(command, data, &optsCopy)
if err != nil {
return err
}
handler.finalize()
return nil
}
// single response
func (w *WshRpc) SendRpcRequest(command string, data any, opts *wshrpc.RpcOpts) (any, error) {
var optsCopy wshrpc.RpcOpts
if opts != nil {
optsCopy = *opts
}
optsCopy.NoResponse = false
handler, err := w.SendComplexRequest(command, data, &optsCopy)
if err != nil {
return nil, err
}
defer handler.finalize()
return handler.NextResponse()
}
type RpcRequestHandler struct {
w *WshRpc
ctx context.Context
ctxCancelFn *atomic.Pointer[context.CancelFunc]
reqId string
respCh chan *RpcMessage
cachedResp *RpcMessage
}
func (handler *RpcRequestHandler) Context() context.Context {
return handler.ctx
}
func (handler *RpcRequestHandler) SendCancel() {
defer func() {
if r := recover(); r != nil {
// this is likely a write to closed channel
log.Printf("panic in SendCancel: %v\n", r)
}
}()
msg := &RpcMessage{
Cancel: true,
ReqId: handler.reqId,
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
handler.finalize()
}
func (handler *RpcRequestHandler) ResponseDone() bool {
if handler.cachedResp != nil {
return false
}
select {
case msg, more := <-handler.respCh:
if !more {
return true
}
handler.cachedResp = msg
return false
default:
return false
}
}
func (handler *RpcRequestHandler) NextResponse() (any, error) {
var resp *RpcMessage
if handler.cachedResp != nil {
resp = handler.cachedResp
handler.cachedResp = nil
} else {
resp = <-handler.respCh
}
if resp == nil {
return nil, errors.New("response channel closed")
}
if resp.Error != "" {
return nil, errors.New(resp.Error)
}
return resp.Data, nil
}
func (handler *RpcRequestHandler) finalize() {
cancelFnPtr := handler.ctxCancelFn.Load()
if cancelFnPtr != nil && *cancelFnPtr != nil {
(*cancelFnPtr)()
handler.ctxCancelFn.Store(nil)
}
if handler.reqId != "" {
handler.w.unregisterRpc(handler.reqId, nil)
}
}
type RpcResponseHandler struct {
w *WshRpc
ctx context.Context
contextCancelFn *atomic.Pointer[context.CancelFunc]
reqId string
source string
command string
commandData any
rpcCtx wshrpc.RpcContext
canceled *atomic.Bool // canceled by requestor
done *atomic.Bool
}
func (handler *RpcResponseHandler) Context() context.Context {
return handler.ctx
}
func (handler *RpcResponseHandler) GetCommand() string {
return handler.command
}
func (handler *RpcResponseHandler) GetCommandRawData() any {
return handler.commandData
}
func (handler *RpcResponseHandler) GetRpcContext() wshrpc.RpcContext {
return handler.rpcCtx
}
func (handler *RpcResponseHandler) GetSource() string {
return handler.source
}
func (handler *RpcResponseHandler) NeedsResponse() bool {
return handler.reqId != ""
}
func (handler *RpcResponseHandler) SendMessage(msg string) {
rpcMsg := &RpcMessage{
Command: wshrpc.Command_Message,
Data: wshrpc.CommandMessageData{
Message: msg,
},
}
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
handler.w.OutputCh <- msgBytes
}
func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
defer func() {
if r := recover(); r != nil {
// this is likely a write to closed channel
log.Printf("panic in SendResponse: %v\n", r)
handler.close()
}
}()
if handler.reqId == "" {
return nil // no response expected
}
if handler.done.Load() {
return fmt.Errorf("request already done, cannot send additional response")
}
if done {
defer handler.close()
}
msg := &RpcMessage{
ResId: handler.reqId,
Data: data,
Cont: !done,
}
barr, err := json.Marshal(msg)
if err != nil {
return err
}
handler.w.OutputCh <- barr
return nil
}
func (handler *RpcResponseHandler) SendResponseError(err error) {
defer func() {
if r := recover(); r != nil {
// this is likely a write to closed channel
log.Printf("panic in SendResponseError: %v\n", r)
handler.close()
}
}()
if handler.reqId == "" || handler.done.Load() {
return
}
defer handler.close()
msg := &RpcMessage{
ResId: handler.reqId,
Error: err.Error(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
}
func (handler *RpcResponseHandler) IsCanceled() bool {
return handler.canceled.Load()
}
func (handler *RpcResponseHandler) close() {
cancelFn := handler.contextCancelFn.Load()
if cancelFn != nil && *cancelFn != nil {
(*cancelFn)()
handler.contextCancelFn.Store(nil)
}
handler.done.Store(true)
}
// if async, caller must call finalize
func (handler *RpcResponseHandler) Finalize() {
if handler.reqId == "" || handler.done.Load() {
return
}
handler.SendResponse(nil, true)
handler.close()
handler.w.unregisterResponseHandler(handler.reqId)
}
func (handler *RpcResponseHandler) IsDone() bool {
return handler.done.Load()
}
func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOpts) (rtnHandler *RpcRequestHandler, rtnErr error) {
if opts == nil {
opts = &wshrpc.RpcOpts{}
}
timeoutMs := opts.Timeout
if timeoutMs <= 0 {
timeoutMs = DefaultTimeoutMs
}
defer func() {
if r := recover(); r != nil {
log.Printf("panic in SendComplexRequest: %v\n", r)
rtnErr = fmt.Errorf("panic: %v", r)
}
}()
if command == "" {
return nil, fmt.Errorf("command cannot be empty")
}
handler := &RpcRequestHandler{
w: w,
ctxCancelFn: &atomic.Pointer[context.CancelFunc]{},
}
var cancelFn context.CancelFunc
handler.ctx, cancelFn = context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
handler.ctxCancelFn.Store(&cancelFn)
if !opts.NoResponse {
handler.reqId = uuid.New().String()
}
req := &RpcMessage{
Command: command,
ReqId: handler.reqId,
Data: data,
Timeout: timeoutMs,
Route: opts.Route,
}
barr, err := json.Marshal(req)
if err != nil {
return nil, err
}
handler.respCh = w.registerRpc(handler.ctx, handler.reqId)
w.OutputCh <- barr
return handler, nil
}