rpc checkpoint

This commit is contained in:
sawka 2024-06-03 14:10:36 -07:00
parent f90554e87e
commit dcd6d04b0b
2 changed files with 147 additions and 16 deletions

View File

@ -14,7 +14,7 @@ import (
) )
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error) type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error type StreamCommandHandlerFn func(context.Context, *RpcServer, *RpcPacket) error
type RpcServer struct { type RpcServer struct {
CVar *sync.Cond CVar *sync.Cond
@ -40,20 +40,47 @@ func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
RpcReqs: make(map[string]*RpcInfo), RpcReqs: make(map[string]*RpcInfo),
SendCh: sendCh, SendCh: sendCh,
RecvCh: recvCh, RecvCh: recvCh,
SimpleCommandHandlers: make(map[string]SimpleCommandHandlerFn),
StreamCommandHandlers: make(map[string]StreamCommandHandlerFn),
} }
go rtn.runRecvLoop() go rtn.runRecvLoop()
return rtn return rtn
} }
func (s *RpcServer) shouldUseStreamHandler(command string) bool {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
_, ok := s.StreamCommandHandlers[command]
return ok
}
func (s *RpcServer) getStreamHandler(command string) StreamCommandHandlerFn {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
return s.StreamCommandHandlers[command]
}
func (s *RpcServer) getSimpleHandler(command string) SimpleCommandHandlerFn {
s.CVar.L.Lock()
defer s.CVar.L.Unlock()
return s.SimpleCommandHandlers[command]
}
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) { func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
s.CVar.L.Lock() s.CVar.L.Lock()
defer s.CVar.L.Unlock() defer s.CVar.L.Unlock()
if s.StreamCommandHandlers[command] != nil {
panic(fmt.Errorf("command %q already registered as a stream handler", command))
}
s.SimpleCommandHandlers[command] = handler s.SimpleCommandHandlers[command] = handler
} }
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) { func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
s.CVar.L.Lock() s.CVar.L.Lock()
defer s.CVar.L.Unlock() defer s.CVar.L.Unlock()
if s.SimpleCommandHandlers[command] != nil {
panic(fmt.Errorf("command %q already registered as a simple handler", command))
}
s.StreamCommandHandlers[command] = handler s.StreamCommandHandlers[command] = handler
} }
@ -67,10 +94,10 @@ func (s *RpcServer) runRecvLoop() {
for pk := range s.RecvCh { for pk := range s.RecvCh {
s.handleAcks(pk.Acks) s.handleAcks(pk.Acks)
if pk.RpcType == RpcType_Req { if pk.RpcType == RpcType_Req {
if pk.ReqDone { if s.shouldUseStreamHandler(pk.Command) {
s.handleSimpleReq(pk)
} else {
s.handleStreamReq(pk) s.handleStreamReq(pk)
} else {
s.handleSimpleReq(pk)
} }
continue continue
} }
@ -163,8 +190,9 @@ func (s *RpcServer) handleAcks(acks []int64) {
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) { func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
s.ackResp(pk.SeqNum) s.ackResp(pk.SeqNum)
handler, ok := s.SimpleCommandHandlers[pk.Command] handler := s.getSimpleHandler(pk.Command)
if !ok { if handler == nil {
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command) log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
return return
} }
@ -201,11 +229,35 @@ func (s *RpcServer) grabAcks_nolock() []int64 {
return acks return acks
} }
func (s *RpcServer) sendErrorResp(pk *RpcPacket, err error) {
respPk := &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: s.NextSeqNum.Add(1),
RespDone: true,
Error: err.Error(),
}
s.waitForSend(context.Background(), respPk)
}
func (s *RpcServer) makeRespPk(pk *RpcPacket, data any, done bool) *RpcPacket {
return &RpcPacket{
Command: pk.Command,
RpcId: pk.RpcId,
RpcType: RpcType_Resp,
SeqNum: s.NextSeqNum.Add(1),
RespDone: done,
Data: data,
}
}
func (s *RpcServer) handleStreamReq(pk *RpcPacket) { func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
s.ackResp(pk.SeqNum) s.ackResp(pk.SeqNum)
handler, ok := s.StreamCommandHandlers[pk.Command] handler := s.getStreamHandler(pk.Command)
if !ok { if handler == nil {
s.ackResp(pk.SeqNum) s.ackResp(pk.SeqNum)
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command) log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
return return
} }
@ -229,7 +281,7 @@ func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
}() }()
ctx, cancelFn := makeContextFromTimeout(pk.Timeout) ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
defer cancelFn() defer cancelFn()
err := handler(ctx, s, pk.Command, pk.Data) err := handler(ctx, s, pk)
if err != nil { if err != nil {
respPk := &RpcPacket{ respPk := &RpcPacket{
Command: pk.Command, Command: pk.Command,

View File

@ -5,6 +5,8 @@ package wshprc
import ( import (
"context" "context"
"fmt"
"log"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -113,3 +115,80 @@ func TestStream(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
} }
func TestSimpleClientServer(t *testing.T) {
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
client := MakeRpcClient(sendCh, recvCh)
server := MakeRpcServer(recvCh, sendCh)
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
server.RegisterSimpleCommandHandler("test", func(ctx context.Context, s *RpcServer, cmd string, data any) (any, error) {
if data != "hello" {
return nil, fmt.Errorf("expected 'hello', got '%s'", data)
}
return "world", nil
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
resp, err := client.SimpleReq(ctx, "test", "hello")
if err != nil {
t.Errorf("SimpleReq() failed: %v", err)
return
}
if resp != "world" {
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
}
}()
wg.Wait()
}
func TestStreamClientServer(t *testing.T) {
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
client := MakeRpcClient(sendCh, recvCh)
server := MakeRpcServer(recvCh, sendCh)
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
server.RegisterStreamCommandHandler("test", func(ctx context.Context, s *RpcServer, req *RpcPacket) error {
pk1 := s.makeRespPk(req, "one", false)
pk2 := s.makeRespPk(req, "two", false)
pk3 := s.makeRespPk(req, "three", true)
s.SendResponse(ctx, pk1)
s.SendResponse(ctx, pk2)
s.SendResponse(ctx, pk3)
return nil
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
respCh, err := client.StreamReq(ctx, "test", "hello", 2*time.Second)
if err != nil {
t.Errorf("StreamReq() failed: %v", err)
return
}
var result []string
for respPk := range respCh {
if respPk.Error != "" {
t.Errorf("StreamReq() failed: %v", respPk.Error)
return
}
log.Printf("got response: %#v", respPk)
result = append(result, respPk.Data.(string))
}
if len(result) != 3 {
t.Errorf("expected 3 responses, got %d", len(result))
return
}
if result[0] != "one" || result[1] != "two" || result[2] != "three" {
t.Errorf("expected 'one', 'two', 'three', got %v", result)
return
}
}()
wg.Wait()
}