From 394b9dce2378d5c7b7c993096146740479e945a8 Mon Sep 17 00:00:00 2001
From: sawka <mike.sawka@gmail.com>
Date: Wed, 29 May 2024 23:58:29 -0700
Subject: [PATCH] working on rpc server

---
 pkg/wshrpc/rpc_client.go |  32 ++---
 pkg/wshrpc/rpc_server.go | 247 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 263 insertions(+), 16 deletions(-)
 create mode 100644 pkg/wshrpc/rpc_server.go

diff --git a/pkg/wshrpc/rpc_client.go b/pkg/wshrpc/rpc_client.go
index f0388d1e7..425fae7c6 100644
--- a/pkg/wshrpc/rpc_client.go
+++ b/pkg/wshrpc/rpc_client.go
@@ -28,10 +28,10 @@ type RpcClient struct {
 }
 
 type RpcInfo struct {
-	CloseSync          *sync.Once
-	RpcId              string
-	ReqPacketsInFlight map[int64]bool // seqnum -> bool
-	RespCh             chan *RpcPacket
+	CloseSync       *sync.Once
+	RpcId           string
+	PacketsInFlight map[int64]bool  // seqnum -> bool (for clients this is for requests, for servers it is for responses)
+	PkCh            chan *RpcPacket // for clients this is for responses, for servers it is for requests
 }
 
 func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient {
@@ -88,7 +88,7 @@ func (c *RpcClient) handleResp(pk *RpcPacket) {
 		return
 	}
 	select {
-	case rpcInfo.RespCh <- pk:
+	case rpcInfo.PkCh <- pk:
 	default:
 		log.Printf("RpcClient.handleResp() respCh full, dropping packet")
 	}
@@ -131,7 +131,7 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e
 			continue
 		}
 		if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok {
-			if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc {
+			if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
 				c.CVar.Wait()
 				continue
 			}
@@ -147,12 +147,12 @@ func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, e
 	rpcInfo := c.RpcReqs[req.RpcId]
 	if rpcInfo == nil {
 		rpcInfo = &RpcInfo{
-			CloseSync:          &sync.Once{},
-			RpcId:              req.RpcId,
-			ReqPacketsInFlight: make(map[int64]bool),
-			RespCh:             make(chan *RpcPacket, MaxUnackedPerRpc),
+			CloseSync:       &sync.Once{},
+			RpcId:           req.RpcId,
+			PacketsInFlight: make(map[int64]bool),
+			PkCh:            make(chan *RpcPacket, MaxUnackedPerRpc),
 		}
-		rpcInfo.ReqPacketsInFlight[req.SeqNum] = true
+		rpcInfo.PacketsInFlight[req.SeqNum] = true
 		c.RpcReqs[req.RpcId] = rpcInfo
 	}
 	return rpcInfo, nil
@@ -171,7 +171,7 @@ func (c *RpcClient) handleAcks(acks []int64) {
 		}
 		rpcInfo := c.RpcReqs[rpcId]
 		if rpcInfo != nil {
-			delete(rpcInfo.ReqPacketsInFlight, ack)
+			delete(rpcInfo.PacketsInFlight, ack)
 		}
 		delete(c.ReqPacketsInFlight, ack)
 	}
@@ -188,12 +188,12 @@ func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) {
 			// unblock the recv loop if it happens to be waiting
 			// because the delete has already happens, it will not be able to send again on the channel
 			select {
-			case <-rpcInfo.RespCh:
+			case <-rpcInfo.PkCh:
 			default:
 			}
 		}
 		rpcInfo.CloseSync.Do(func() {
-			close(rpcInfo.RespCh)
+			close(rpcInfo.PkCh)
 		})
 	}
 }
