waveterm/pkg/wshrpc/rpc_server.go
2024-06-03 14:10:36 -07:00

300 lines
7.5 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshprc
import (
"context"
"fmt"
"log"
"runtime/debug"
"sync"
"sync/atomic"
"time"
)
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
type StreamCommandHandlerFn func(context.Context, *RpcServer, *RpcPacket) error
type RpcServer struct {
CVar *sync.Cond
NextSeqNum *atomic.Int64
RespPacketsInFlight map[int64]string // seqnum -> rpcId
AckList []int64
RpcReqs map[string]*RpcInfo
SendCh chan *RpcPacket
RecvCh chan *RpcPacket
SimpleCommandHandlers map[string]SimpleCommandHandlerFn
StreamCommandHandlers map[string]StreamCommandHandlerFn
}
func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
if cap(sendCh) < MaxInFlightPackets {
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
}
rtn := &RpcServer{
CVar: sync.NewCond(&sync.Mutex{}),
NextSeqNum: &atomic.Int64{},
RespPacketsInFlight: make(map[int64]string),
AckList: nil,
RpcReqs: make(map[string]*RpcInfo),
SendCh: sendCh,
RecvCh: recvCh,
SimpleCommandHandlers: make(map[string]SimpleCommandHandlerFn),
StreamCommandHandlers: make(map[string]StreamCommandHandlerFn),
}
go rtn.runRecvLoop()
return rtn
}
func (s *RpcServer) shouldUseStreamHandler(command string) bool {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
_, ok := s.StreamCommandHandlers[command]
return ok
}
func (s *RpcServer) getStreamHandler(command string) StreamCommandHandlerFn {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
return s.StreamCommandHandlers[command]
}
func (s *RpcServer) getSimpleHandler(command string) SimpleCommandHandlerFn {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
return s.SimpleCommandHandlers[command]
}
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
if s.StreamCommandHandlers[command] != nil {
panic(fmt.Errorf("command %q already registered as a stream handler", command))
}
s.SimpleCommandHandlers[command] = handler
}
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
if s.SimpleCommandHandlers[command] != nil {
panic(fmt.Errorf("command %q already registered as a simple handler", command))
}
s.StreamCommandHandlers[command] = handler
}
func (s *RpcServer) runRecvLoop() {
defer func() {
if r := recover(); r != nil {
log.Printf("RpcServer.runRecvLoop() panic: %v", r)
debug.PrintStack()
}
}()
for pk := range s.RecvCh {
s.handleAcks(pk.Acks)
if pk.RpcType == RpcType_Req {
if s.shouldUseStreamHandler(pk.Command) {
s.handleStreamReq(pk)
} else {
s.handleSimpleReq(pk)
}
continue
}
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
}
log.Printf("RpcServer.runRecvLoop() normal exit")
}
func (s *RpcServer) ackResp(seqNum int64) {
if seqNum == 0 {
return
}
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
s.AckList = append(s.AckList, seqNum)
}
func makeContextFromTimeout(timeout *TimeoutInfo) (context.Context, context.CancelFunc) {
if timeout == nil {
return context.Background(), func() {}
}
return context.WithDeadline(context.Background(), time.UnixMilli(timeout.Deadline))
}
func (s *RpcServer) SendResponse(ctx context.Context, pk *RpcPacket) error {
return s.waitForSend(ctx, pk)
}
func (s *RpcServer) waitForSend(ctx context.Context, pk *RpcPacket) error {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
for {
if ctx.Err() != nil {
return ctx.Err()
}
if len(s.RespPacketsInFlight) >= MaxInFlightPackets {
s.CVar.Wait()
continue
}
rpcInfo := s.RpcReqs[pk.RpcId]
if rpcInfo != nil {
if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
s.CVar.Wait()
continue
}
}
break
}
s.RespPacketsInFlight[pk.SeqNum] = pk.RpcId
pk.Acks = s.grabAcks_nolock()
s.SendCh <- pk
rpcInfo := s.RpcReqs[pk.RpcId]
if !pk.RespDone && rpcInfo != nil {
rpcInfo = &RpcInfo{
CloseSync: &sync.Once{},
RpcId: pk.RpcId,
PkCh: make(chan *RpcPacket, MaxUnackedPerRpc),
PacketsInFlight: make(map[int64]bool),
}
s.RpcReqs[pk.RpcId] = rpcInfo
}
if rpcInfo != nil {
rpcInfo.PacketsInFlight[pk.SeqNum] = true
}
if pk.RespDone {
delete(s.RpcReqs, pk.RpcId)
}
return nil
}
func (s *RpcServer) handleAcks(acks []int64) {
if len(acks) == 0 {
return
}
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
for _, ack := range acks {
rpcId, ok := s.RespPacketsInFlight[ack]
if !ok {
continue
}
rpcInfo := s.RpcReqs[rpcId]
if rpcInfo != nil {
delete(rpcInfo.PacketsInFlight, ack)
}
delete(s.RespPacketsInFlight, ack)
}
s.CVar.Broadcast()
}
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
s.ackResp(pk.SeqNum)
handler := s.getSimpleHandler(pk.Command)
if handler == nil {
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
return
}
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("RpcServer.handleReq(%q) panic: %v", pk.Command, r)
debug.PrintStack()
}
}()
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
defer cancelFn()
data, err := handler(ctx, s, pk.Command, pk.Data)
seqNum := s.NextSeqNum.Add(1)
respPk := &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: seqNum,
RespDone: true,
}
if err != nil {
respPk.Error = err.Error()
} else {
respPk.Data = data
}
s.waitForSend(ctx, respPk)
}()
}
func (s *RpcServer) grabAcks_nolock() []int64 {
acks := s.AckList
s.AckList = nil
return acks
}
func (s *RpcServer) sendErrorResp(pk *RpcPacket, err error) {
respPk := &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: s.NextSeqNum.Add(1),
RespDone: true,
Error: err.Error(),
}
s.waitForSend(context.Background(), respPk)
}
func (s *RpcServer) makeRespPk(pk *RpcPacket, data any, done bool) *RpcPacket {
return &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: s.NextSeqNum.Add(1),
RespDone: done,
Data: data,
}
}
func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
s.ackResp(pk.SeqNum)
handler := s.getStreamHandler(pk.Command)
if handler == nil {
s.ackResp(pk.SeqNum)
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
return
}
go func() {
defer func() {
r := recover()
if r == nil {
return
}
log.Printf("RpcServer.handleStreamReq(%q) panic: %v", pk.Command, r)
debug.PrintStack()
respPk := &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: s.NextSeqNum.Add(1),
RespDone: true,
Error: fmt.Sprintf("panic: %v", r),
}
s.waitForSend(context.Background(), respPk)
}()
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
defer cancelFn()
err := handler(ctx, s, pk)
if err != nil {
respPk := &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: s.NextSeqNum.Add(1),
RespDone: true,
Error: err.Error(),
}
s.waitForSend(ctx, respPk)
return
}
// check if RespDone has been set, if not, send it here
}()
}