mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-01 23:21:59 +01:00
working on rpc server
This commit is contained in:
parent
45f20bb5c3
commit
394b9dce23
@ -28,10 +28,10 @@ type RpcClient struct {
|
||||
}
|
||||
|
||||
type RpcInfo struct {
|
||||
CloseSync *sync.Once
|
||||
RpcId string
|
||||
ReqPacketsInFlight map[int64]bool // seqnum -> bool
|
||||
RespCh chan *RpcPacket
|
||||
CloseSync *sync.Once
|
||||
RpcId string
|
||||
PacketsInFlight map[int64]bool // seqnum -> bool (for clients this is for requests, for servers it is for responses)
|
||||
PkCh chan *RpcPacket // for clients this is for responses, for servers it is for requests
|
||||
}
|
||||
|
||||
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
|
||||
@ -88,7 +88,7 @@ func (c *RpcClient) handleResp(pk *RpcPacket) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case rpcInfo.RespCh <- pk:
|
||||
case rpcInfo.PkCh <- pk:
|
||||
default:
|
||||
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
|
||||
}
|
||||
@ -131,7 +131,7 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e
|
||||
continue
|
||||
}
|
||||
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
|
||||
if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc {
|
||||
if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
|
||||
c.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
@ -147,12 +147,12 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e
|
||||
rpcInfo := c.RpcReqs[req.RpcId]
|
||||
if rpcInfo == nil {
|
||||
rpcInfo = &RpcInfo{
|
||||
CloseSync: &sync.Once{},
|
||||
RpcId: req.RpcId,
|
||||
ReqPacketsInFlight: make(map[int64]bool),
|
||||
RespCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||
CloseSync: &sync.Once{},
|
||||
RpcId: req.RpcId,
|
||||
PacketsInFlight: make(map[int64]bool),
|
||||
PkCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||
}
|
||||
rpcInfo.ReqPacketsInFlight[req.SeqNum] = true
|
||||
rpcInfo.PacketsInFlight[req.SeqNum] = true
|
||||
c.RpcReqs[req.RpcId] = rpcInfo
|
||||
}
|
||||
return rpcInfo, nil
|
||||
@ -171,7 +171,7 @@ func (c *RpcClient) handleAcks(acks []int64) {
|
||||
}
|
||||
rpcInfo := c.RpcReqs[rpcId]
|
||||
if rpcInfo != nil {
|
||||
delete(rpcInfo.ReqPacketsInFlight, ack)
|
||||
delete(rpcInfo.PacketsInFlight, ack)
|
||||
}
|
||||
delete(c.ReqPacketsInFlight, ack)
|
||||
}
|
||||
@ -188,12 +188,12 @@ func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
|
||||
// unblock the recv loop if it happens to be waiting
|
||||
// because the delete has already happens, it will not be able to send again on the channel
|
||||
select {
|
||||
case <-rpcInfo.RespCh:
|
||||
case <-rpcInfo.PkCh:
|
||||
default:
|
||||
}
|
||||
}
|
||||
rpcInfo.CloseSync.Do(func() {
|
||||
close(rpcInfo.RespCh)
|
||||
close(rpcInfo.PkCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -225,7 +225,7 @@ func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (an
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case rtnPacket = <-rpcInfo.RespCh:
|
||||
case rtnPacket = <-rpcInfo.PkCh:
|
||||
// fallthrough
|
||||
}
|
||||
if rtnPacket.Error != "" {
|
||||
@ -256,7 +256,7 @@ func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, res
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rpcInfo.RespCh, nil
|
||||
return rpcInfo.PkCh, nil
|
||||
}
|
||||
|
||||
func (c *RpcClient) EndStreamReq(rpcId string) {
|
||||
|
247
pkg/wshrpc/rpc_server.go
Normal file
247
pkg/wshrpc/rpc_server.go
Normal file
@ -0,0 +1,247 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshprc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
|
||||
type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error
|
||||
|
||||
type RpcServer struct {
|
||||
CVar *sync.Cond
|
||||
NextSeqNum *atomic.Int64
|
||||
RespPacketsInFlight map[int64]string // seqnum -> rpcId
|
||||
AckList []int64
|
||||
RpcReqs map[string]*RpcInfo
|
||||
SendCh chan *RpcPacket
|
||||
RecvCh chan *RpcPacket
|
||||
SimpleCommandHandlers map[string]SimpleCommandHandlerFn
|
||||
StreamCommandHandlers map[string]StreamCommandHandlerFn
|
||||
}
|
||||
|
||||
func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
|
||||
if cap(sendCh) < MaxInFlightPackets {
|
||||
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
|
||||
}
|
||||
rtn := &RpcServer{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
NextSeqNum: &atomic.Int64{},
|
||||
RespPacketsInFlight: make(map[int64]string),
|
||||
AckList: nil,
|
||||
RpcReqs: make(map[string]*RpcInfo),
|
||||
SendCh: sendCh,
|
||||
RecvCh: recvCh,
|
||||
}
|
||||
go rtn.runRecvLoop()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
s.SimpleCommandHandlers[command] = handler
|
||||
}
|
||||
|
||||
func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
s.StreamCommandHandlers[command] = handler
|
||||
}
|
||||
|
||||
func (s *RpcServer) runRecvLoop() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("RpcServer.runRecvLoop() panic: %v", r)
|
||||
debug.PrintStack()
|
||||
}
|
||||
}()
|
||||
for pk := range s.RecvCh {
|
||||
s.handleAcks(pk.Acks)
|
||||
if pk.RpcType == RpcType_Req {
|
||||
if pk.ReqDone {
|
||||
s.handleSimpleReq(pk)
|
||||
} else {
|
||||
s.handleStreamReq(pk)
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
|
||||
}
|
||||
log.Printf("RpcServer.runRecvLoop() normal exit")
|
||||
}
|
||||
|
||||
func (s *RpcServer) ackResp(seqNum int64) {
|
||||
if seqNum == 0 {
|
||||
return
|
||||
}
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
s.AckList = append(s.AckList, seqNum)
|
||||
}
|
||||
|
||||
func makeContextFromTimeout(timeout *TimeoutInfo) (context.Context, context.CancelFunc) {
|
||||
if timeout == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
return context.WithDeadline(context.Background(), time.UnixMilli(timeout.Deadline))
|
||||
}
|
||||
|
||||
func (s *RpcServer) SendResponse(ctx context.Context, pk *RpcPacket) error {
|
||||
return s.waitForSend(ctx, pk)
|
||||
}
|
||||
|
||||
func (s *RpcServer) waitForSend(ctx context.Context, pk *RpcPacket) error {
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if len(s.RespPacketsInFlight) >= MaxInFlightPackets {
|
||||
s.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
rpcInfo := s.RpcReqs[pk.RpcId]
|
||||
if rpcInfo != nil {
|
||||
if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
|
||||
s.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
s.RespPacketsInFlight[pk.SeqNum] = pk.RpcId
|
||||
pk.Acks = s.grabAcks_nolock()
|
||||
s.SendCh <- pk
|
||||
rpcInfo := s.RpcReqs[pk.RpcId]
|
||||
if !pk.RespDone && rpcInfo != nil {
|
||||
rpcInfo = &RpcInfo{
|
||||
CloseSync: &sync.Once{},
|
||||
RpcId: pk.RpcId,
|
||||
PkCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||
PacketsInFlight: make(map[int64]bool),
|
||||
}
|
||||
s.RpcReqs[pk.RpcId] = rpcInfo
|
||||
}
|
||||
if rpcInfo != nil {
|
||||
rpcInfo.PacketsInFlight[pk.SeqNum] = true
|
||||
}
|
||||
if pk.RespDone {
|
||||
delete(s.RpcReqs, pk.RpcId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RpcServer) handleAcks(acks []int64) {
|
||||
if len(acks) == 0 {
|
||||
return
|
||||
}
|
||||
s.CVar.L.Lock()
|
||||
defer s.CVar.L.Unlock()
|
||||
for _, ack := range acks {
|
||||
rpcId, ok := s.RespPacketsInFlight[ack]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
rpcInfo := s.RpcReqs[rpcId]
|
||||
if rpcInfo != nil {
|
||||
delete(rpcInfo.PacketsInFlight, ack)
|
||||
}
|
||||
delete(s.RespPacketsInFlight, ack)
|
||||
}
|
||||
s.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
|
||||
s.ackResp(pk.SeqNum)
|
||||
handler, ok := s.SimpleCommandHandlers[pk.Command]
|
||||
if !ok {
|
||||
log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("RpcServer.handleReq(%q) panic: %v", pk.Command, r)
|
||||
debug.PrintStack()
|
||||
}
|
||||
}()
|
||||
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
||||
defer cancelFn()
|
||||
data, err := handler(ctx, s, pk.Command, pk.Data)
|
||||
seqNum := s.NextSeqNum.Add(1)
|
||||
respPk := &RpcPacket{
|
||||
Command: pk.Command,
|
||||
RpcId: pk.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: seqNum,
|
||||
RespDone: true,
|
||||
}
|
||||
if err != nil {
|
||||
respPk.Error = err.Error()
|
||||
} else {
|
||||
respPk.Data = data
|
||||
}
|
||||
s.waitForSend(ctx, respPk)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *RpcServer) grabAcks_nolock() []int64 {
|
||||
acks := s.AckList
|
||||
s.AckList = nil
|
||||
return acks
|
||||
}
|
||||
|
||||
func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
|
||||
s.ackResp(pk.SeqNum)
|
||||
handler, ok := s.StreamCommandHandlers[pk.Command]
|
||||
if !ok {
|
||||
s.ackResp(pk.SeqNum)
|
||||
log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
log.Printf("RpcServer.handleStreamReq(%q) panic: %v", pk.Command, r)
|
||||
debug.PrintStack()
|
||||
respPk := &RpcPacket{
|
||||
Command: pk.Command,
|
||||
RpcId: pk.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: s.NextSeqNum.Add(1),
|
||||
RespDone: true,
|
||||
Error: fmt.Sprintf("panic: %v", r),
|
||||
}
|
||||
s.waitForSend(context.Background(), respPk)
|
||||
}()
|
||||
ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
|
||||
defer cancelFn()
|
||||
err := handler(ctx, s, pk.Command, pk.Data)
|
||||
if err != nil {
|
||||
respPk := &RpcPacket{
|
||||
Command: pk.Command,
|
||||
RpcId: pk.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: s.NextSeqNum.Add(1),
|
||||
RespDone: true,
|
||||
Error: err.Error(),
|
||||
}
|
||||
s.waitForSend(ctx, respPk)
|
||||
return
|
||||
}
|
||||
// check if RespDone has been set, if not, send it here
|
||||
}()
|
||||
}
|
Loading…
Reference in New Issue
Block a user