mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-04 18:59:08 +01:00
rpc checkpoint
This commit is contained in:
parent
f90554e87e
commit
dcd6d04b0b
@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
|
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
|
||||||
type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error
|
type StreamCommandHandlerFn func(context.Context, *RpcServer, *RpcPacket) error
|
||||||
|
|
||||||
type RpcServer struct {
|
type RpcServer struct {
|
||||||
CVar *sync.Cond
|
CVar *sync.Cond
|
||||||
@ -33,27 +33,54 @@ func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
|
|||||||
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
|
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
|
||||||
}
|
}
|
||||||
rtn := &RpcServer{
|
rtn := &RpcServer{
|
||||||
CVar: sync.NewCond(&sync.Mutex{}),
|
CVar: sync.NewCond(&sync.Mutex{}),
|
||||||
NextSeqNum: &atomic.Int64{},
|
NextSeqNum: &atomic.Int64{},
|
||||||
RespPacketsInFlight: make(map[int64]string),
|
RespPacketsInFlight: make(map[int64]string),
|
||||||
AckList: nil,
|
AckList: nil,
|
||||||
RpcReqs: make(map[string]*RpcInfo),
|
RpcReqs: make(map[string]*RpcInfo),
|
||||||
SendCh: sendCh,
|
SendCh: sendCh,
|
||||||
RecvCh: recvCh,
|
RecvCh: recvCh,
|
||||||
|
SimpleCommandHandlers: make(map[string]SimpleCommandHandlerFn),
|
||||||
|
StreamCommandHandlers: make(map[string]StreamCommandHandlerFn),
|
||||||
}
|
}
|
||||||
go rtn.runRecvLoop()
|
go rtn.runRecvLoop()
|
||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) shouldUseStreamHandler(command string) bool {
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
_, ok := s.StreamCommandHandlers[command]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) getStreamHandler(command string) StreamCommandHandlerFn {
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
return s.StreamCommandHandlers[command]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) getSimpleHandler(command string) SimpleCommandHandlerFn {
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
return s.SimpleCommandHandlers[command]
|
||||||
|
}
|
||||||
|
|
||||||
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
|
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
|
||||||
s.CVar.L.Lock()
|
s.CVar.L.Lock()
|
||||||
defer s.CVar.L.Unlock()
|
defer s.CVar.L.Unlock()
|
||||||
|
if s.StreamCommandHandlers[command] != nil {
|
||||||
|
panic(fmt.Errorf("command %q already registered as a stream handler", command))
|
||||||
|
}
|
||||||
s.SimpleCommandHandlers[command] = handler
|
s.SimpleCommandHandlers[command] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
|
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
|
||||||
s.CVar.L.Lock()
|
s.CVar.L.Lock()
|
||||||
defer s.CVar.L.Unlock()
|
defer s.CVar.L.Unlock()
|
||||||
|
if s.SimpleCommandHandlers[command] != nil {
|
||||||
|
panic(fmt.Errorf("command %q already registered as a simple handler", command))
|
||||||
|
}
|
||||||
s.StreamCommandHandlers[command] = handler
|
s.StreamCommandHandlers[command] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,10 +94,10 @@ func (s *RpcServer) runRecvLoop() {
|
|||||||
for pk := range s.RecvCh {
|
for pk := range s.RecvCh {
|
||||||
s.handleAcks(pk.Acks)
|
s.handleAcks(pk.Acks)
|
||||||
if pk.RpcType == RpcType_Req {
|
if pk.RpcType == RpcType_Req {
|
||||||
if pk.ReqDone {
|
if s.shouldUseStreamHandler(pk.Command) {
|
||||||
s.handleSimpleReq(pk)
|
|
||||||
} else {
|
|
||||||
s.handleStreamReq(pk)
|
s.handleStreamReq(pk)
|
||||||
|
} else {
|
||||||
|
s.handleSimpleReq(pk)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -163,8 +190,9 @@ func (s *RpcServer) handleAcks(acks []int64) {
|
|||||||
|
|
||||||
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
|
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
|
||||||
s.ackResp(pk.SeqNum)
|
s.ackResp(pk.SeqNum)
|
||||||
handler, ok := s.SimpleCommandHandlers[pk.Command]
|
handler := s.getSimpleHandler(pk.Command)
|
||||||
if !ok {
|
if handler == nil {
|
||||||
|
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
|
||||||
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
|
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -201,11 +229,35 @@ func (s *RpcServer) grabAcks_nolock() []int64 {
|
|||||||
return acks
|
return acks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) sendErrorResp(pk *RpcPacket, err error) {
|
||||||
|
respPk := &RpcPacket{
|
||||||
|
Command: pk.Command,
|
||||||
|
RpcId: pk.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: s.NextSeqNum.Add(1),
|
||||||
|
RespDone: true,
|
||||||
|
Error: err.Error(),
|
||||||
|
}
|
||||||
|
s.waitForSend(context.Background(), respPk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) makeRespPk(pk *RpcPacket, data any, done bool) *RpcPacket {
|
||||||
|
return &RpcPacket{
|
||||||
|
Command: pk.Command,
|
||||||
|
RpcId: pk.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: s.NextSeqNum.Add(1),
|
||||||
|
RespDone: done,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
||||||
s.ackResp(pk.SeqNum)
|
s.ackResp(pk.SeqNum)
|
||||||
handler, ok := s.StreamCommandHandlers[pk.Command]
|
handler := s.getStreamHandler(pk.Command)
|
||||||
if !ok {
|
if handler == nil {
|
||||||
s.ackResp(pk.SeqNum)
|
s.ackResp(pk.SeqNum)
|
||||||
|
s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command))
|
||||||
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
|
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -229,7 +281,7 @@ func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
|||||||
}()
|
}()
|
||||||
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
err := handler(ctx, s, pk.Command, pk.Data)
|
err := handler(ctx, s, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respPk := &RpcPacket{
|
respPk := &RpcPacket{
|
||||||
Command: pk.Command,
|
Command: pk.Command,
|
||||||
|
@ -5,6 +5,8 @@ package wshprc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -113,3 +115,80 @@ func TestStream(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSimpleClientServer(t *testing.T) {
|
||||||
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
client := MakeRpcClient(sendCh, recvCh)
|
||||||
|
server := MakeRpcServer(recvCh, sendCh)
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
server.RegisterSimpleCommandHandler("test", func(ctx context.Context, s *RpcServer, cmd string, data any) (any, error) {
|
||||||
|
if data != "hello" {
|
||||||
|
return nil, fmt.Errorf("expected 'hello', got '%s'", data)
|
||||||
|
}
|
||||||
|
return "world", nil
|
||||||
|
})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := client.SimpleReq(ctx, "test", "hello")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("SimpleReq() failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp != "world" {
|
||||||
|
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamClientServer(t *testing.T) {
|
||||||
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
client := MakeRpcClient(sendCh, recvCh)
|
||||||
|
server := MakeRpcServer(recvCh, sendCh)
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
server.RegisterStreamCommandHandler("test", func(ctx context.Context, s *RpcServer, req *RpcPacket) error {
|
||||||
|
pk1 := s.makeRespPk(req, "one", false)
|
||||||
|
pk2 := s.makeRespPk(req, "two", false)
|
||||||
|
pk3 := s.makeRespPk(req, "three", true)
|
||||||
|
s.SendResponse(ctx, pk1)
|
||||||
|
s.SendResponse(ctx, pk2)
|
||||||
|
s.SendResponse(ctx, pk3)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
respCh, err := client.StreamReq(ctx, "test", "hello", 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("StreamReq() failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var result []string
|
||||||
|
for respPk := range respCh {
|
||||||
|
if respPk.Error != "" {
|
||||||
|
t.Errorf("StreamReq() failed: %v", respPk.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("got response: %#v", respPk)
|
||||||
|
result = append(result, respPk.Data.(string))
|
||||||
|
}
|
||||||
|
if len(result) != 3 {
|
||||||
|
t.Errorf("expected 3 responses, got %d", len(result))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result[0] != "one" || result[1] != "two" || result[2] != "three" {
|
||||||
|
t.Errorf("expected 'one', 'two', 'three', got %v", result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user