mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-17 20:51:55 +01:00
195 lines
4.9 KiB
Go
195 lines
4.9 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package wshprc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestSimple(t *testing.T) {
|
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
client := MakeRpcClient(sendCh, recvCh)
|
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancelFn()
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
resp, err := client.SimpleReq(ctx, "test", "hello")
|
|
if err != nil {
|
|
t.Errorf("SimpleReq() failed: %v", err)
|
|
return
|
|
}
|
|
if resp != "world" {
|
|
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
|
|
}
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
req := <-sendCh
|
|
if req.Command != "test" {
|
|
t.Errorf("expected 'test', got '%s'", req.Command)
|
|
}
|
|
if req.Data != "hello" {
|
|
t.Errorf("expected 'hello', got '%s'", req.Data)
|
|
}
|
|
resp := &RpcPacket{
|
|
Command: "test",
|
|
RpcId: req.RpcId,
|
|
RpcType: RpcType_Resp,
|
|
SeqNum: 1,
|
|
RespDone: true,
|
|
Acks: []int64{req.SeqNum},
|
|
Data: "world",
|
|
}
|
|
recvCh <- resp
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
func makeRpcResp(req *RpcPacket, data any, seqNum int64, done bool) *RpcPacket {
|
|
return &RpcPacket{
|
|
Command: req.Command,
|
|
RpcId: req.RpcId,
|
|
RpcType: RpcType_Resp,
|
|
SeqNum: seqNum,
|
|
RespDone: done,
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
func TestStream(t *testing.T) {
|
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
client := MakeRpcClient(sendCh, recvCh)
|
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancelFn()
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
respCh, err := client.StreamReq(ctx, "test", "hello", 1000)
|
|
if err != nil {
|
|
t.Errorf("StreamReq() failed: %v", err)
|
|
return
|
|
}
|
|
var output []string
|
|
for resp := range respCh {
|
|
if resp.Error != "" {
|
|
t.Errorf("StreamReq() failed: %v", resp.Error)
|
|
return
|
|
}
|
|
output = append(output, resp.Data.(string))
|
|
}
|
|
if len(output) != 3 {
|
|
t.Errorf("expected 3 responses, got %d (%v)", len(output), output)
|
|
return
|
|
}
|
|
if output[0] != "one" || output[1] != "two" || output[2] != "three" {
|
|
t.Errorf("expected 'one', 'two', 'three', got %v", output)
|
|
return
|
|
}
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
req := <-sendCh
|
|
if req.Command != "test" {
|
|
t.Errorf("expected 'test', got '%s'", req.Command)
|
|
}
|
|
if req.Data != "hello" {
|
|
t.Errorf("expected 'hello', got '%s'", req.Data)
|
|
}
|
|
resp := makeRpcResp(req, "one", 1, false)
|
|
recvCh <- resp
|
|
resp = makeRpcResp(req, "two", 2, false)
|
|
recvCh <- resp
|
|
resp = makeRpcResp(req, "three", 3, true)
|
|
recvCh <- resp
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestSimpleClientServer(t *testing.T) {
|
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
client := MakeRpcClient(sendCh, recvCh)
|
|
server := MakeRpcServer(recvCh, sendCh)
|
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancelFn()
|
|
server.RegisterSimpleCommandHandler("test", func(ctx context.Context, s *RpcServer, cmd string, data any) (any, error) {
|
|
if data != "hello" {
|
|
return nil, fmt.Errorf("expected 'hello', got '%s'", data)
|
|
}
|
|
return "world", nil
|
|
})
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
resp, err := client.SimpleReq(ctx, "test", "hello")
|
|
if err != nil {
|
|
t.Errorf("SimpleReq() failed: %v", err)
|
|
return
|
|
}
|
|
if resp != "world" {
|
|
t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp)
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
|
|
}
|
|
|
|
func TestStreamClientServer(t *testing.T) {
|
|
sendCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
recvCh := make(chan *RpcPacket, MaxInFlightPackets)
|
|
client := MakeRpcClient(sendCh, recvCh)
|
|
server := MakeRpcServer(recvCh, sendCh)
|
|
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancelFn()
|
|
server.RegisterStreamCommandHandler("test", func(ctx context.Context, s *RpcServer, req *RpcPacket) error {
|
|
pk1 := s.makeRespPk(req, "one", false)
|
|
pk2 := s.makeRespPk(req, "two", false)
|
|
pk3 := s.makeRespPk(req, "three", true)
|
|
s.SendResponse(ctx, pk1)
|
|
s.SendResponse(ctx, pk2)
|
|
s.SendResponse(ctx, pk3)
|
|
return nil
|
|
})
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
respCh, err := client.StreamReq(ctx, "test", "hello", 2*time.Second)
|
|
if err != nil {
|
|
t.Errorf("StreamReq() failed: %v", err)
|
|
return
|
|
}
|
|
var result []string
|
|
for respPk := range respCh {
|
|
if respPk.Error != "" {
|
|
t.Errorf("StreamReq() failed: %v", respPk.Error)
|
|
return
|
|
}
|
|
log.Printf("got response: %#v", respPk)
|
|
result = append(result, respPk.Data.(string))
|
|
}
|
|
if len(result) != 3 {
|
|
t.Errorf("expected 3 responses, got %d", len(result))
|
|
return
|
|
}
|
|
if result[0] != "one" || result[1] != "two" || result[2] != "three" {
|
|
t.Errorf("expected 'one', 'two', 'three', got %v", result)
|
|
return
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
|
|
}
|