From dcd6d04b0bfafdb64146a5ba87c843a8cf9bed27 Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 3 Jun 2024 14:10:36 -0700 Subject: [PATCH] rpc checkpoint --- pkg/wshrpc/rpc_server.go | 84 ++++++++++++++++++++++++++++++++-------- pkg/wshrpc/rpc_test.go | 79 +++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 16 deletions(-) diff --git a/pkg/wshrpc/rpc_server.go b/pkg/wshrpc/rpc_server.go index 8ba3a1bb0..91a4d7aa6 100644 --- a/pkg/wshrpc/rpc_server.go +++ b/pkg/wshrpc/rpc_server.go @@ -14,7 +14,7 @@ import ( ) type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error) -type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error +type StreamCommandHandlerFn func(context.Context, *RpcServer, *RpcPacket) error type RpcServer struct { CVar *sync.Cond @@ -33,27 +33,54 @@ func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer { panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets)) } rtn := &RpcServer{ - CVar: sync.NewCond(&sync.Mutex{}), - NextSeqNum: &atomic.Int64{}, - RespPacketsInFlight: make(map[int64]string), - AckList: nil, - RpcReqs: make(map[string]*RpcInfo), - SendCh: sendCh, - RecvCh: recvCh, + CVar: sync.NewCond(&sync.Mutex{}), + NextSeqNum: &atomic.Int64{}, + RespPacketsInFlight: make(map[int64]string), + AckList: nil, + RpcReqs: make(map[string]*RpcInfo), + SendCh: sendCh, + RecvCh: recvCh, + SimpleCommandHandlers: make(map[string]SimpleCommandHandlerFn), + StreamCommandHandlers: make(map[string]StreamCommandHandlerFn), } go rtn.runRecvLoop() return rtn } +func (s *RpcServer) shouldUseStreamHandler(command string) bool { + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + _, ok := s.StreamCommandHandlers[command] + return ok +} + +func (s *RpcServer) getStreamHandler(command string) StreamCommandHandlerFn { + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + return s.StreamCommandHandlers[command] +} + +func (s *RpcServer) getSimpleHandler(command string) SimpleCommandHandlerFn { + s.CVar.L.Lock() + defer s.CVar.L.Unlock() + return s.SimpleCommandHandlers[command] +} + func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) { s.CVar.L.Lock() defer s.CVar.L.Unlock() + if s.StreamCommandHandlers[command] != nil { + panic(fmt.Errorf("command %q already registered as a stream handler", command)) + } s.SimpleCommandHandlers[command] = handler } func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) { s.CVar.L.Lock() defer s.CVar.L.Unlock() + if s.SimpleCommandHandlers[command] != nil { + panic(fmt.Errorf("command %q already registered as a simple handler", command)) + } s.StreamCommandHandlers[command] = handler } @@ -67,10 +94,10 @@ func (s *RpcServer) runRecvLoop() { for pk := range s.RecvCh { s.handleAcks(pk.Acks) if pk.RpcType == RpcType_Req { - if pk.ReqDone { - s.handleSimpleReq(pk) - } else { + if s.shouldUseStreamHandler(pk.Command) { s.handleStreamReq(pk) + } else { + s.handleSimpleReq(pk) } continue } @@ -163,8 +190,9 @@ func (s *RpcServer) handleAcks(acks []int64) { func (s *RpcServer) handleSimpleReq(pk *RpcPacket) { s.ackResp(pk.SeqNum) - handler, ok := s.SimpleCommandHandlers[pk.Command] - if !ok { + handler := s.getSimpleHandler(pk.Command) + if handler == nil { + s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command)) log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command) return } @@ -201,11 +229,35 @@ func (s *RpcServer) grabAcks_nolock() []int64 { return acks } +func (s *RpcServer) sendErrorResp(pk *RpcPacket, err error) { + respPk := &RpcPacket{ + Command: pk.Command, + RpcId: pk.RpcId, + RpcType: RpcType_Resp, + SeqNum: s.NextSeqNum.Add(1), + RespDone: true, + Error: err.Error(), + } + s.waitForSend(context.Background(), respPk) +} + +func (s *RpcServer) makeRespPk(pk *RpcPacket, data any, done bool) *RpcPacket { + return &RpcPacket{ + Command: pk.Command, + RpcId: pk.RpcId, + RpcType: RpcType_Resp, + SeqNum: s.NextSeqNum.Add(1), + RespDone: done, + Data: data, + } +} + func (s *RpcServer) handleStreamReq(pk *RpcPacket) { s.ackResp(pk.SeqNum) - handler, ok := s.StreamCommandHandlers[pk.Command] - if !ok { + handler := s.getStreamHandler(pk.Command) + if handler == nil { s.ackResp(pk.SeqNum) + s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command)) log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command) return } @@ -229,7 +281,7 @@ func (s *RpcServer) handleStreamReq(pk *RpcPacket) { }() ctx, cancelFn := makeContextFromTimeout(pk.Timeout) defer cancelFn() - err := handler(ctx, s, pk.Command, pk.Data) + err := handler(ctx, s, pk) if err != nil { respPk := &RpcPacket{ Command: pk.Command, diff --git a/pkg/wshrpc/rpc_test.go b/pkg/wshrpc/rpc_test.go index 4f8ed783b..cb2f77418 100644 --- a/pkg/wshrpc/rpc_test.go +++ b/pkg/wshrpc/rpc_test.go @@ -5,6 +5,8 @@ package wshprc import ( "context" + "fmt" + "log" "sync" "testing" "time" @@ -113,3 +115,80 @@ func TestStream(t *testing.T) { }() wg.Wait() } + +func TestSimpleClientServer(t *testing.T) { + sendCh := make(chan *RpcPacket, MaxInFlightPackets) + recvCh := make(chan *RpcPacket, MaxInFlightPackets) + client := MakeRpcClient(sendCh, recvCh) + server := MakeRpcServer(recvCh, sendCh) + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + server.RegisterSimpleCommandHandler("test", func(ctx context.Context, s *RpcServer, cmd string, data any) (any, error) { + if data != "hello" { + return nil, fmt.Errorf("expected 'hello', got '%s'", data) + } + return "world", nil + }) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := client.SimpleReq(ctx, "test", "hello") + if err != nil { + t.Errorf("SimpleReq() failed: %v", err) + return + } + if resp != "world" { + t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp) + } + }() + wg.Wait() + +} + +func TestStreamClientServer(t *testing.T) { + sendCh := make(chan *RpcPacket, MaxInFlightPackets) + recvCh := make(chan *RpcPacket, MaxInFlightPackets) + client := MakeRpcClient(sendCh, recvCh) + server := MakeRpcServer(recvCh, sendCh) + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + server.RegisterStreamCommandHandler("test", func(ctx context.Context, s *RpcServer, req *RpcPacket) error { + pk1 := s.makeRespPk(req, "one", false) + pk2 := s.makeRespPk(req, "two", false) + pk3 := s.makeRespPk(req, "three", true) + s.SendResponse(ctx, pk1) + s.SendResponse(ctx, pk2) + s.SendResponse(ctx, pk3) + return nil + }) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + respCh, err := client.StreamReq(ctx, "test", "hello", 2*time.Second) + if err != nil { + t.Errorf("StreamReq() failed: %v", err) + return + } + var result []string + for respPk := range respCh { + if respPk.Error != "" { + t.Errorf("StreamReq() failed: %v", respPk.Error) + return + } + log.Printf("got response: %#v", respPk) + result = append(result, respPk.Data.(string)) + } + if len(result) != 3 { + t.Errorf("expected 3 responses, got %d", len(result)) + return + } + if result[0] != "one" || result[1] != "two" || result[2] != "three" { + t.Errorf("expected 'one', 'two', 'three', got %v", result) + return + } + }() + wg.Wait() + +}