1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-28 13:15:12 +01:00
bitwarden-server/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
Shane Melton 2bf8438ff7
[EC-502] Rate Limiting Improvements (#2231)
* [EC-502] Add custom Redis IP rate limit processing strategy

* [EC-502] Formatting

* [EC-502] Add documentation and app setting config options

* [EC-502] Formatting

* [EC-502] Fix appsettings.json keys

* [EC-502] Replace magic string for cache key

* [EC-502] Add tests for custom processing strategy

* [EC-502] Formatting

* [EC-502] Use base class for custom processing strategy

* [EC-502] Fix failing test
2022-08-31 14:17:29 -07:00

199 lines
6.6 KiB
C#

using AspNetCoreRateLimit;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using Moq;
using StackExchange.Redis;
using Xunit;
namespace Bit.Core.Test.Utilities;
public class CustomRedisProcessingStrategyTests
{
#region Sample RateLimit Options for Testing
private readonly GlobalSettings _sampleSettings = new()
{
DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings
{
Enabled = true,
MaxRedisTimeoutsThreshold = 2,
SlidingWindowSeconds = 5
}
};
private readonly ClientRequestIdentity _sampleClientId = new()
{
ClientId = "test",
ClientIp = "127.0.0.1",
HttpVerb = "GET",
Path = "/"
};
private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) };
private readonly RateLimitOptions _sampleOptions = new() { };
#endregion
private readonly Mock<ICounterKeyBuilder> _mockCounterKeyBuilder = new();
private Mock<IDatabase> _mockDb;
public CustomRedisProcessingStrategyTests()
{
_mockCounterKeyBuilder
.Setup(x =>
x.Build(It.IsAny<ClientRequestIdentity>(), It.IsAny<RateLimitRule>()))
.Returns(_sampleClientId.ClientId);
}
[Fact]
public async Task IncrementRateLimitCount_When_RedisIsHealthy()
{
// Arrange
var strategy = BuildProcessingStrategy();
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
Assert.Equal(1, result.Count);
VerifyRedisCalls(Times.Once());
}
[Fact]
public async Task SkipRateLimit_When_RedisIsDown()
{
// Arrange
var strategy = BuildProcessingStrategy(false);
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
Assert.Equal(0, result.Count);
VerifyRedisCalls(Times.Never());
}
[Fact]
public async Task SkipRateLimit_When_TimeoutThresholdExceeded()
{
// Arrange
var mockCache = new Mock<IMemoryCache>();
object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter
{
Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1
};
mockCache.Setup(x => x.TryGetValue(It.IsAny<object>(), out existingCount)).Returns(true);
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object);
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
Assert.Equal(0, result.Count);
VerifyRedisCalls(Times.Never());
}
[Fact]
public async Task SkipRateLimit_When_RedisTimeoutException()
{
// Arrange
var mockCache = new Mock<IMemoryCache>();
var mockCacheEntry = new Mock<ICacheEntry>();
mockCacheEntry.SetupAllProperties();
mockCache.Setup(x => x.CreateEntry(It.IsAny<object>())).Returns(mockCacheEntry.Object);
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object, throwRedisTimeout: true);
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
var timeoutCounter = ((CustomRedisProcessingStrategy.TimeoutCounter)mockCacheEntry.Object.Value);
// Assert
Assert.Equal(0, result.Count); // Skip rate limiting
VerifyRedisCalls(Times.Once());
Assert.Equal(1, timeoutCounter.Count); // Timeout count increased/cached
Assert.NotNull(mockCacheEntry.Object.AbsoluteExpiration);
mockCache.Verify(x => x.CreateEntry(It.IsAny<object>()));
}
[Fact]
public async Task BackoffRedis_After_ThresholdExceeded()
{
// Arrange
var memoryCache = new MemoryCache(new MemoryCacheOptions());
var strategy = BuildProcessingStrategy(mockCache: memoryCache, throwRedisTimeout: true);
// Act
// Redis Timeout 1
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Redis Timeout 2
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Skip Redis
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
VerifyRedisCalls(Times.Exactly(_sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold));
}
private void VerifyRedisCalls(Times times)
{
_mockDb.Verify(x =>
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()),
times);
}
private CustomRedisProcessingStrategy BuildProcessingStrategy(
bool isRedisConnected = true,
bool throwRedisTimeout = false,
IMemoryCache mockCache = null)
{
var mockRedisConnection = new Mock<IConnectionMultiplexer>();
mockRedisConnection.Setup(x => x.IsConnected).Returns(isRedisConnected);
_mockDb = new Mock<IDatabase>();
var mockScriptEvaluate = _mockDb
.Setup(x =>
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()));
if (throwRedisTimeout)
{
mockScriptEvaluate.ThrowsAsync(new RedisTimeoutException("Timeout", CommandStatus.WaitingToBeSent));
}
else
{
mockScriptEvaluate.ReturnsAsync(RedisResult.Create(1));
}
mockRedisConnection
.Setup(x =>
x.GetDatabase(It.IsAny<int>(), It.IsAny<object>()))
.Returns(_mockDb.Object);
var mockLogger = new Mock<ILogger<CustomRedisProcessingStrategy>>();
var mockConfig = new Mock<IRateLimitConfiguration>();
mockCache ??= new Mock<IMemoryCache>().Object;
return new CustomRedisProcessingStrategy(mockRedisConnection.Object, mockConfig.Object,
mockLogger.Object, mockCache, _sampleSettings);
}
}