wsh rpc client

This commit is contained in:
sawka 2024-05-29 23:17:23 -07:00
parent 2472deb379
commit 45f20bb5c3
5 changed files with 449 additions and 2 deletions

View File

@ -8,8 +8,6 @@ import (
"os"
)
const WaveOSC = "23198"
func main() {
barr, err := os.ReadFile("/Users/mike/Downloads/2.png")
if err != nil {

264
pkg/wshrpc/rpc_client.go Normal file
View File

@ -0,0 +1,264 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshprc
import (
"context"
"errors"
"fmt"
"log"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
)
// there is a single go-routine that reads from RecvCh
type RpcClient struct {
CVar *sync.Cond
NextSeqNum *atomic.Int64
ReqPacketsInFlight map[int64]string // seqnum -> rpcId
AckList []int64
RpcReqs map[string]*RpcInfo
SendCh chan *RpcPacket
RecvCh chan *RpcPacket
}
type RpcInfo struct {
CloseSync *sync.Once
RpcId string
ReqPacketsInFlight map[int64]bool // seqnum -> bool
RespCh chan *RpcPacket
}
func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
if cap(sendCh) < MaxInFlightPackets {
panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
}
rtn := &RpcClient{
CVar: sync.NewCond(&sync.Mutex{}),
NextSeqNum: &atomic.Int64{},
ReqPacketsInFlight: make(map[int64]string),
AckList: nil,
RpcReqs: make(map[string]*RpcInfo),
SendCh: sendCh,
RecvCh: recvCh,
}
go rtn.runRecvLoop()
return rtn
}
func (c *RpcClient) runRecvLoop() {
defer func() {
if r := recover(); r != nil {
log.Printf("RpcClient.runRecvLoop() panic: %v", r)
debug.PrintStack()
}
}()
for pk := range c.RecvCh {
if pk.RpcType == RpcType_Resp {
c.handleResp(pk)
continue
}
log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
}
log.Printf("RpcClient.runRecvLoop() normal exit")
}
func (c *RpcClient) getRpcInfo(rpcId string) *RpcInfo {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
return c.RpcReqs[rpcId]
}
func (c *RpcClient) handleResp(pk *RpcPacket) {
c.handleAcks(pk.Acks)
if pk.RpcId == "" {
c.ackResp(pk.SeqNum)
log.Printf("RpcClient.handleResp() missing rpcId: %v", pk)
return
}
rpcInfo := c.getRpcInfo(pk.RpcId)
if rpcInfo == nil {
c.ackResp(pk.SeqNum)
log.Printf("RpcClient.handleResp() unknown rpcId: %v", pk)
return
}
select {
case rpcInfo.RespCh <- pk:
default:
log.Printf("RpcClient.handleResp() respCh full, dropping packet")
}
if pk.RespDone {
c.removeReqInfo(pk.RpcId, false)
}
}
func (c *RpcClient) grabAcks() []int64 {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
acks := c.AckList
c.AckList = nil
return acks
}
func (c *RpcClient) ackResp(seqNum int64) {
if seqNum == 0 {
return
}
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
c.AckList = append(c.AckList, seqNum)
}
func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, error) {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
// issue with ctx timeout sync -- we need the cvar to be signaled fairly regularly so we can check ctx.Err()
for {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if len(c.RpcReqs) >= MaxOpenRpcs {
c.CVar.Wait()
continue
}
if len(c.ReqPacketsInFlight) >= MaxOpenRpcs {
c.CVar.Wait()
continue
}
if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc {
c.CVar.Wait()
continue
}
}
break
}
select {
case c.SendCh <- req:
default:
return nil, errors.New("SendCh Full")
}
c.ReqPacketsInFlight[req.SeqNum] = req.RpcId
rpcInfo := c.RpcReqs[req.RpcId]
if rpcInfo == nil {
rpcInfo = &RpcInfo{
CloseSync: &sync.Once{},
RpcId: req.RpcId,
ReqPacketsInFlight: make(map[int64]bool),
RespCh: make(chan *RpcPacket, MaxUnackedPerRpc),
}
rpcInfo.ReqPacketsInFlight[req.SeqNum] = true
c.RpcReqs[req.RpcId] = rpcInfo
}
return rpcInfo, nil
}
func (c *RpcClient) handleAcks(acks []int64) {
if len(acks) == 0 {
return
}
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
for _, ack := range acks {
rpcId, ok := c.ReqPacketsInFlight[ack]
if !ok {
continue
}
rpcInfo := c.RpcReqs[rpcId]
if rpcInfo != nil {
delete(rpcInfo.ReqPacketsInFlight, ack)
}
delete(c.ReqPacketsInFlight, ack)
}
c.CVar.Broadcast()
}
func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
c.CVar.L.Lock()
defer c.CVar.L.Unlock()
rpcInfo := c.RpcReqs[rpcId]
delete(c.RpcReqs, rpcId)
if rpcInfo != nil {
if clearSend {
// unblock the recv loop if it happens to be waiting
// because the delete has already happens, it will not be able to send again on the channel
select {
case <-rpcInfo.RespCh:
default:
}
}
rpcInfo.CloseSync.Do(func() {
close(rpcInfo.RespCh)
})
}
}
func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (any, error) {
rpcId := uuid.New().String()
seqNum := c.NextSeqNum.Add(1)
var timeoutInfo *TimeoutInfo
deadline, ok := ctx.Deadline()
if ok {
timeoutInfo = &TimeoutInfo{Deadline: deadline.UnixMilli()}
}
req := &RpcPacket{
Command: command,
RpcId: rpcId,
RpcType: RpcType_Req,
SeqNum: seqNum,
ReqDone: true,
Acks: c.grabAcks(),
Timeout: timeoutInfo,
Data: data,
}
rpcInfo, err := c.waitForReq(ctx, req)
if err != nil {
return nil, err
}
defer c.removeReqInfo(rpcId, true)
var rtnPacket *RpcPacket
select {
case <-ctx.Done():
return nil, ctx.Err()
case rtnPacket = <-rpcInfo.RespCh:
// fallthrough
}
if rtnPacket.Error != "" {
return nil, errors.New(rtnPacket.Error)
}
return rtnPacket.Data, nil
}
func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, respTimeout time.Duration) (chan *RpcPacket, error) {
rpcId := uuid.New().String()
seqNum := c.NextSeqNum.Add(1)
var timeoutInfo *TimeoutInfo = &TimeoutInfo{RespPacketTimeout: respTimeout.Milliseconds()}
deadline, ok := ctx.Deadline()
if ok {
timeoutInfo.Deadline = deadline.UnixMilli()
}
req := &RpcPacket{
Command: command,
RpcId: rpcId,
RpcType: RpcType_Req,
SeqNum: seqNum,
ReqDone: true,
Acks: c.grabAcks(),
Timeout: timeoutInfo,
Data: data,
}
rpcInfo, err := c.waitForReq(ctx, req)
if err != nil {
return nil, err
}
return rpcInfo.RespCh, nil
}
func (c *RpcClient) EndStreamReq(rpcId string) {
c.removeReqInfo(rpcId, true)
}

