mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
wsh rpc client
This commit is contained in:
parent
2472deb379
commit
45f20bb5c3
@ -8,8 +8,6 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
const WaveOSC = "23198"
|
||||
|
||||
func main() {
|
||||
barr, err := os.ReadFile("/Users/mike/Downloads/2.png")
|
||||
if err != nil {
|
||||
|
264
pkg/wshrpc/rpc_client.go
Normal file
264
pkg/wshrpc/rpc_client.go
Normal file
@ -0,0 +1,264 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshprc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// there is a single go-routine that reads from RecvCh
|
||||
type RpcClient struct {
|
||||
CVar *sync.Cond
|
||||
NextSeqNum *atomic.Int64
|
||||
ReqPacketsInFlight map[int64]string // seqnum -> rpcId
|
||||
AckList []int64
|
||||
RpcReqs map[string]*RpcInfo
|
||||
SendCh chan *RpcPacket
|
||||
RecvCh chan *RpcPacket
|
||||
}
|
||||
|
||||
type RpcInfo struct {
|
||||
CloseSync *sync.Once
|
||||
RpcId string
|
||||
ReqPacketsInFlight map[int64]bool // seqnum -> bool
|
||||
RespCh chan *RpcPacket
|
||||
}
|
||||
|
||||
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
|
||||
if cap(sendCh) < MaxInFlightPackets {
|
||||
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
|
||||
}
|
||||
rtn := &RpcClient{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
NextSeqNum: &atomic.Int64{},
|
||||
ReqPacketsInFlight: make(map[int64]string),
|
||||
AckList: nil,
|
||||
RpcReqs: make(map[string]*RpcInfo),
|
||||
SendCh: sendCh,
|
||||
RecvCh: recvCh,
|
||||
}
|
||||
go rtn.runRecvLoop()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (c *RpcClient) runRecvLoop() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("RpcClient.runRecvLoop() panic: %v", r)
|
||||
debug.PrintStack()
|
||||
}
|
||||
}()
|
||||
for pk := range c.RecvCh {
|
||||
if pk.RpcType == RpcType_Resp {
|
||||
c.handleResp(pk)
|
||||
continue
|
||||
}
|
||||
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
|
||||
}
|
||||
log.Printf("RpcClient.runRecvLoop() normal exit")
|
||||
}
|
||||
|
||||
func (c *RpcClient) getRpcInfo(rpcId string) *RpcInfo {
|
||||
c.CVar.L.Lock()
|
||||
defer c.CVar.L.Unlock()
|
||||
return c.RpcReqs[rpcId]
|
||||
}
|
||||
|
||||
func (c *RpcClient) handleResp(pk *RpcPacket) {
|
||||
c.handleAcks(pk.Acks)
|
||||
if pk.RpcId == "" {
|
||||
c.ackResp(pk.SeqNum)
|
||||
log.Printf("RpcClient.handleResp() missing rpcId: %v", pk)
|
||||
return
|
||||
}
|
||||
rpcInfo := c.getRpcInfo(pk.RpcId)
|
||||
if rpcInfo == nil {
|
||||
c.ackResp(pk.SeqNum)
|
||||
log.Printf("RpcClient.handleResp() unknown rpcId: %v", pk)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case rpcInfo.RespCh <- pk:
|
||||
default:
|
||||
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
|
||||
}
|
||||
if pk.RespDone {
|
||||
c.removeReqInfo(pk.RpcId, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RpcClient) grabAcks() []int64 {
|
||||
c.CVar.L.Lock()
|
||||
defer c.CVar.L.Unlock()
|
||||
acks := c.AckList
|
||||
c.AckList = nil
|
||||
return acks
|
||||
}
|
||||
|
||||
func (c *RpcClient) ackResp(seqNum int64) {
|
||||
if seqNum == 0 {
|
||||
return
|
||||
}
|
||||
c.CVar.L.Lock()
|
||||
defer c.CVar.L.Unlock()
|
||||
c.AckList = append(c.AckList, seqNum)
|
||||
}
|
||||
|
||||
func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, error) {
|
||||
c.CVar.L.Lock()
|
||||
defer c.CVar.L.Unlock()
|
||||
// issue with ctx timeout sync -- we need the cvar to be signaled fairly regularly so we can check ctx.Err()
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if len(c.RpcReqs) >= MaxOpenRpcs {
|
||||
c.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
if len(c.ReqPacketsInFlight) >= MaxOpenRpcs {
|
||||
c.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
|
||||
if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc {
|
||||
c.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
select {
|
||||
case c.SendCh <- req:
|
||||
default:
|
||||
return nil, errors.New("SendCh Full")
|
||||
}
|
||||
c.ReqPacketsInFlight[req.SeqNum] = req.RpcId
|
||||
rpcInfo := c.RpcReqs[req.RpcId]
|
||||
if rpcInfo == nil {
|
||||
rpcInfo = &RpcInfo{
|
||||
CloseSync: &sync.Once{},
|
||||
RpcId: req.RpcId,
|
||||
ReqPacketsInFlight: make(map[int64]bool),
|
||||
RespCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||
}
|
||||
rpcInfo.ReqPacketsInFlight[req.SeqNum] = true
|
||||
c.RpcReqs[req.RpcId] = rpcInfo
|
||||
}
|
||||
return rpcInfo, nil
|
||||
}
|
||||
|
||||
func (c *RpcClient) handleAcks(acks []int64) {
|
||||
if len(acks) == 0 {
|
||||
return
|
||||
}
|
||||
c.CVar.L.Lock()
|
||||
defer c.CVar.L.Unlock()
|
||||
for _, ack := range acks {
|
||||
rpcId, ok := c.ReqPacketsInFlight[ack]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
rpcInfo := c.RpcReqs[rpcId]
|
||||
if rpcInfo != nil {
|
||||
delete(rpcInfo.ReqPacketsInFlight, ack)
|
||||
}
|
||||
delete(c.ReqPacketsInFlight, ack)
|
||||
}
|
||||
c.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
|
||||
c.CVar.L.Lock()
|
||||
defer c.CVar.L.Unlock()
|
||||
rpcInfo := c.RpcReqs[rpcId]
|
||||
delete(c.RpcReqs, rpcId)
|
||||
if rpcInfo != nil {
|
||||
if clearSend {
|
||||
// unblock the recv loop if it happens to be waiting
|
||||
// because the delete has already happens, it will not be able to send again on the channel
|
||||
select {
|
||||
case <-rpcInfo.RespCh:
|
||||
default:
|
||||
}
|
||||
}
|
||||
rpcInfo.CloseSync.Do(func() {
|
||||
close(rpcInfo.RespCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (any, error) {
|
||||
rpcId := uuid.New().String()
|
||||
seqNum := c.NextSeqNum.Add(1)
|
||||
var timeoutInfo *TimeoutInfo
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
timeoutInfo = &TimeoutInfo{Deadline: deadline.UnixMilli()}
|
||||
}
|
||||
req := &RpcPacket{
|
||||
Command: command,
|
||||
RpcId: rpcId,
|
||||
RpcType: RpcType_Req,
|
||||
SeqNum: seqNum,
|
||||
ReqDone: true,
|
||||
Acks: c.grabAcks(),
|
||||
Timeout: timeoutInfo,
|
||||
Data: data,
|
||||
}
|
||||
rpcInfo, err := c.waitForReq(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.removeReqInfo(rpcId, true)
|
||||
var rtnPacket *RpcPacket
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case rtnPacket = <-rpcInfo.RespCh:
|
||||
// fallthrough
|
||||
}
|
||||
if rtnPacket.Error != "" {
|
||||
return nil, errors.New(rtnPacket.Error)
|
||||
}
|
||||
return rtnPacket.Data, nil
|
||||
}
|
||||
|
||||
func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, respTimeout time.Duration) (chan *RpcPacket, error) {
|
||||
rpcId := uuid.New().String()
|
||||
seqNum := c.NextSeqNum.Add(1)
|
||||
var timeoutInfo *TimeoutInfo = &TimeoutInfo{RespPacketTimeout: respTimeout.Milliseconds()}
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
timeoutInfo.Deadline = deadline.UnixMilli()
|
||||
}
|
||||
req := &RpcPacket{
|
||||
Command: command,
|
||||
RpcId: rpcId,
|
||||
RpcType: RpcType_Req,
|
||||
SeqNum: seqNum,
|
||||
ReqDone: true,
|
||||
Acks: c.grabAcks(),
|
||||
Timeout: timeoutInfo,
|
||||
Data: data,
|
||||
}
|
||||
rpcInfo, err := c.waitForReq(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rpcInfo.RespCh, nil
|
||||
}
|
||||
|
||||
func (c *RpcClient) EndStreamReq(rpcId string) {
|
||||
c.removeReqInfo(rpcId, true)
|
||||
}
|
115
pkg/wshrpc/rpc_test.go
Normal file
115
pkg/wshrpc/rpc_test.go
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshprc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
client := MakeRpcClient(sendCh, recvCh)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := client.SimpleReq(ctx, "test", "hello")
|
||||
if err != nil {
|
||||
t.Errorf("SimpleReq() failed: %v", err)
|
||||
return
|
||||
}
|
||||
if resp != "world" {
|
||||
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := <-sendCh
|
||||
if req.Command != "test" {
|
||||
t.Errorf("expected 'test', got '%s'", req.Command)
|
||||
}
|
||||
if req.Data != "hello" {
|
||||
t.Errorf("expected 'hello', got '%s'", req.Data)
|
||||
}
|
||||
resp := &RpcPacket{
|
||||
Command: "test",
|
||||
RpcId: req.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: 1,
|
||||
RespDone: true,
|
||||
Acks: []int64{req.SeqNum},
|
||||
Data: "world",
|
||||
}
|
||||
recvCh <- resp
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func makeRpcResp(req *RpcPacket, data any, seqNum int64, done bool) *RpcPacket {
|
||||
return &RpcPacket{
|
||||
Command: req.Command,
|
||||
RpcId: req.RpcId,
|
||||
RpcType: RpcType_Resp,
|
||||
SeqNum: seqNum,
|
||||
RespDone: done,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||
client := MakeRpcClient(sendCh, recvCh)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
respCh, err := client.StreamReq(ctx, "test", "hello", 1000)
|
||||
if err != nil {
|
||||
t.Errorf("StreamReq() failed: %v", err)
|
||||
return
|
||||
}
|
||||
var output []string
|
||||
for resp := range respCh {
|
||||
if resp.Error != "" {
|
||||
t.Errorf("StreamReq() failed: %v", resp.Error)
|
||||
return
|
||||
}
|
||||
output = append(output, resp.Data.(string))
|
||||
}
|
||||
if len(output) != 3 {
|
||||
t.Errorf("expected 3 responses, got %d (%v)", len(output), output)
|
||||
return
|
||||
}
|
||||
if output[0] != "one" || output[1] != "two" || output[2] != "three" {
|
||||
t.Errorf("expected 'one', 'two', 'three', got %v", output)
|
||||
return
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := <-sendCh
|
||||
if req.Command != "test" {
|
||||
t.Errorf("expected 'test', got '%s'", req.Command)
|
||||
}
|
||||
if req.Data != "hello" {
|
||||
t.Errorf("expected 'hello', got '%s'", req.Data)
|
||||
}
|
||||
resp := makeRpcResp(req, "one", 1, false)
|
||||
recvCh <- resp
|
||||
resp = makeRpcResp(req, "two", 2, false)
|
||||
recvCh <- resp
|
||||
resp = makeRpcResp(req, "three", 3, true)
|
||||
recvCh <- resp
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
58
pkg/wshrpc/wshrpc.go
Normal file
58
pkg/wshrpc/wshrpc.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshprc
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxOpenRpcs = 10
|
||||
MaxUnackedPerRpc = 10
|
||||
MaxInFlightPackets = MaxOpenRpcs * MaxUnackedPerRpc
|
||||
)
|
||||
|
||||
const (
|
||||
RpcType_Req = "req"
|
||||
RpcType_Resp = "resp"
|
||||
)
|
||||
|
||||
const (
|
||||
CommandType_Ack = ":ack"
|
||||
CommandType_Ping = ":ping"
|
||||
CommandType_Cancel = ":cancel"
|
||||
CommandType_Timeout = ":timeout"
|
||||
)
|
||||
|
||||
var rpcClientContextKey = struct{}{}
|
||||
|
||||
type TimeoutInfo struct {
|
||||
Deadline int64 `json:"deadline,omitempty"`
|
||||
ReqPacketTimeout int64 `json:"reqpackettimeout,omitempty"` // for streaming requests
|
||||
RespPacketTimeout int64 `json:"resppackettimeout,omitempty"` // for streaming responses
|
||||
}
|
||||
|
||||
type RpcPacket struct {
|
||||
Command string `json:"command"`
|
||||
RpcId string `json:"rpcid"`
|
||||
RpcType string `json:"rpctype"`
|
||||
SeqNum int64 `json:"seqnum"`
|
||||
ReqDone bool `json:"reqdone"`
|
||||
RespDone bool `json:"resdone"`
|
||||
Acks []int64 `json:"acks,omitempty"` // seqnums acked
|
||||
Timeout *TimeoutInfo `json:"timeout,omitempty"` // for initial request only
|
||||
Data any `json:"data"` // json data for command
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func GetRpcClient(ctx context.Context) *RpcClient {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
val := ctx.Value(rpcClientContextKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(*RpcClient)
|
||||
}
|
@ -9,6 +9,7 @@ const (
|
||||
CommandSetView = "setview"
|
||||
CommandSetMeta = "setmeta"
|
||||
CommandBlockFileAppend = "blockfile:append"
|
||||
CommandStreamFile = "streamfile"
|
||||
)
|
||||
|
||||
var CommandToTypeMap = map[string]reflect.Type{
|
||||
@ -23,6 +24,8 @@ type Command interface {
|
||||
// for unmarshalling
|
||||
type baseCommand struct {
|
||||
Command string `json:"command"`
|
||||
RpcID string `json:"rpcid"`
|
||||
RpcType string `json:"rpctype"`
|
||||
}
|
||||
|
||||
type SetViewCommand struct {
|
||||
@ -52,3 +55,12 @@ type BlockFileAppendCommand struct {
|
||||
func (bfac *BlockFileAppendCommand) GetCommand() string {
|
||||
return CommandBlockFileAppend
|
||||
}
|
||||
|
||||
type StreamFileCommand struct {
|
||||
Command string `json:"command"`
|
||||
FileName string `json:"filename"`
|
||||
}
|
||||
|
||||
func (c *StreamFileCommand) GetCommand() string {
|
||||
return CommandStreamFile
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user