// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package wshprc import ( "context" "errors" "fmt" "log" "runtime/debug" "sync" "sync/atomic" "time" "github.com/google/uuid" ) // there is a single go-routine that reads from RecvCh type RpcClient struct { CVar *sync.Cond NextSeqNum *atomic.Int64 ReqPacketsInFlight map[int64]string // seqnum -> rpcId AckList []int64 RpcReqs map[string]*RpcInfo SendCh chan *RpcPacket RecvCh chan *RpcPacket } type RpcInfo struct { CloseSync *sync.Once RpcId string PacketsInFlight map[int64]bool // seqnum -> bool (for clients this is for requests, for servers it is for responses) PkCh chan *RpcPacket // for clients this is for responses, for servers it is for requests } func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient { if cap(sendCh) < MaxInFlightPackets { panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets)) } rtn := &RpcClient{ CVar: sync.NewCond(&sync.Mutex{}), NextSeqNum: &atomic.Int64{}, ReqPacketsInFlight: make(map[int64]string), AckList: nil, RpcReqs: make(map[string]*RpcInfo), SendCh: sendCh, RecvCh: recvCh, } go rtn.runRecvLoop() return rtn } func (c *RpcClient) runRecvLoop() { defer func() { if r := recover(); r != nil { log.Printf("RpcClient.runRecvLoop() panic: %v", r) debug.PrintStack() } }() for pk := range c.RecvCh { if pk.RpcType == RpcType_Resp { c.handleResp(pk) continue } log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk) } log.Printf("RpcClient.runRecvLoop() normal exit") } func (c *RpcClient) getRpcInfo(rpcId string) *RpcInfo { c.CVar.L.Lock() defer c.CVar.L.Unlock() return c.RpcReqs[rpcId] } func (c *RpcClient) handleResp(pk *RpcPacket) { c.handleAcks(pk.Acks) if pk.RpcId == "" { c.ackResp(pk.SeqNum) log.Printf("RpcClient.handleResp() missing rpcId: %v", pk) return } rpcInfo := c.getRpcInfo(pk.RpcId) if rpcInfo == nil { c.ackResp(pk.SeqNum) log.Printf("RpcClient.handleResp() unknown rpcId: %v", pk) return } select { case rpcInfo.PkCh <- pk: default: log.Printf("RpcClient.handleResp() respCh full, dropping packet") } if pk.RespDone { c.removeReqInfo(pk.RpcId, false) } } func (c *RpcClient) grabAcks() []int64 { c.CVar.L.Lock() defer c.CVar.L.Unlock() acks := c.AckList c.AckList = nil return acks } func (c *RpcClient) ackResp(seqNum int64) { if seqNum == 0 { return } c.CVar.L.Lock() defer c.CVar.L.Unlock() c.AckList = append(c.AckList, seqNum) } func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, error) { c.CVar.L.Lock() defer c.CVar.L.Unlock() // issue with ctx timeout sync -- we need the cvar to be signaled fairly regularly so we can check ctx.Err() for { if ctx.Err() != nil { return nil, ctx.Err() } if len(c.RpcReqs) >= MaxOpenRpcs { c.CVar.Wait() continue } if len(c.ReqPacketsInFlight) >= MaxOpenRpcs { c.CVar.Wait() continue } if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok { if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc { c.CVar.Wait() continue } } break } select { case c.SendCh <- req: default: return nil, errors.New("SendCh Full") } c.ReqPacketsInFlight[req.SeqNum] = req.RpcId rpcInfo := c.RpcReqs[req.RpcId] if rpcInfo == nil { rpcInfo = &RpcInfo{ CloseSync: &sync.Once{}, RpcId: req.RpcId, PacketsInFlight: make(map[int64]bool), PkCh: make(chan *RpcPacket, MaxUnackedPerRpc), } rpcInfo.PacketsInFlight[req.SeqNum] = true c.RpcReqs[req.RpcId] = rpcInfo } return rpcInfo, nil } func (c *RpcClient) handleAcks(acks []int64) { if len(acks) == 0 { return } c.CVar.L.Lock() defer c.CVar.L.Unlock() for _, ack := range acks { rpcId, ok := c.ReqPacketsInFlight[ack] if !ok { continue } rpcInfo := c.RpcReqs[rpcId] if rpcInfo != nil { delete(rpcInfo.PacketsInFlight, ack) } delete(c.ReqPacketsInFlight, ack) } c.CVar.Broadcast() } func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) { c.CVar.L.Lock() defer c.CVar.L.Unlock() rpcInfo := c.RpcReqs[rpcId] delete(c.RpcReqs, rpcId) if rpcInfo != nil { if clearSend { // unblock the recv loop if it happens to be waiting // because the delete has already happens, it will not be able to send again on the channel select { case <-rpcInfo.PkCh: default: } } rpcInfo.CloseSync.Do(func() { close(rpcInfo.PkCh) }) } } func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (any, error) { rpcId := uuid.NewString() seqNum := c.NextSeqNum.Add(1) var timeoutInfo *TimeoutInfo deadline, ok := ctx.Deadline() if ok { timeoutInfo = &TimeoutInfo{Deadline: deadline.UnixMilli()} } req := &RpcPacket{ Command: command, RpcId: rpcId, RpcType: RpcType_Req, SeqNum: seqNum, ReqDone: true, Acks: c.grabAcks(), Timeout: timeoutInfo, Data: data, } rpcInfo, err := c.waitForReq(ctx, req) if err != nil { return nil, err } defer c.removeReqInfo(rpcId, true) var rtnPacket *RpcPacket select { case <-ctx.Done(): return nil, ctx.Err() case rtnPacket = <-rpcInfo.PkCh: // fallthrough } if rtnPacket.Error != "" { return nil, errors.New(rtnPacket.Error) } return rtnPacket.Data, nil } func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, respTimeout time.Duration) (chan *RpcPacket, error) { rpcId := uuid.NewString() seqNum := c.NextSeqNum.Add(1) var timeoutInfo *TimeoutInfo = &TimeoutInfo{RespPacketTimeout: respTimeout.Milliseconds()} deadline, ok := ctx.Deadline() if ok { timeoutInfo.Deadline = deadline.UnixMilli() } req := &RpcPacket{ Command: command, RpcId: rpcId, RpcType: RpcType_Req, SeqNum: seqNum, ReqDone: true, Acks: c.grabAcks(), Timeout: timeoutInfo, Data: data, } rpcInfo, err := c.waitForReq(ctx, req) if err != nil { return nil, err } return rpcInfo.PkCh, nil } func (c *RpcClient) EndStreamReq(rpcId string) { c.removeReqInfo(rpcId, true) }