115
pkg/wshrpc/rpc_test.go Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshprc
import (
"context"
"sync"
"testing"
"time"
)
func TestSimple(t *testing.T) {
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
client := MakeRpcClient(sendCh, recvCh)
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
resp, err := client.SimpleReq(ctx, "test", "hello")
if err != nil {
t.Errorf("SimpleReq() failed: %v", err)
return
}
if resp != "world" {
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
}
}()
go func() {
defer wg.Done()
req := <-sendCh
if req.Command != "test" {
t.Errorf("expected 'test', got '%s'", req.Command)
}
if req.Data != "hello" {
t.Errorf("expected 'hello', got '%s'", req.Data)
}
resp := &RpcPacket{
Command: "test",
RpcId: req.RpcId,
RpcType: RpcType_Resp,
SeqNum: 1,
RespDone: true,
Acks: []int64{req.SeqNum},
Data: "world",
}
recvCh <- resp
}()
wg.Wait()
}
func makeRpcResp(req *RpcPacket, data any, seqNum int64, done bool) *RpcPacket {
return &RpcPacket{
Command: req.Command,
RpcId: req.RpcId,
RpcType: RpcType_Resp,
SeqNum: seqNum,
RespDone: done,
Data: data,
}
}
func TestStream(t *testing.T) {
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
client := MakeRpcClient(sendCh, recvCh)
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
respCh, err := client.StreamReq(ctx, "test", "hello", 1000)
if err != nil {
t.Errorf("StreamReq() failed: %v", err)
return
}
var output []string
for resp := range respCh {
if resp.Error != "" {
t.Errorf("StreamReq() failed: %v", resp.Error)
return
}
output = append(output, resp.Data.(string))
}
if len(output) != 3 {
t.Errorf("expected 3 responses, got %d (%v)", len(output), output)
return
}
if output[0] != "one" || output[1] != "two" || output[2] != "three" {
t.Errorf("expected 'one', 'two', 'three', got %v", output)
return
}
}()
go func() {
defer wg.Done()
req := <-sendCh
if req.Command != "test" {
t.Errorf("expected 'test', got '%s'", req.Command)
}
if req.Data != "hello" {
t.Errorf("expected 'hello', got '%s'", req.Data)
}
resp := makeRpcResp(req, "one", 1, false)
recvCh <- resp
resp = makeRpcResp(req, "two", 2, false)
recvCh <- resp
resp = makeRpcResp(req, "three", 3, true)
recvCh <- resp
}()
wg.Wait()
}

