mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-02 18:39:05 +01:00
rpc checkpoint
This commit is contained in:
parent
f90554e87e
commit
dcd6d04b0b
@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
|
||||
type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error
|
||||
type StreamCommandHandlerFn func(context.Context, *RpcServer, *RpcPacket) error
|
||||
|
||||
type RpcServer struct {
|
||||
CVar *sync.Cond
|
||||
@ -40,20 +40,47 @@ func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
|
||||
RpcReqs: make(map[string]*RpcInfo),
|
||||
SendCh: sendCh,
|
||||
RecvCh: recvCh,
|
||||
SimpleCommandHandlers: make(map[string]SimpleCommandHandlerFn),
|
||||
StreamCommandHandlers: make(map[string]StreamCommandHandlerFn),
|
||||
}
|
||||
go rtn.runRecvLoop()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (s *RpcServer) shouldUseStreamHandler(command string) bool {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
_, ok := s.StreamCommandHandlers[command]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *RpcServer) getStreamHandler(command string) StreamCommandHandlerFn {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
return s.StreamCommandHandlers[command]
|
||||
}
|
||||
|
||||
func (s *RpcServer) getSimpleHandler(command string) SimpleCommandHandlerFn {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
return s.SimpleCommandHandlers[command]
|
||||
}
|
||||
|
||||
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
if s.StreamCommandHandlers[command] != nil {
|
||||
panic(fmt.Errorf("command %q already registered as a stream handler", command))
|
||||
}
|
||||
s.SimpleCommandHandlers[command] = handler
|
||||
}
|
||||
|
||||
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
if s.SimpleCommandHandlers[command] != nil {
|
||||
panic(fmt.Errorf("command %q already registered as a simple handler", command))
|
||||
}
|
||||
s.StreamCommandHandlers[command] = handler
|
||||
}
|
||||
|
||||
@ -67,10 +94,10 @@ func (s *RpcServer) runRecvLoop() {
|
||||
for pk := range s.RecvCh {
|
||||
s.handleAcks(pk.Acks)
|
||||
if pk.RpcType == RpcType_Req {
|
||||
if pk.ReqDone {
|
||||
s.handleSimpleReq(pk)
|
||||
} else {
|
||||
if s.shouldUseStreamHandler(pk.Command) {
|
||||
s.handleStreamReq(pk)
|
||||
} else {
|
||||
s.handleSimpleReq(pk)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -163,8 +190,9 @@ func (s *RpcServer) handleAcks(acks []int64) {
|
||||
|
||||
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
|
||||
s.ackResp(pk.SeqNum)
|
||||
handler, ok := s.SimpleCommandHandlers[pk.Command]
|
||||
if !ok {
|
||||
handler := s.getSimpleHandler(pk.Command)
|
||||
if handler == nil {
|
||||
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
|
||||
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
|
||||
return
|
||||
}
|
||||
@ -201,11 +229,35 @@ func (s *RpcServer) grabAcks_nolock() []int64 {
|
||||
return acks
|
||||
}
|
||||
|
||||
func (s *RpcServer) sendErrorResp(pk *RpcPacket, err error) {
|
||||
respPk := &RpcPacket{
|
||||
Command: pk.Command,
|
||||
RpcId: pk.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: s.NextSeqNum.Add(1),
|
||||
RespDone: true,
|
||||
Error: err.Error(),
|
||||
}
|
||||
s.waitForSend(context.Background(), respPk)
|
||||
}
|
||||
|
||||
func (s *RpcServer) makeRespPk(pk *RpcPacket, data any, done bool) *RpcPacket {
|
||||
return &RpcPacket{
|
||||
Command: pk.Command,
|
||||
RpcId: pk.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: s.NextSeqNum.Add(1),
|
||||
RespDone: done,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
||||
s.ackResp(pk.SeqNum)
|
||||
handler, ok := s.StreamCommandHandlers[pk.Command]
|
||||
if !ok {
|
||||
handler := s.getStreamHandler(pk.Command)
|
||||
if handler == nil {
|
||||
s.ackResp(pk.SeqNum)
|
||||
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
|
||||
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
|
||||
return
|
||||
}
|
||||
@ -229,7 +281,7 @@ func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
||||
}()
|
||||
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
||||
defer cancelFn()
|
||||
err := handler(ctx, s, pk.Command, pk.Data)
|
||||
err := handler(ctx, s, pk)
|
||||
if err != nil {
|
||||
respPk := &RpcPacket{
|
||||
Command: pk.Command,
|
||||
|
@ -5,6 +5,8 @@ package wshprc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -113,3 +115,80 @@ func TestStream(t *testing.T) {
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSimpleClientServer(t *testing.T) {
|
||||
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
client := MakeRpcClient(sendCh, recvCh)
|
||||
server := MakeRpcServer(recvCh, sendCh)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
server.RegisterSimpleCommandHandler("test", func(ctx context.Context, s *RpcServer, cmd string, data any) (any, error) {
|
||||
if data != "hello" {
|
||||
return nil, fmt.Errorf("expected 'hello', got '%s'", data)
|
||||
}
|
||||
return "world", nil
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := client.SimpleReq(ctx, "test", "hello")
|
||||
if err != nil {
|
||||
t.Errorf("SimpleReq() failed: %v", err)
|
||||
return
|
||||
}
|
||||
if resp != "world" {
|
||||
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
func TestStreamClientServer(t *testing.T) {
|
||||
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
client := MakeRpcClient(sendCh, recvCh)
|
||||
server := MakeRpcServer(recvCh, sendCh)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
server.RegisterStreamCommandHandler("test", func(ctx context.Context, s *RpcServer, req *RpcPacket) error {
|
||||
pk1 := s.makeRespPk(req, "one", false)
|
||||
pk2 := s.makeRespPk(req, "two", false)
|
||||
pk3 := s.makeRespPk(req, "three", true)
|
||||
s.SendResponse(ctx, pk1)
|
||||
s.SendResponse(ctx, pk2)
|
||||
s.SendResponse(ctx, pk3)
|
||||
return nil
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
respCh, err := client.StreamReq(ctx, "test", "hello", 2*time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("StreamReq() failed: %v", err)
|
||||
return
|
||||
}
|
||||
var result []string
|
||||
for respPk := range respCh {
|
||||
if respPk.Error != "" {
|
||||
t.Errorf("StreamReq() failed: %v", respPk.Error)
|
||||
return
|
||||
}
|
||||
log.Printf("got response: %#v", respPk)
|
||||
result = append(result, respPk.Data.(string))
|
||||
}
|
||||
if len(result) != 3 {
|
||||
t.Errorf("expected 3 responses, got %d", len(result))
|
||||
return
|
||||
}
|
||||
if result[0] != "one" || result[1] != "two" || result[2] != "three" {
|
||||
t.Errorf("expected 'one', 'two', 'three', got %v", result)
|
||||
return
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user