From f8a65abb61717e2aae2aca8fb549434255d74730 Mon Sep 17 00:00:00 2001 From: Geoff Bourne Date: Fri, 5 Aug 2022 20:50:32 -0500 Subject: [PATCH] Wait for connections to finish when stopping (#107) --- cmd/mc-router/main.go | 5 +++- server/connector.go | 57 ++++++++++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/cmd/mc-router/main.go b/cmd/mc-router/main.go index c433498..ba87230 100644 --- a/cmd/mc-router/main.go +++ b/cmd/mc-router/main.go @@ -152,7 +152,10 @@ func main() { // wait for process-stop signal <-c - logrus.Info("Stopping") + logrus.Info("Stopping. Waiting for connections to complete...") + signal.Stop(c) + connector.WaitForConnections() + logrus.Info("Stopped") } func parseMappings(vals []string) map[string]string { diff --git a/server/connector.go b/server/connector.go index 62e7550..d09fcc6 100644 --- a/server/connector.go +++ b/server/connector.go @@ -6,6 +6,8 @@ import ( "io" "net" "strconv" + "sync" + "sync/atomic" "time" "github.com/go-kit/kit/metrics" @@ -21,10 +23,6 @@ const ( var noDeadline time.Time -type Connector interface { - StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error -} - type ConnectorMetrics struct { Errors metrics.Counter BytesTransmitted metrics.Counter @@ -32,21 +30,25 @@ type ConnectorMetrics struct { ActiveConnections metrics.Gauge } -func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool) Connector { +func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool) *Connector { - return &connectorImpl{ - metrics: metrics, - sendProxyProto: sendProxyProto, + return &Connector{ + metrics: metrics, + sendProxyProto: sendProxyProto, + connectionsCond: sync.NewCond(&sync.Mutex{}), } } -type connectorImpl struct { +type Connector struct { state mcproto.State metrics *ConnectorMetrics sendProxyProto bool + + activeConnections int32 + connectionsCond *sync.Cond } -func (c *connectorImpl) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error { +func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error { ln, err := net.Listen("tcp", listenAddress) if err != nil { @@ -60,7 +62,22 @@ func (c *connectorImpl) StartAcceptingConnections(ctx context.Context, listenAdd return nil } -func (c *connectorImpl) acceptConnections(ctx context.Context, ln net.Listener, connRateLimit int) { +func (c *Connector) WaitForConnections() { + c.connectionsCond.L.Lock() + defer c.connectionsCond.L.Unlock() + + for { + count := atomic.LoadInt32(&c.activeConnections) + if count > 0 { + logrus.Infof("Waiting on %d connection(s)", count) + c.connectionsCond.Wait() + } else { + break + } + } +} + +func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, connRateLimit int) { //noinspection GoUnhandledErrorResult defer ln.Close() @@ -82,7 +99,7 @@ func (c *connectorImpl) acceptConnections(ctx context.Context, ln net.Listener, } } -func (c *connectorImpl) HandleConnection(ctx context.Context, frontendConn net.Conn) { +func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn) { c.metrics.Connections.With("side", "frontend").Add(1) //noinspection GoUnhandledErrorResult defer frontendConn.Close() @@ -164,7 +181,7 @@ func (c *connectorImpl) HandleConnection(ctx context.Context, frontendConn net.C } } -func (c *connectorImpl) findAndConnectBackend(ctx context.Context, frontendConn net.Conn, +func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.Conn, clientAddr net.Addr, preReadContent io.Reader, serverAddress string) { backendHostPort, resolvedHost, waker := Routes.FindBackendForServerAddress(ctx, serverAddress) @@ -202,8 +219,14 @@ func (c *connectorImpl) findAndConnectBackend(ctx context.Context, frontendConn } c.metrics.Connections.With("side", "backend", "host", resolvedHost).Add(1) - c.metrics.ActiveConnections.Add(1) - defer c.metrics.ActiveConnections.Add(-1) + + c.metrics.ActiveConnections.Set(float64( + atomic.AddInt32(&c.activeConnections, 1))) + defer func() { + c.metrics.ActiveConnections.Set(float64( + atomic.AddInt32(&c.activeConnections, -1))) + c.connectionsCond.Signal() + }() // PROXY protocol implementation if c.sendProxyProto { @@ -257,7 +280,7 @@ func (c *connectorImpl) findAndConnectBackend(ctx context.Context, frontendConn c.pumpConnections(ctx, frontendConn, backendConn) } -func (c *connectorImpl) pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn) { +func (c *Connector) pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn) { //noinspection GoUnhandledErrorResult defer backendConn.Close() @@ -283,7 +306,7 @@ func (c *connectorImpl) pumpConnections(ctx context.Context, frontendConn, backe } } -func (c *connectorImpl) pumpFrames(incoming io.Reader, outgoing io.Writer, errors chan<- error, from, to string, clientAddr net.Addr) { +func (c *Connector) pumpFrames(incoming io.Reader, outgoing io.Writer, errors chan<- error, from, to string, clientAddr net.Addr) { amount, err := io.Copy(outgoing, incoming) logrus. WithField("client", clientAddr).