mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-02 18:39:05 +01:00
wsh rpc client
This commit is contained in:
parent
2472deb379
commit
45f20bb5c3
@ -8,8 +8,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
const WaveOSC = "23198"
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
barr, err := os.ReadFile("/Users/mike/Downloads/2.png")
|
barr, err := os.ReadFile("/Users/mike/Downloads/2.png")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
264
pkg/wshrpc/rpc_client.go
Normal file
264
pkg/wshrpc/rpc_client.go
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wshprc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"runtime/debug"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// there is a single go-routine that reads from RecvCh
|
||||||
|
type RpcClient struct {
|
||||||
|
CVar *sync.Cond
|
||||||
|
NextSeqNum *atomic.Int64
|
||||||
|
ReqPacketsInFlight map[int64]string // seqnum -> rpcId
|
||||||
|
AckList []int64
|
||||||
|
RpcReqs map[string]*RpcInfo
|
||||||
|
SendCh chan *RpcPacket
|
||||||
|
RecvCh chan *RpcPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
type RpcInfo struct {
|
||||||
|
CloseSync *sync.Once
|
||||||
|
RpcId string
|
||||||
|
ReqPacketsInFlight map[int64]bool // seqnum -> bool
|
||||||
|
RespCh chan *RpcPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
|
||||||
|
if cap(sendCh) < MaxInFlightPackets {
|
||||||
|
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
|
||||||
|
}
|
||||||
|
rtn := &RpcClient{
|
||||||
|
CVar: sync.NewCond(&sync.Mutex{}),
|
||||||
|
NextSeqNum: &atomic.Int64{},
|
||||||
|
ReqPacketsInFlight: make(map[int64]string),
|
||||||
|
AckList: nil,
|
||||||
|
RpcReqs: make(map[string]*RpcInfo),
|
||||||
|
SendCh: sendCh,
|
||||||
|
RecvCh: recvCh,
|
||||||
|
}
|
||||||
|
go rtn.runRecvLoop()
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) runRecvLoop() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("RpcClient.runRecvLoop() panic: %v", r)
|
||||||
|
debug.PrintStack()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for pk := range c.RecvCh {
|
||||||
|
if pk.RpcType == RpcType_Resp {
|
||||||
|
c.handleResp(pk)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
|
||||||
|
}
|
||||||
|
log.Printf("RpcClient.runRecvLoop() normal exit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) getRpcInfo(rpcId string) *RpcInfo {
|
||||||
|
c.CVar.L.Lock()
|
||||||
|
defer c.CVar.L.Unlock()
|
||||||
|
return c.RpcReqs[rpcId]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) handleResp(pk *RpcPacket) {
|
||||||
|
c.handleAcks(pk.Acks)
|
||||||
|
if pk.RpcId == "" {
|
||||||
|
c.ackResp(pk.SeqNum)
|
||||||
|
log.Printf("RpcClient.handleResp() missing rpcId: %v", pk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rpcInfo := c.getRpcInfo(pk.RpcId)
|
||||||
|
if rpcInfo == nil {
|
||||||
|
c.ackResp(pk.SeqNum)
|
||||||
|
log.Printf("RpcClient.handleResp() unknown rpcId: %v", pk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case rpcInfo.RespCh <- pk:
|
||||||
|
default:
|
||||||
|
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
|
||||||
|
}
|
||||||
|
if pk.RespDone {
|
||||||
|
c.removeReqInfo(pk.RpcId, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) grabAcks() []int64 {
|
||||||
|
c.CVar.L.Lock()
|
||||||
|
defer c.CVar.L.Unlock()
|
||||||
|
acks := c.AckList
|
||||||
|
c.AckList = nil
|
||||||
|
return acks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) ackResp(seqNum int64) {
|
||||||
|
if seqNum == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.CVar.L.Lock()
|
||||||
|
defer c.CVar.L.Unlock()
|
||||||
|
c.AckList = append(c.AckList, seqNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, error) {
|
||||||
|
c.CVar.L.Lock()
|
||||||
|
defer c.CVar.L.Unlock()
|
||||||
|
// issue with ctx timeout sync -- we need the cvar to be signaled fairly regularly so we can check ctx.Err()
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
if len(c.RpcReqs) >= MaxOpenRpcs {
|
||||||
|
c.CVar.Wait()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(c.ReqPacketsInFlight) >= MaxOpenRpcs {
|
||||||
|
c.CVar.Wait()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
|
||||||
|
if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc {
|
||||||
|
c.CVar.Wait()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case c.SendCh <- req:
|
||||||
|
default:
|
||||||
|
return nil, errors.New("SendCh Full")
|
||||||
|
}
|
||||||
|
c.ReqPacketsInFlight[req.SeqNum] = req.RpcId
|
||||||
|
rpcInfo := c.RpcReqs[req.RpcId]
|
||||||
|
if rpcInfo == nil {
|
||||||
|
rpcInfo = &RpcInfo{
|
||||||
|
CloseSync: &sync.Once{},
|
||||||
|
RpcId: req.RpcId,
|
||||||
|
ReqPacketsInFlight: make(map[int64]bool),
|
||||||
|
RespCh: make(chan *RpcPacket, MaxUnackedPerRpc),
|
||||||
|
}
|
||||||
|
rpcInfo.ReqPacketsInFlight[req.SeqNum] = true
|
||||||
|
c.RpcReqs[req.RpcId] = rpcInfo
|
||||||
|
}
|
||||||
|
return rpcInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) handleAcks(acks []int64) {
|
||||||
|
if len(acks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.CVar.L.Lock()
|
||||||
|
defer c.CVar.L.Unlock()
|
||||||
|
for _, ack := range acks {
|
||||||
|
rpcId, ok := c.ReqPacketsInFlight[ack]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rpcInfo := c.RpcReqs[rpcId]
|
||||||
|
if rpcInfo != nil {
|
||||||
|
delete(rpcInfo.ReqPacketsInFlight, ack)
|
||||||
|
}
|
||||||
|
delete(c.ReqPacketsInFlight, ack)
|
||||||
|
}
|
||||||
|
c.CVar.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
|
||||||
|
c.CVar.L.Lock()
|
||||||
|
defer c.CVar.L.Unlock()
|
||||||
|
rpcInfo := c.RpcReqs[rpcId]
|
||||||
|
delete(c.RpcReqs, rpcId)
|
||||||
|
if rpcInfo != nil {
|
||||||
|
if clearSend {
|
||||||
|
// unblock the recv loop if it happens to be waiting
|
||||||
|
// because the delete has already happens, it will not be able to send again on the channel
|
||||||
|
select {
|
||||||
|
case <-rpcInfo.RespCh:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rpcInfo.CloseSync.Do(func() {
|
||||||
|
close(rpcInfo.RespCh)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (any, error) {
|
||||||
|
rpcId := uuid.New().String()
|
||||||
|
seqNum := c.NextSeqNum.Add(1)
|
||||||
|
var timeoutInfo *TimeoutInfo
|
||||||
|
deadline, ok := ctx.Deadline()
|
||||||
|
if ok {
|
||||||
|
timeoutInfo = &TimeoutInfo{Deadline: deadline.UnixMilli()}
|
||||||
|
}
|
||||||
|
req := &RpcPacket{
|
||||||
|
Command: command,
|
||||||
|
RpcId: rpcId,
|
||||||
|
RpcType: RpcType_Req,
|
||||||
|
SeqNum: seqNum,
|
||||||
|
ReqDone: true,
|
||||||
|
Acks: c.grabAcks(),
|
||||||
|
Timeout: timeoutInfo,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
rpcInfo, err := c.waitForReq(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer c.removeReqInfo(rpcId, true)
|
||||||
|
var rtnPacket *RpcPacket
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case rtnPacket = <-rpcInfo.RespCh:
|
||||||
|
// fallthrough
|
||||||
|
}
|
||||||
|
if rtnPacket.Error != "" {
|
||||||
|
return nil, errors.New(rtnPacket.Error)
|
||||||
|
}
|
||||||
|
return rtnPacket.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, respTimeout time.Duration) (chan *RpcPacket, error) {
|
||||||
|
rpcId := uuid.New().String()
|
||||||
|
seqNum := c.NextSeqNum.Add(1)
|
||||||
|
var timeoutInfo *TimeoutInfo = &TimeoutInfo{RespPacketTimeout: respTimeout.Milliseconds()}
|
||||||
|
deadline, ok := ctx.Deadline()
|
||||||
|
if ok {
|
||||||
|
timeoutInfo.Deadline = deadline.UnixMilli()
|
||||||
|
}
|
||||||
|
req := &RpcPacket{
|
||||||
|
Command: command,
|
||||||
|
RpcId: rpcId,
|
||||||
|
RpcType: RpcType_Req,
|
||||||
|
SeqNum: seqNum,
|
||||||
|
ReqDone: true,
|
||||||
|
Acks: c.grabAcks(),
|
||||||
|
Timeout: timeoutInfo,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
rpcInfo, err := c.waitForReq(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rpcInfo.RespCh, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RpcClient) EndStreamReq(rpcId string) {
|
||||||
|
c.removeReqInfo(rpcId, true)
|
||||||
|
}
|
115
pkg/wshrpc/rpc_test.go
Normal file
115
pkg/wshrpc/rpc_test.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wshprc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSimple(t *testing.T) {
|
||||||
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
client := MakeRpcClient(sendCh, recvCh)
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := client.SimpleReq(ctx, "test", "hello")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("SimpleReq() failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp != "world" {
|
||||||
|
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
req := <-sendCh
|
||||||
|
if req.Command != "test" {
|
||||||
|
t.Errorf("expected 'test', got '%s'", req.Command)
|
||||||
|
}
|
||||||
|
if req.Data != "hello" {
|
||||||
|
t.Errorf("expected 'hello', got '%s'", req.Data)
|
||||||
|
}
|
||||||
|
resp := &RpcPacket{
|
||||||
|
Command: "test",
|
||||||
|
RpcId: req.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: 1,
|
||||||
|
RespDone: true,
|
||||||
|
Acks: []int64{req.SeqNum},
|
||||||
|
Data: "world",
|
||||||
|
}
|
||||||
|
recvCh <- resp
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeRpcResp(req *RpcPacket, data any, seqNum int64, done bool) *RpcPacket {
|
||||||
|
return &RpcPacket{
|
||||||
|
Command: req.Command,
|
||||||
|
RpcId: req.RpcId,
|
||||||
|
RpcType: RpcType_Resp,
|
||||||
|
SeqNum: seqNum,
|
||||||
|
RespDone: done,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStream(t *testing.T) {
|
||||||
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
||||||
|
client := MakeRpcClient(sendCh, recvCh)
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
respCh, err := client.StreamReq(ctx, "test", "hello", 1000)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("StreamReq() failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var output []string
|
||||||
|
for resp := range respCh {
|
||||||
|
if resp.Error != "" {
|
||||||
|
t.Errorf("StreamReq() failed: %v", resp.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
output = append(output, resp.Data.(string))
|
||||||
|
}
|
||||||
|
if len(output) != 3 {
|
||||||
|
t.Errorf("expected 3 responses, got %d (%v)", len(output), output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if output[0] != "one" || output[1] != "two" || output[2] != "three" {
|
||||||
|
t.Errorf("expected 'one', 'two', 'three', got %v", output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
req := <-sendCh
|
||||||
|
if req.Command != "test" {
|
||||||
|
t.Errorf("expected 'test', got '%s'", req.Command)
|
||||||
|
}
|
||||||
|
if req.Data != "hello" {
|
||||||
|
t.Errorf("expected 'hello', got '%s'", req.Data)
|
||||||
|
}
|
||||||
|
resp := makeRpcResp(req, "one", 1, false)
|
||||||
|
recvCh <- resp
|
||||||
|
resp = makeRpcResp(req, "two", 2, false)
|
||||||
|
recvCh <- resp
|
||||||
|
resp = makeRpcResp(req, "three", 3, true)
|
||||||
|
recvCh <- resp
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
58
pkg/wshrpc/wshrpc.go
Normal file
58
pkg/wshrpc/wshrpc.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wshprc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxOpenRpcs = 10
|
||||||
|
MaxUnackedPerRpc = 10
|
||||||
|
MaxInFlightPackets = MaxOpenRpcs * MaxUnackedPerRpc
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RpcType_Req = "req"
|
||||||
|
RpcType_Resp = "resp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CommandType_Ack = ":ack"
|
||||||
|
CommandType_Ping = ":ping"
|
||||||
|
CommandType_Cancel = ":cancel"
|
||||||
|
CommandType_Timeout = ":timeout"
|
||||||
|
)
|
||||||
|
|
||||||
|
var rpcClientContextKey = struct{}{}
|
||||||
|
|
||||||
|
type TimeoutInfo struct {
|
||||||
|
Deadline int64 `json:"deadline,omitempty"`
|
||||||
|
ReqPacketTimeout int64 `json:"reqpackettimeout,omitempty"` // for streaming requests
|
||||||
|
RespPacketTimeout int64 `json:"resppackettimeout,omitempty"` // for streaming responses
|
||||||
|
}
|
||||||
|
|
||||||
|
type RpcPacket struct {
|
||||||
|
Command string `json:"command"`
|
||||||
|
RpcId string `json:"rpcid"`
|
||||||
|
RpcType string `json:"rpctype"`
|
||||||
|
SeqNum int64 `json:"seqnum"`
|
||||||
|
ReqDone bool `json:"reqdone"`
|
||||||
|
RespDone bool `json:"resdone"`
|
||||||
|
Acks []int64 `json:"acks,omitempty"` // seqnums acked
|
||||||
|
Timeout *TimeoutInfo `json:"timeout,omitempty"` // for initial request only
|
||||||
|
Data any `json:"data"` // json data for command
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRpcClient(ctx context.Context) *RpcClient {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
val := ctx.Value(rpcClientContextKey)
|
||||||
|
if val == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return val.(*RpcClient)
|
||||||
|
}
|
@ -9,6 +9,7 @@ const (
|
|||||||
CommandSetView = "setview"
|
CommandSetView = "setview"
|
||||||
CommandSetMeta = "setmeta"
|
CommandSetMeta = "setmeta"
|
||||||
CommandBlockFileAppend = "blockfile:append"
|
CommandBlockFileAppend = "blockfile:append"
|
||||||
|
CommandStreamFile = "streamfile"
|
||||||
)
|
)
|
||||||
|
|
||||||
var CommandToTypeMap = map[string]reflect.Type{
|
var CommandToTypeMap = map[string]reflect.Type{
|
||||||
@ -23,6 +24,8 @@ type Command interface {
|
|||||||
// for unmarshalling
|
// for unmarshalling
|
||||||
type baseCommand struct {
|
type baseCommand struct {
|
||||||
Command string `json:"command"`
|
Command string `json:"command"`
|
||||||
|
RpcID string `json:"rpcid"`
|
||||||
|
RpcType string `json:"rpctype"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SetViewCommand struct {
|
type SetViewCommand struct {
|
||||||
@ -52,3 +55,12 @@ type BlockFileAppendCommand struct {
|
|||||||
func (bfac *BlockFileAppendCommand) GetCommand() string {
|
func (bfac *BlockFileAppendCommand) GetCommand() string {
|
||||||
return CommandBlockFileAppend
|
return CommandBlockFileAppend
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamFileCommand struct {
|
||||||
|
Command string `json:"command"`
|
||||||
|
FileName string `json:"filename"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StreamFileCommand) GetCommand() string {
|
||||||
|
return CommandStreamFile
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user