using AspNetCoreRateLimit; using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Logging; using Moq; using StackExchange.Redis; using Xunit; namespace Bit.Core.Test.Utilities; public class CustomRedisProcessingStrategyTests { #region Sample RateLimit Options for Testing private readonly GlobalSettings _sampleSettings = new() { DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings { Enabled = true, MaxRedisTimeoutsThreshold = 2, SlidingWindowSeconds = 5 } }; private readonly ClientRequestIdentity _sampleClientId = new() { ClientId = "test", ClientIp = "127.0.0.1", HttpVerb = "GET", Path = "/" }; private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) }; private readonly RateLimitOptions _sampleOptions = new() { }; #endregion private readonly Mock _mockCounterKeyBuilder = new(); private Mock _mockDb; public CustomRedisProcessingStrategyTests() { _mockCounterKeyBuilder .Setup(x => x.Build(It.IsAny(), It.IsAny())) .Returns(_sampleClientId.ClientId); } [Fact] public async Task IncrementRateLimitCount_When_RedisIsHealthy() { // Arrange var strategy = BuildProcessingStrategy(); // Act var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); // Assert Assert.Equal(1, result.Count); VerifyRedisCalls(Times.Once()); } [Fact] public async Task SkipRateLimit_When_RedisIsDown() { // Arrange var strategy = BuildProcessingStrategy(false); // Act var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); // Assert Assert.Equal(0, result.Count); VerifyRedisCalls(Times.Never()); } [Fact] public async Task SkipRateLimit_When_TimeoutThresholdExceeded() { // Arrange var mockCache = new Mock(); object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter { Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1 }; mockCache.Setup(x => x.TryGetValue(It.IsAny(), out existingCount)).Returns(true); var strategy = BuildProcessingStrategy(mockCache: mockCache.Object); // Act var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); // Assert Assert.Equal(0, result.Count); VerifyRedisCalls(Times.Never()); } [Fact] public async Task SkipRateLimit_When_RedisTimeoutException() { // Arrange var mockCache = new Mock(); var mockCacheEntry = new Mock(); mockCacheEntry.SetupAllProperties(); mockCache.Setup(x => x.CreateEntry(It.IsAny())).Returns(mockCacheEntry.Object); var strategy = BuildProcessingStrategy(mockCache: mockCache.Object, throwRedisTimeout: true); // Act var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); var timeoutCounter = ((CustomRedisProcessingStrategy.TimeoutCounter)mockCacheEntry.Object.Value); // Assert Assert.Equal(0, result.Count); // Skip rate limiting VerifyRedisCalls(Times.Once()); Assert.Equal(1, timeoutCounter.Count); // Timeout count increased/cached Assert.NotNull(mockCacheEntry.Object.AbsoluteExpiration); mockCache.Verify(x => x.CreateEntry(It.IsAny())); } [Fact] public async Task BackoffRedis_After_ThresholdExceeded() { // Arrange var memoryCache = new MemoryCache(new MemoryCacheOptions()); var strategy = BuildProcessingStrategy(mockCache: memoryCache, throwRedisTimeout: true); // Act // Redis Timeout 1 await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); // Redis Timeout 2 await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); // Skip Redis await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, CancellationToken.None); // Assert VerifyRedisCalls(Times.Exactly(_sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold)); } private void VerifyRedisCalls(Times times) { _mockDb.Verify(x => x.ScriptEvaluateAsync(It.IsAny(), It.IsAny(), It.IsAny()), times); } private CustomRedisProcessingStrategy BuildProcessingStrategy( bool isRedisConnected = true, bool throwRedisTimeout = false, IMemoryCache mockCache = null) { var mockRedisConnection = new Mock(); mockRedisConnection.Setup(x => x.IsConnected).Returns(isRedisConnected); _mockDb = new Mock(); var mockScriptEvaluate = _mockDb .Setup(x => x.ScriptEvaluateAsync(It.IsAny(), It.IsAny(), It.IsAny())); if (throwRedisTimeout) { mockScriptEvaluate.ThrowsAsync(new RedisTimeoutException("Timeout", CommandStatus.WaitingToBeSent)); } else { mockScriptEvaluate.ReturnsAsync(RedisResult.Create(1)); } mockRedisConnection .Setup(x => x.GetDatabase(It.IsAny(), It.IsAny())) .Returns(_mockDb.Object); var mockLogger = new Mock>(); var mockConfig = new Mock(); mockCache ??= new Mock().Object; return new CustomRedisProcessingStrategy(mockRedisConnection.Object, mockConfig.Object, mockLogger.Object, mockCache, _sampleSettings); } }