mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-01 23:21:59 +01:00
working on rpc server
This commit is contained in:
parent
45f20bb5c3
commit
394b9dce23
@ -28,10 +28,10 @@ type RpcClient struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RpcInfo struct {
|
type RpcInfo struct {
|
||||||
CloseSync *sync.Once
|
CloseSync *sync.Once
|
||||||
RpcId string
|
RpcId string
|
||||||
ReqPacketsInFlight map[int64]bool // seqnum -> bool
|
PacketsInFlight map[int64]bool // seqnum -> bool (for clients this is for requests, for servers it is for responses)
|
||||||
RespCh chan *RpcPacket
|
PkCh chan *RpcPacket // for clients this is for responses, for servers it is for requests
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
|
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
|
||||||
@ -88,7 +88,7 @@ func (c *RpcClient) handleResp(pk *RpcPacket) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case rpcInfo.RespCh <- pk:
|
case rpcInfo.PkCh <- pk:
|
||||||
default:
|
default:
|
||||||
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
|
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
|
||||||
}
|
}
|
||||||
@ -131,7 +131,7 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
|
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
|
||||||
if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc {
|
if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
|
||||||
c.CVar.Wait()
|
c.CVar.Wait()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -147,12 +147,12 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e
|
|||||||
rpcInfo := c.RpcReqs[req.RpcId]
|
rpcInfo := c.RpcReqs[req.RpcId]
|
||||||
if rpcInfo == nil {
|
if rpcInfo == nil {
|
||||||
rpcInfo = &RpcInfo{
|
rpcInfo = &RpcInfo{
|
||||||
CloseSync: &sync.Once{},
|
CloseSync: &sync.Once{},
|
||||||
RpcId: req.RpcId,
|
RpcId: req.RpcId,
|
||||||
ReqPacketsInFlight: make(map[int64]bool),
|
PacketsInFlight: make(map[int64]bool),
|
||||||
RespCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
PkCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||||
}
|
}
|
||||||
rpcInfo.ReqPacketsInFlight[req.SeqNum] = true
|
rpcInfo.PacketsInFlight[req.SeqNum] = true
|
||||||
c.RpcReqs[req.RpcId] = rpcInfo
|
c.RpcReqs[req.RpcId] = rpcInfo
|
||||||
}
|
}
|
||||||
return rpcInfo, nil
|
return rpcInfo, nil
|
||||||
@ -171,7 +171,7 @@ func (c *RpcClient) handleAcks(acks []int64) {
|
|||||||
}
|
}
|
||||||
rpcInfo := c.RpcReqs[rpcId]
|
rpcInfo := c.RpcReqs[rpcId]
|
||||||
if rpcInfo != nil {
|
if rpcInfo != nil {
|
||||||
delete(rpcInfo.ReqPacketsInFlight, ack)
|
delete(rpcInfo.PacketsInFlight, ack)
|
||||||
}
|
}
|
||||||
delete(c.ReqPacketsInFlight, ack)
|
delete(c.ReqPacketsInFlight, ack)
|
||||||
}
|
}
|
||||||
@ -188,12 +188,12 @@ func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
|
|||||||
// unblock the recv loop if it happens to be waiting
|
// unblock the recv loop if it happens to be waiting
|
||||||
// because the delete has already happens, it will not be able to send again on the channel
|
// because the delete has already happens, it will not be able to send again on the channel
|
||||||
select {
|
select {
|
||||||
case <-rpcInfo.RespCh:
|
case <-rpcInfo.PkCh:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rpcInfo.CloseSync.Do(func() {
|
rpcInfo.CloseSync.Do(func() {
|
||||||
close(rpcInfo.RespCh)
|
close(rpcInfo.PkCh)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -225,7 +225,7 @@ func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (an
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case rtnPacket = <-rpcInfo.RespCh:
|
case rtnPacket = <-rpcInfo.PkCh:
|
||||||
// fallthrough
|
// fallthrough
|
||||||
}
|
}
|
||||||
if rtnPacket.Error != "" {
|
if rtnPacket.Error != "" {
|
||||||
@ -256,7 +256,7 @@ func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, res
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return rpcInfo.RespCh, nil
|
return rpcInfo.PkCh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RpcClient) EndStreamReq(rpcId string) {
|
func (c *RpcClient) EndStreamReq(rpcId string) {
|
||||||
|
247
pkg/wshrpc/rpc_server.go
Normal file
247
pkg/wshrpc/rpc_server.go
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wshprc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"runtime/debug"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
|
||||||
|
type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error
|
||||||
|
|
||||||
|
type RpcServer struct {
|
||||||
|
CVar *sync.Cond
|
||||||
|
NextSeqNum *atomic.Int64
|
||||||
|
RespPacketsInFlight map[int64]string // seqnum -> rpcId
|
||||||
|
AckList []int64
|
||||||
|
RpcReqs map[string]*RpcInfo
|
||||||
|
SendCh chan *RpcPacket
|
||||||
|
RecvCh chan *RpcPacket
|
||||||
|
SimpleCommandHandlers map[string]SimpleCommandHandlerFn
|
||||||
|
StreamCommandHandlers map[string]StreamCommandHandlerFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
|
||||||
|
if cap(sendCh) < MaxInFlightPackets {
|
||||||
|
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
|
||||||
|
}
|
||||||
|
rtn := &RpcServer{
|
||||||
|
CVar: sync.NewCond(&sync.Mutex{}),
|
||||||
|
NextSeqNum: &atomic.Int64{},
|
||||||
|
RespPacketsInFlight: make(map[int64]string),
|
||||||
|
AckList: nil,
|
||||||
|
RpcReqs: make(map[string]*RpcInfo),
|
||||||
|
SendCh: sendCh,
|
||||||
|
RecvCh: recvCh,
|
||||||
|
}
|
||||||
|
go rtn.runRecvLoop()
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
s.SimpleCommandHandlers[command] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
s.StreamCommandHandlers[command] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) runRecvLoop() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("RpcServer.runRecvLoop() panic: %v", r)
|
||||||
|
debug.PrintStack()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for pk := range s.RecvCh {
|
||||||
|
s.handleAcks(pk.Acks)
|
||||||
|
if pk.RpcType == RpcType_Req {
|
||||||
|
if pk.ReqDone {
|
||||||
|
s.handleSimpleReq(pk)
|
||||||
|
} else {
|
||||||
|
s.handleStreamReq(pk)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
|
||||||
|
}
|
||||||
|
log.Printf("RpcServer.runRecvLoop() normal exit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) ackResp(seqNum int64) {
|
||||||
|
if seqNum == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
s.AckList = append(s.AckList, seqNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeContextFromTimeout(timeout *TimeoutInfo) (context.Context, context.CancelFunc) {
|
||||||
|
if timeout == nil {
|
||||||
|
return context.Background(), func() {}
|
||||||
|
}
|
||||||
|
return context.WithDeadline(context.Background(), time.UnixMilli(timeout.Deadline))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) SendResponse(ctx context.Context, pk *RpcPacket) error {
|
||||||
|
return s.waitForSend(ctx, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) waitForSend(ctx context.Context, pk *RpcPacket) error {
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
if len(s.RespPacketsInFlight) >= MaxInFlightPackets {
|
||||||
|
s.CVar.Wait()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rpcInfo := s.RpcReqs[pk.RpcId]
|
||||||
|
if rpcInfo != nil {
|
||||||
|
if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
|
||||||
|
s.CVar.Wait()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
s.RespPacketsInFlight[pk.SeqNum] = pk.RpcId
|
||||||
|
pk.Acks = s.grabAcks_nolock()
|
||||||
|
s.SendCh <- pk
|
||||||
|
rpcInfo := s.RpcReqs[pk.RpcId]
|
||||||
|
if !pk.RespDone && rpcInfo != nil {
|
||||||
|
rpcInfo = &RpcInfo{
|
||||||
|
CloseSync: &sync.Once{},
|
||||||
|
RpcId: pk.RpcId,
|
||||||
|
PkCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||||
|
PacketsInFlight: make(map[int64]bool),
|
||||||
|
}
|
||||||
|
s.RpcReqs[pk.RpcId] = rpcInfo
|
||||||
|
}
|
||||||
|
if rpcInfo != nil {
|
||||||
|
rpcInfo.PacketsInFlight[pk.SeqNum] = true
|
||||||
|
}
|
||||||
|
if pk.RespDone {
|
||||||
|
delete(s.RpcReqs, pk.RpcId)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) handleAcks(acks []int64) {
|
||||||
|
if len(acks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.CVar.L.Lock()
|
||||||
|
defer s.CVar.L.Unlock()
|
||||||
|
for _, ack := range acks {
|
||||||
|
rpcId, ok := s.RespPacketsInFlight[ack]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rpcInfo := s.RpcReqs[rpcId]
|
||||||
|
if rpcInfo != nil {
|
||||||
|
delete(rpcInfo.PacketsInFlight, ack)
|
||||||
|
}
|
||||||
|
delete(s.RespPacketsInFlight, ack)
|
||||||
|
}
|
||||||
|
s.CVar.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
|
||||||
|
s.ackResp(pk.SeqNum)
|
||||||
|
handler, ok := s.SimpleCommandHandlers[pk.Command]
|
||||||
|
if !ok {
|
||||||
|
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("RpcServer.handleReq(%q) panic: %v", pk.Command, r)
|
||||||
|
debug.PrintStack()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
||||||
|
defer cancelFn()
|
||||||
|
data, err := handler(ctx, s, pk.Command, pk.Data)
|
||||||
|
seqNum := s.NextSeqNum.Add(1)
|
||||||
|
respPk := &RpcPacket{
|
||||||
|
Command: pk.Command,
|
||||||
|
RpcId: pk.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: seqNum,
|
||||||
|
RespDone: true,
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
respPk.Error = err.Error()
|
||||||
|
} else {
|
||||||
|
respPk.Data = data
|
||||||
|
}
|
||||||
|
s.waitForSend(ctx, respPk)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) grabAcks_nolock() []int64 {
|
||||||
|
acks := s.AckList
|
||||||
|
s.AckList = nil
|
||||||
|
return acks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
||||||
|
s.ackResp(pk.SeqNum)
|
||||||
|
handler, ok := s.StreamCommandHandlers[pk.Command]
|
||||||
|
if !ok {
|
||||||
|
s.ackResp(pk.SeqNum)
|
||||||
|
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("RpcServer.handleStreamReq(%q) panic: %v", pk.Command, r)
|
||||||
|
debug.PrintStack()
|
||||||
|
respPk := &RpcPacket{
|
||||||
|
Command: pk.Command,
|
||||||
|
RpcId: pk.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: s.NextSeqNum.Add(1),
|
||||||
|
RespDone: true,
|
||||||
|
Error: fmt.Sprintf("panic: %v", r),
|
||||||
|
}
|
||||||
|
s.waitForSend(context.Background(), respPk)
|
||||||
|
}()
|
||||||
|
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
||||||
|
defer cancelFn()
|
||||||
|
err := handler(ctx, s, pk.Command, pk.Data)
|
||||||
|
if err != nil {
|
||||||
|
respPk := &RpcPacket{
|
||||||
|
Command: pk.Command,
|
||||||
|
RpcId: pk.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: s.NextSeqNum.Add(1),
|
||||||
|
RespDone: true,
|
||||||
|
Error: err.Error(),
|
||||||
|
}
|
||||||
|
s.waitForSend(ctx, respPk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check if RespDone has been set, if not, send it here
|
||||||
|
}()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user