58
pkg/wshrpc/wshrpc.go Normal file
View File

@ -0,0 +1,58 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshprc
import (
"context"
)
const (
MaxOpenRpcs = 10
MaxUnackedPerRpc = 10
MaxInFlightPackets = MaxOpenRpcs * MaxUnackedPerRpc
)
const (
RpcType_Req = "req"
RpcType_Resp = "resp"
)
const (
CommandType_Ack = ":ack"
CommandType_Ping = ":ping"
CommandType_Cancel = ":cancel"
CommandType_Timeout = ":timeout"
)
var rpcClientContextKey = struct{}{}
type TimeoutInfo struct {
Deadline int64 `json:"deadline,omitempty"`
ReqPacketTimeout int64 `json:"reqpackettimeout,omitempty"` // for streaming requests
RespPacketTimeout int64 `json:"resppackettimeout,omitempty"` // for streaming responses
}
type RpcPacket struct {
Command string `json:"command"`
RpcId string `json:"rpcid"`
RpcType string `json:"rpctype"`
SeqNum int64 `json:"seqnum"`
ReqDone bool `json:"reqdone"`
RespDone bool `json:"resdone"`
Acks []int64 `json:"acks,omitempty"` // seqnums acked
Timeout *TimeoutInfo `json:"timeout,omitempty"` // for initial request only
Data any `json:"data"` // json data for command
Error string `json:"error,omitempty"`
}
func GetRpcClient(ctx context.Context) *RpcClient {
if ctx == nil {
return nil
}
val := ctx.Value(rpcClientContextKey)
if val == nil {
return nil
}
return val.(*RpcClient)
}

View File

@ -9,6 +9,7 @@ const (
CommandSetView = "setview"
CommandSetMeta = "setmeta"
CommandBlockFileAppend = "blockfile:append"
CommandStreamFile = "streamfile"
)
var CommandToTypeMap = map[string]reflect.Type{
@ -23,6 +24,8 @@ type Command interface {
// for unmarshalling
type baseCommand struct {
Command string `json:"command"`
RpcID string `json:"rpcid"`
RpcType string `json:"rpctype"`
}
type SetViewCommand struct {
@ -52,3 +55,12 @@ type BlockFileAppendCommand struct {
func (bfac *BlockFileAppendCommand) GetCommand() string {
return CommandBlockFileAppend
}
type StreamFileCommand struct {
Command string `json:"command"`
FileName string `json:"filename"`
}
func (c *StreamFileCommand) GetCommand() string {
return CommandStreamFile
}