diff --git a/pkg/wshrpc/rpc_client.go b/pkg/wshrpc/rpc_client.go index f0388d1e7..425fae7c6 100644 --- a/pkg/wshrpc/rpc_client.go +++ b/pkg/wshrpc/rpc_client.go @@ -28,10 +28,10 @@ type RpcClient struct { } type RpcInfo struct { - CloseSync *sync.Once - RpcId string - ReqPacketsInFlight map[int64]bool // seqnum -> bool - RespCh chan *RpcPacket + CloseSync *sync.Once + RpcId string + PacketsInFlight map[int64]bool // seqnum -> bool (for clients this is for requests, for servers it is for responses) + PkCh chan *RpcPacket // for clients this is for responses, for servers it is for requests } func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient { @@ -88,7 +88,7 @@ func (c *RpcClient) handleResp(pk *RpcPacket) { return } select { - case rpcInfo.RespCh <- pk: + case rpcInfo.PkCh <- pk: default: log.Printf("RpcClient.handleResp() respCh full, dropping packet") } @@ -131,7 +131,7 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e continue } if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok { - if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc { + if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc { c.CVar.Wait() continue } @@ -147,12 +147,12 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e rpcInfo := c.RpcReqs[req.RpcId] if rpcInfo == nil { rpcInfo = &RpcInfo{ - CloseSync: &sync.Once{}, - RpcId: req.RpcId, - ReqPacketsInFlight: make(map[int64]bool), - RespCh: make(chan *RpcPacket, MaxUnackedPerRpc), + CloseSync: &sync.Once{}, + RpcId: req.RpcId, + PacketsInFlight: make(map[int64]bool), + PkCh: make(chan *RpcPacket, MaxUnackedPerRpc), } - rpcInfo.ReqPacketsInFlight[req.SeqNum] = true + rpcInfo.PacketsInFlight[req.SeqNum] = true c.RpcReqs[req.RpcId] = rpcInfo } return rpcInfo, nil @@ -171,7 +171,7 @@ func (c *RpcClient) handleAcks(acks []int64) { } rpcInfo := c.RpcReqs[rpcId] if rpcInfo != nil { - delete(rpcInfo.ReqPacketsInFlight, ack) + delete(rpcInfo.PacketsInFlight, ack) } delete(c.ReqPacketsInFlight, ack) } @@ -188,12 +188,12 @@ func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) { // unblock the recv loop if it happens to be waiting // because the delete has already happens, it will not be able to send again on the channel select { - case <-rpcInfo.RespCh: + case <-rpcInfo.PkCh: default: } } rpcInfo.CloseSync.Do(func() { - close(rpcInfo.RespCh) + close(rpcInfo.PkCh) }) } } @@ -225,7 +225,7 @@ func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (an select { case <-ctx.Done(): return nil, ctx.Err() - case rtnPacket = <-rpcInfo.RespCh: + case rtnPacket = <-rpcInfo.PkCh: // fallthrough } if rtnPacket.Error != "" { @@ -256,7 +256,7 @@ func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, res if err != nil { return nil, err } - return rpcInfo.RespCh, nil + return rpcInfo.PkCh, nil } func (c *RpcClient) EndStreamReq(rpcId string) { diff --git a/pkg/wshrpc/rpc_server.go b/pkg/wshrpc/rpc_server.go new file mode 100644 index 000000000..8ba3a1bb0 --- /dev/null +++ b/pkg/wshrpc/rpc_server.go @@ -0,0 +1,247 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshprc + +import ( + "context" + "fmt" + "log" + "runtime/debug" + "sync" + "sync/atomic" + "time" +) + +type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error) +type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error + +type RpcServer struct { + CVar *sync.Cond + NextSeqNum *atomic.Int64 + RespPacketsInFlight map[int64]string // seqnum -> rpcId + AckList []int64 + RpcReqs map[string]*RpcInfo + SendCh chan *RpcPacket + RecvCh chan *RpcPacket + SimpleCommandHandlers map[string]SimpleCommandHandlerFn + StreamCommandHandlers map[string]StreamCommandHandlerFn +} + +func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer { + if cap(sendCh) < MaxInFlightPackets { + panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets)) + } + rtn := &RpcServer{ + CVar: sync.NewCond(&sync.Mutex{}), + NextSeqNum: &atomic.Int64{}, + RespPacketsInFlight: make(map[int64]string), + AckList: nil, + RpcReqs: make(map[string]*RpcInfo), + SendCh: sendCh, + RecvCh: recvCh, + } + go rtn.runRecvLoop() + return rtn +} + +func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) { + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + s.SimpleCommandHandlers[command] = handler +} + +func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) { + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + s.StreamCommandHandlers[command] = handler +} + +func (s *RpcServer) runRecvLoop() { + defer func() { + if r := recover(); r != nil { + log.Printf("RpcServer.runRecvLoop() panic: %v", r) + debug.PrintStack() + } + }() + for pk := range s.RecvCh { + s.handleAcks(pk.Acks) + if pk.RpcType == RpcType_Req { + if pk.ReqDone { + s.handleSimpleReq(pk) + } else { + s.handleStreamReq(pk) + } + continue + } + log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk) + } + log.Printf("RpcServer.runRecvLoop() normal exit") +} + +func (s *RpcServer) ackResp(seqNum int64) { + if seqNum == 0 { + return + } + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + s.AckList = append(s.AckList, seqNum) +} + +func makeContextFromTimeout(timeout *TimeoutInfo) (context.Context, context.CancelFunc) { + if timeout == nil { + return context.Background(), func() {} + } + return context.WithDeadline(context.Background(), time.UnixMilli(timeout.Deadline)) +} + +func (s *RpcServer) SendResponse(ctx context.Context, pk *RpcPacket) error { + return s.waitForSend(ctx, pk) +} + +func (s *RpcServer) waitForSend(ctx context.Context, pk *RpcPacket) error { + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + for { + if ctx.Err() != nil { + return ctx.Err() + } + if len(s.RespPacketsInFlight) >= MaxInFlightPackets { + s.CVar.Wait() + continue + } + rpcInfo := s.RpcReqs[pk.RpcId] + if rpcInfo != nil { + if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc { + s.CVar.Wait() + continue + } + } + break + } + s.RespPacketsInFlight[pk.SeqNum] = pk.RpcId + pk.Acks = s.grabAcks_nolock() + s.SendCh <- pk + rpcInfo := s.RpcReqs[pk.RpcId] + if !pk.RespDone && rpcInfo != nil { + rpcInfo = &RpcInfo{ + CloseSync: &sync.Once{}, + RpcId: pk.RpcId, + PkCh: make(chan *RpcPacket, MaxUnackedPerRpc), + PacketsInFlight: make(map[int64]bool), + } + s.RpcReqs[pk.RpcId] = rpcInfo + } + if rpcInfo != nil { + rpcInfo.PacketsInFlight[pk.SeqNum] = true + } + if pk.RespDone { + delete(s.RpcReqs, pk.RpcId) + } + return nil +} + +func (s *RpcServer) handleAcks(acks []int64) { + if len(acks) == 0 { + return + } + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + for _, ack := range acks { + rpcId, ok := s.RespPacketsInFlight[ack] + if !ok { + continue + } + rpcInfo := s.RpcReqs[rpcId] + if rpcInfo != nil { + delete(rpcInfo.PacketsInFlight, ack) + } + delete(s.RespPacketsInFlight, ack) + } + s.CVar.Broadcast() +} + +func (s *RpcServer) handleSimpleReq(pk *RpcPacket) { + s.ackResp(pk.SeqNum) + handler, ok := s.SimpleCommandHandlers[pk.Command] + if !ok { + log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command) + return + } + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("RpcServer.handleReq(%q) panic: %v", pk.Command, r) + debug.PrintStack() + } + }() + ctx, cancelFn := makeContextFromTimeout(pk.Timeout) + defer cancelFn() + data, err := handler(ctx, s, pk.Command, pk.Data) + seqNum := s.NextSeqNum.Add(1) + respPk := &RpcPacket{ + Command: pk.Command, + RpcId: pk.RpcId, + RpcType: RpcType_Resp, + SeqNum: seqNum, + RespDone: true, + } + if err != nil { + respPk.Error = err.Error() + } else { + respPk.Data = data + } + s.waitForSend(ctx, respPk) + }() +} + +func (s *RpcServer) grabAcks_nolock() []int64 { + acks := s.AckList + s.AckList = nil + return acks +} + +func (s *RpcServer) handleStreamReq(pk *RpcPacket) { + s.ackResp(pk.SeqNum) + handler, ok := s.StreamCommandHandlers[pk.Command] + if !ok { + s.ackResp(pk.SeqNum) + log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command) + return + } + go func() { + defer func() { + r := recover() + if r == nil { + return + } + log.Printf("RpcServer.handleStreamReq(%q) panic: %v", pk.Command, r) + debug.PrintStack() + respPk := &RpcPacket{ + Command: pk.Command, + RpcId: pk.RpcId, + RpcType: RpcType_Resp, + SeqNum: s.NextSeqNum.Add(1), + RespDone: true, + Error: fmt.Sprintf("panic: %v", r), + } + s.waitForSend(context.Background(), respPk) + }() + ctx, cancelFn := makeContextFromTimeout(pk.Timeout) + defer cancelFn() + err := handler(ctx, s, pk.Command, pk.Data) + if err != nil { + respPk := &RpcPacket{ + Command: pk.Command, + RpcId: pk.RpcId, + RpcType: RpcType_Resp, + SeqNum: s.NextSeqNum.Add(1), + RespDone: true, + Error: err.Error(), + } + s.waitForSend(ctx, respPk) + return + } + // check if RespDone has been set, if not, send it here + }() +}