waveterm/pkg/wshrpc/rpc_client.go
Evan Simkowitz f12e246c15
Break layout node into its own Wave Object (#21)
I am updating the layout node setup to write to its own wave object. 

The existing setup requires me to plumb the layout updates through every
time the tab gets updated, which produces a lot of annoying and
unintuitive design patterns. With this new setup, the tab object doesn't
get written to when the layout changes, only the layout object will get
written to. This prevents collisions when both the tab object and the
layout node object are getting updated, such as when a new block is
added or deleted.
2024-06-05 17:21:40 -07:00

265 lines
6.1 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshprc
import (
"context"
"errors"
"fmt"
"log"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
)
// there is a single go-routine that reads from RecvCh
type RpcClient struct {
CVar *sync.Cond
NextSeqNum *atomic.Int64
ReqPacketsInFlight map[int64]string // seqnum -> rpcId
AckList []int64
RpcReqs map[string]*RpcInfo
SendCh chan *RpcPacket
RecvCh chan *RpcPacket
}
type RpcInfo struct {
CloseSync *sync.Once
RpcId string
PacketsInFlight map[int64]bool // seqnum -> bool (for clients this is for requests, for servers it is for responses)
PkCh chan *RpcPacket // for clients this is for responses, for servers it is for requests
}
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
if cap(sendCh) < MaxInFlightPackets {
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
}
rtn := &RpcClient{
CVar: sync.NewCond(&sync.Mutex{}),
NextSeqNum: &atomic.Int64{},
ReqPacketsInFlight: make(map[int64]string),
AckList: nil,
RpcReqs: make(map[string]*RpcInfo),
SendCh: sendCh,
RecvCh: recvCh,
}
go rtn.runRecvLoop()
return rtn
}
func (c *RpcClient) runRecvLoop() {
defer func() {
if r := recover(); r != nil {
log.Printf("RpcClient.runRecvLoop() panic: %v", r)
debug.PrintStack()
}
}()
for pk := range c.RecvCh {
if pk.RpcType == RpcType_Resp {
c.handleResp(pk)
continue
}
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
}
log.Printf("RpcClient.runRecvLoop() normal exit")
}
func (c *RpcClient) getRpcInfo(rpcId string) *RpcInfo {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
return c.RpcReqs[rpcId]
}
func (c *RpcClient) handleResp(pk *RpcPacket) {
c.handleAcks(pk.Acks)
if pk.RpcId == "" {
c.ackResp(pk.SeqNum)
log.Printf("RpcClient.handleResp() missing rpcId: %v", pk)
return
}
rpcInfo := c.getRpcInfo(pk.RpcId)
if rpcInfo == nil {
c.ackResp(pk.SeqNum)
log.Printf("RpcClient.handleResp() unknown rpcId: %v", pk)
return
}
select {
case rpcInfo.PkCh <- pk:
default:
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
}
if pk.RespDone {
c.removeReqInfo(pk.RpcId, false)
}
}
func (c *RpcClient) grabAcks() []int64 {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
acks := c.AckList
c.AckList = nil
return acks
}
func (c *RpcClient) ackResp(seqNum int64) {
if seqNum == 0 {
return
}
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
c.AckList = append(c.AckList, seqNum)
}
func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, error) {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
// issue with ctx timeout sync -- we need the cvar to be signaled fairly regularly so we can check ctx.Err()
for {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if len(c.RpcReqs) >= MaxOpenRpcs {
c.CVar.Wait()
continue
}
if len(c.ReqPacketsInFlight) >= MaxOpenRpcs {
c.CVar.Wait()
continue
}
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
c.CVar.Wait()
continue
}
}
break
}
select {
case c.SendCh <- req:
default:
return nil, errors.New("SendCh Full")
}
c.ReqPacketsInFlight[req.SeqNum] = req.RpcId
rpcInfo := c.RpcReqs[req.RpcId]
if rpcInfo == nil {
rpcInfo = &RpcInfo{
CloseSync: &sync.Once{},
RpcId: req.RpcId,
PacketsInFlight: make(map[int64]bool),
PkCh: make(chan *RpcPacket, MaxUnackedPerRpc),
}
rpcInfo.PacketsInFlight[req.SeqNum] = true
c.RpcReqs[req.RpcId] = rpcInfo
}
return rpcInfo, nil
}
func (c *RpcClient) handleAcks(acks []int64) {
if len(acks) == 0 {
return
}
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
for _, ack := range acks {
rpcId, ok := c.ReqPacketsInFlight[ack]
if !ok {
continue
}
rpcInfo := c.RpcReqs[rpcId]
if rpcInfo != nil {
delete(rpcInfo.PacketsInFlight, ack)
}
delete(c.ReqPacketsInFlight, ack)
}
c.CVar.Broadcast()
}
func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
rpcInfo := c.RpcReqs[rpcId]
delete(c.RpcReqs, rpcId)
if rpcInfo != nil {
if clearSend {
// unblock the recv loop if it happens to be waiting
// because the delete has already happens, it will not be able to send again on the channel
select {
case <-rpcInfo.PkCh:
default:
}
}
rpcInfo.CloseSync.Do(func() {
close(rpcInfo.PkCh)
})
}
}
func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (any, error) {
rpcId := uuid.NewString()
seqNum := c.NextSeqNum.Add(1)
var timeoutInfo *TimeoutInfo
deadline, ok := ctx.Deadline()
if ok {
timeoutInfo = &TimeoutInfo{Deadline: deadline.UnixMilli()}
}
req := &RpcPacket{
Command: command,
RpcId: rpcId,
RpcType: RpcType_Req,
SeqNum: seqNum,
ReqDone: true,
Acks: c.grabAcks(),
Timeout: timeoutInfo,
Data: data,
}
rpcInfo, err := c.waitForReq(ctx, req)
if err != nil {
return nil, err
}
defer c.removeReqInfo(rpcId, true)
var rtnPacket *RpcPacket
select {
case <-ctx.Done():
return nil, ctx.Err()
case rtnPacket = <-rpcInfo.PkCh:
// fallthrough
}
if rtnPacket.Error != "" {
return nil, errors.New(rtnPacket.Error)
}
return rtnPacket.Data, nil
}
func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, respTimeout time.Duration) (chan *RpcPacket, error) {
rpcId := uuid.NewString()
seqNum := c.NextSeqNum.Add(1)
var timeoutInfo *TimeoutInfo = &TimeoutInfo{RespPacketTimeout: respTimeout.Milliseconds()}
deadline, ok := ctx.Deadline()
if ok {
timeoutInfo.Deadline = deadline.UnixMilli()
}
req := &RpcPacket{
Command: command,
RpcId: rpcId,
RpcType: RpcType_Req,
SeqNum: seqNum,
ReqDone: true,
Acks: c.grabAcks(),
Timeout: timeoutInfo,
Data: data,
}
rpcInfo, err := c.waitForReq(ctx, req)
if err != nil {
return nil, err
}
return rpcInfo.PkCh, nil
}
func (c *RpcClient) EndStreamReq(rpcId string) {
c.removeReqInfo(rpcId, true)
}