// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package wshprc import ( "context" "fmt" "log" "runtime/debug" "sync" "sync/atomic" "time" ) type SimpleCommandHandlerFn func(context.Context, *RpcServer, string, any) (any, error) type StreamCommandHandlerFn func(context.Context, *RpcServer, *RpcPacket) error type RpcServer struct { CVar *sync.Cond NextSeqNum *atomic.Int64 RespPacketsInFlight map[int64]string // seqnum -> rpcId AckList []int64 RpcReqs map[string]*RpcInfo SendCh chan *RpcPacket RecvCh chan *RpcPacket SimpleCommandHandlers map[string]SimpleCommandHandlerFn StreamCommandHandlers map[string]StreamCommandHandlerFn } func MakeRpcServer(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcServer { if cap(sendCh) < MaxInFlightPackets { panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets)) } rtn := &RpcServer{ CVar: sync.NewCond(&sync.Mutex{}), NextSeqNum: &atomic.Int64{}, RespPacketsInFlight: make(map[int64]string), AckList: nil, RpcReqs: make(map[string]*RpcInfo), SendCh: sendCh, RecvCh: recvCh, SimpleCommandHandlers: make(map[string]SimpleCommandHandlerFn), StreamCommandHandlers: make(map[string]StreamCommandHandlerFn), } go rtn.runRecvLoop() return rtn } func (s *RpcServer) shouldUseStreamHandler(command string) bool { s.CVar.L.Lock() defer s.CVar.L.Unlock() _, ok := s.StreamCommandHandlers[command] return ok } func (s *RpcServer) getStreamHandler(command string) StreamCommandHandlerFn { s.CVar.L.Lock() defer s.CVar.L.Unlock() return s.StreamCommandHandlers[command] } func (s *RpcServer) getSimpleHandler(command string) SimpleCommandHandlerFn { s.CVar.L.Lock() defer s.CVar.L.Unlock() return s.SimpleCommandHandlers[command] } func (s *RpcServer) RegisterSimpleCommandHandler(command string, handler SimpleCommandHandlerFn) { s.CVar.L.Lock() defer s.CVar.L.Unlock() if s.StreamCommandHandlers[command] != nil { panic(fmt.Errorf("command %q already registered as a stream handler", command)) } s.SimpleCommandHandlers[command] = handler } func (s *RpcServer) RegisterStreamCommandHandler(command string, handler StreamCommandHandlerFn) { s.CVar.L.Lock() defer s.CVar.L.Unlock() if s.SimpleCommandHandlers[command] != nil { panic(fmt.Errorf("command %q already registered as a simple handler", command)) } s.StreamCommandHandlers[command] = handler } func (s *RpcServer) runRecvLoop() { defer func() { if r := recover(); r != nil { log.Printf("RpcServer.runRecvLoop() panic: %v", r) debug.PrintStack() } }() for pk := range s.RecvCh { s.handleAcks(pk.Acks) if pk.RpcType == RpcType_Req { if s.shouldUseStreamHandler(pk.Command) { s.handleStreamReq(pk) } else { s.handleSimpleReq(pk) } continue } log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk) } log.Printf("RpcServer.runRecvLoop() normal exit") } func (s *RpcServer) ackResp(seqNum int64) { if seqNum == 0 { return } s.CVar.L.Lock() defer s.CVar.L.Unlock() s.AckList = append(s.AckList, seqNum) } func makeContextFromTimeout(timeout *TimeoutInfo) (context.Context, context.CancelFunc) { if timeout == nil { return context.Background(), func() {} } return context.WithDeadline(context.Background(), time.UnixMilli(timeout.Deadline)) } func (s *RpcServer) SendResponse(ctx context.Context, pk *RpcPacket) error { return s.waitForSend(ctx, pk) } func (s *RpcServer) waitForSend(ctx context.Context, pk *RpcPacket) error { s.CVar.L.Lock() defer s.CVar.L.Unlock() for { if ctx.Err() != nil { return ctx.Err() } if len(s.RespPacketsInFlight) >= MaxInFlightPackets { s.CVar.Wait() continue } rpcInfo := s.RpcReqs[pk.RpcId] if rpcInfo != nil { if len(rpcInfo.PacketsInFlight) >= MaxUnackedPerRpc { s.CVar.Wait() continue } } break } s.RespPacketsInFlight[pk.SeqNum] = pk.RpcId pk.Acks = s.grabAcks_nolock() s.SendCh <- pk rpcInfo := s.RpcReqs[pk.RpcId] if !pk.RespDone && rpcInfo != nil { rpcInfo = &RpcInfo{ CloseSync: &sync.Once{}, RpcId: pk.RpcId, PkCh: make(chan *RpcPacket, MaxUnackedPerRpc), PacketsInFlight: make(map[int64]bool), } s.RpcReqs[pk.RpcId] = rpcInfo } if rpcInfo != nil { rpcInfo.PacketsInFlight[pk.SeqNum] = true } if pk.RespDone { delete(s.RpcReqs, pk.RpcId) } return nil } func (s *RpcServer) handleAcks(acks []int64) { if len(acks) == 0 { return } s.CVar.L.Lock() defer s.CVar.L.Unlock() for _, ack := range acks { rpcId, ok := s.RespPacketsInFlight[ack] if !ok { continue } rpcInfo := s.RpcReqs[rpcId] if rpcInfo != nil { delete(rpcInfo.PacketsInFlight, ack) } delete(s.RespPacketsInFlight, ack) } s.CVar.Broadcast() } func (s *RpcServer) handleSimpleReq(pk *RpcPacket) { s.ackResp(pk.SeqNum) handler := s.getSimpleHandler(pk.Command) if handler == nil { s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command)) log.Printf("RpcServer.handleReq() unknown command: %s", pk.Command) return } go func() { defer func() { if r := recover(); r != nil { log.Printf("RpcServer.handleReq(%q) panic: %v", pk.Command, r) debug.PrintStack() } }() ctx, cancelFn := makeContextFromTimeout(pk.Timeout) defer cancelFn() data, err := handler(ctx, s, pk.Command, pk.Data) seqNum := s.NextSeqNum.Add(1) respPk := &RpcPacket{ Command: pk.Command, RpcId: pk.RpcId, RpcType: RpcType_Resp, SeqNum: seqNum, RespDone: true, } if err != nil { respPk.Error = err.Error() } else { respPk.Data = data } s.waitForSend(ctx, respPk) }() } func (s *RpcServer) grabAcks_nolock() []int64 { acks := s.AckList s.AckList = nil return acks } func (s *RpcServer) sendErrorResp(pk *RpcPacket, err error) { respPk := &RpcPacket{ Command: pk.Command, RpcId: pk.RpcId, RpcType: RpcType_Resp, SeqNum: s.NextSeqNum.Add(1), RespDone: true, Error: err.Error(), } s.waitForSend(context.Background(), respPk) } func (s *RpcServer) makeRespPk(pk *RpcPacket, data any, done bool) *RpcPacket { return &RpcPacket{ Command: pk.Command, RpcId: pk.RpcId, RpcType: RpcType_Resp, SeqNum: s.NextSeqNum.Add(1), RespDone: done, Data: data, } } func (s *RpcServer) handleStreamReq(pk *RpcPacket) { s.ackResp(pk.SeqNum) handler := s.getStreamHandler(pk.Command) if handler == nil { s.ackResp(pk.SeqNum) s.sendErrorResp(pk, fmt.Errorf("unknown command: %s", pk.Command)) log.Printf("RpcServer.handleStreamReq() unknown command: %s", pk.Command) return } go func() { defer func() { r := recover() if r == nil { return } log.Printf("RpcServer.handleStreamReq(%q) panic: %v", pk.Command, r) debug.PrintStack() respPk := &RpcPacket{ Command: pk.Command, RpcId: pk.RpcId, RpcType: RpcType_Resp, SeqNum: s.NextSeqNum.Add(1), RespDone: true, Error: fmt.Sprintf("panic: %v", r), } s.waitForSend(context.Background(), respPk) }() ctx, cancelFn := makeContextFromTimeout(pk.Timeout) defer cancelFn() err := handler(ctx, s, pk) if err != nil { respPk := &RpcPacket{ Command: pk.Command, RpcId: pk.RpcId, RpcType: RpcType_Resp, SeqNum: s.NextSeqNum.Add(1), RespDone: true, Error: err.Error(), } s.waitForSend(ctx, respPk) return } // check if RespDone has been set, if not, send it here }() }