@@ -225,7 +225,7 @@ func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (an
 	select {
 	case <-ctx.Done():
 		return nil, ctx.Err()
-	case rtnPacket = <-rpcInfo.RespCh:
+	case rtnPacket = <-rpcInfo.PkCh:
 		// fallthrough
 	}
 	if rtnPacket.Error != "" {
@@ -256,7 +256,7 @@ func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, res
 	if err != nil {
 		return nil, err
 	}
-	return rpcInfo.RespCh, nil
+	return rpcInfo.PkCh, nil
 }
 
 func (c *RpcClient) EndStreamReq(rpcId string) {
diff --git a/pkg/wshrpc/rpc_server.go b/pkg/wshrpc/rpc_server.go
new file mode 100644
index 000000000..8ba3a1bb0
--- /dev/null
+++ b/pkg/wshrpc/rpc_server.go
@@ -0,0 +1,247 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package wshprc
+
+import (
+	"context"
+	"fmt"
+	"log"
+	"runtime/debug"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error)
+type StreamCommandHandlerFn func(context.Context, *RpcServer, string, any) error
+
+type RpcServer struct {
+	CVar                  *sync.Cond
+	NextSeqNum            *atomic.Int64
+	RespPacketsInFlight   map[int64]string // seqnum -> rpcId
+	AckList               []int64
+	RpcReqs               map[string]*RpcInfo
+	SendCh                chan *RpcPacket
+	RecvCh                chan *RpcPacket
+	SimpleCommandHandlers map[string]SimpleCommandHandlerFn
+	StreamCommandHandlers map[string]StreamCommandHandlerFn
+}
+
+func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer {
+	if cap(sendCh) < MaxInFlightPackets {
+		panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets))
+	}
+	rtn := &RpcServer{
+		CVar:                sync.NewCond(&sync.Mutex{}),
+		NextSeqNum:          &atomic.Int64{},
+		RespPacketsInFlight: make(map[int64]string),
+		AckList:             nil,
+		RpcReqs:             make(map[string]*RpcInfo),
+		SendCh:              sendCh,
+		RecvCh:              recvCh,
+	}
+	go rtn.runRecvLoop()
+	return rtn
+}
+
+func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) {
+	s.CVar.L.Lock()
+	defer s.CVar.L.Unlock()
+	s.SimpleCommandHandlers[command] = handler
+}
+
+func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) {
+	s.CVar.L.Lock()
+	defer s.CVar.L.Unlock()
+	s.StreamCommandHandlers[command] = handler
+}
+
+func (s *RpcServer) runRecvLoop() {
+	defer func() {
+		if r := recover(); r != nil {
+			log.Printf("RpcServer.runRecvLoop() panic: %v", r)
+			debug.PrintStack()
+		}
+	}()
+	for pk := range s.RecvCh {
+		s.handleAcks(pk.Acks)
+		if pk.RpcType == RpcType_Req {
+			if pk.ReqDone {
+				s.handleSimpleReq(pk)
+			} else {
+				s.handleStreamReq(pk)
+			}
+			continue
+		}
+		log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk)
+	}
+	log.Printf("RpcServer.runRecvLoop() normal exit")
+}
+
+func (s *RpcServer) ackResp(seqNum int64) {
+	if seqNum == 0 {
+		return
+	}
+	s.CVar.L.Lock()
+	defer s.CVar.L.Unlock()
+	s.AckList = append(s.AckList, seqNum)
+}
+
+func makeContextFromTimeout(timeout *TimeoutInfo) (context.Context, context.CancelFunc) {
+	if timeout == nil {
+		return context.Background(), func() {}
+	}
+	return context.WithDeadline(context.Background(), time.UnixMilli(timeout.Deadline))
+}
+
+func (s *RpcServer) SendResponse(ctx context.Context, pk *RpcPacket) error {
+	return s.waitForSend(ctx, pk)
+}
+
+func (s *RpcServer) waitForSend(ctx context.Context, pk *RpcPacket) error {
+	s.CVar.L.Lock()
+	defer s.CVar.L.Unlock()
+	for {
+		if ctx.Err() != nil {
+			return ctx.Err()
+		}
+		if len(s.RespPacketsInFlight) >= MaxInFlightPackets {
+			s.CVar.Wait()
+			continue
+		}
+		rpcInfo := s.RpcReqs[pk.RpcId]
+		if rpcInfo != nil {
+			if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc {
+				s.CVar.Wait()
+				continue
+			}
+		}
+		break
+	}
+	s.RespPacketsInFlight[pk.SeqNum] = pk.RpcId
+	pk.Acks = s.grabAcks_nolock()
+	s.SendCh <- pk
+	rpcInfo := s.RpcReqs[pk.RpcId]
+	if !pk.RespDone && rpcInfo != nil {
+		rpcInfo = &RpcInfo{
+			CloseSync:       &sync.Once{},
+			RpcId:           pk.RpcId,
+			PkCh:            make(chan *RpcPacket, MaxUnackedPerRpc),
+			PacketsInFlight: make(map[int64]bool),
+		}
+		s.RpcReqs[pk.RpcId] = rpcInfo
+	}
+	if rpcInfo != nil {
+		rpcInfo.PacketsInFlight[pk.SeqNum] = true
+	}
+	if pk.RespDone {
+		delete(s.RpcReqs, pk.RpcId)
+	}
+	return nil
+}
+
+func (s *RpcServer) handleAcks(acks []int64) {
+	if len(acks) == 0 {
+		return
+	}
+	s.CVar.L.Lock()
+	defer s.CVar.L.Unlock()
+	for _, ack := range acks {
+		rpcId, ok := s.RespPacketsInFlight[ack]
+		if !ok {
+			continue
+		}
+		rpcInfo := s.RpcReqs[rpcId]
+		if rpcInfo != nil {
+			delete(rpcInfo.PacketsInFlight, ack)
+		}
+		delete(s.RespPacketsInFlight, ack)
+	}
+	s.CVar.Broadcast()
+}
+
+func (s *RpcServer) handleSimpleReq(pk *RpcPacket) {
+	s.ackResp(pk.SeqNum)
+	handler, ok := s.SimpleCommandHandlers[pk.Command]
+	if !ok {
+		log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command)
+		return
+	}
+	go func() {
+		defer func() {
+			if r := recover(); r != nil {
+				log.Printf("RpcServer.handleReq(%q) panic: %v", pk.Command, r)
+				debug.PrintStack()
+			}
+		}()
+		ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
+		defer cancelFn()
+		data, err := handler(ctx, s, pk.Command, pk.Data)
+		seqNum := s.NextSeqNum.Add(1)
+		respPk := &RpcPacket{
+			Command:  pk.Command,
+			RpcId:    pk.RpcId,
+			RpcType:  RpcType_Resp,
+			SeqNum:   seqNum,
+			RespDone: true,
+		}
+		if err != nil {
+			respPk.Error = err.Error()
+		} else {
+			respPk.Data = data
+		}
+		s.waitForSend(ctx, respPk)
+	}()
+}
+
+func (s *RpcServer) grabAcks_nolock() []int64 {
+	acks := s.AckList
+	s.AckList = nil
+	return acks
+}
+
+func (s *RpcServer) handleStreamReq(pk *RpcPacket) {
+	s.ackResp(pk.SeqNum)
+	handler, ok := s.StreamCommandHandlers[pk.Command]
+	if !ok {
+		s.ackResp(pk.SeqNum)
+		log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command)
+		return
+	}
+	go func() {
+		defer func() {
+			r := recover()
+			if r == nil {
+				return
+			}
+			log.Printf("RpcServer.handleStreamReq(%q) panic: %v", pk.Command, r)
+			debug.PrintStack()
+			respPk := &RpcPacket{
+				Command:  pk.Command,
+				RpcId:    pk.RpcId,
+				RpcType:  RpcType_Resp,
+				SeqNum:   s.NextSeqNum.Add(1),
+				RespDone: true,
+				Error:    fmt.Sprintf("panic: %v", r),
+			}
+			s.waitForSend(context.Background(), respPk)
+		}()
+		ctx, cancelFn := makeContextFromTimeout(pk.Timeout)
+		defer cancelFn()
+		err := handler(ctx, s, pk.Command, pk.Data)
+		if err != nil {
+			respPk := &RpcPacket{
+				Command:  pk.Command,
+				RpcId:    pk.RpcId,
+				RpcType:  RpcType_Resp,
+				SeqNum:   s.NextSeqNum.Add(1),
+				RespDone: true,
+				Error:    err.Error(),
+			}
+			s.waitForSend(ctx, respPk)
+			return
+		}
+		// check if RespDone has been set, if not, send it here
+	}()
+}