mirror of
https://github.com/bitwarden/server.git
synced 2025-02-22 02:51:33 +01:00
[EC-502] Rate Limiting Improvements (#2231)
* [EC-502] Add custom Redis IP rate limit processing strategy * [EC-502] Formatting * [EC-502] Add documentation and app setting config options * [EC-502] Formatting * [EC-502] Fix appsettings.json keys * [EC-502] Replace magic string for cache key * [EC-502] Add tests for custom processing strategy * [EC-502] Formatting * [EC-502] Use base class for custom processing strategy * [EC-502] Fix failing test
This commit is contained in:
parent
e0f9d99b49
commit
2bf8438ff7
@ -69,6 +69,11 @@
|
||||
"accessKeyId": "SECRET",
|
||||
"accessKeySecret": "SECRET",
|
||||
"region": "SECRET"
|
||||
},
|
||||
"distributedIpRateLimiting": {
|
||||
"enabled": true,
|
||||
"maxRedisTimeoutsThreshold": 10,
|
||||
"slidingWindowSeconds": 120
|
||||
}
|
||||
},
|
||||
"IpRateLimitOptions": {
|
||||
|
@ -69,6 +69,8 @@ public class GlobalSettings : IGlobalSettings
|
||||
public virtual ISsoSettings Sso { get; set; } = new SsoSettings();
|
||||
public virtual StripeSettings Stripe { get; set; } = new StripeSettings();
|
||||
public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings();
|
||||
public virtual DistributedIpRateLimitingSettings DistributedIpRateLimiting { get; set; } =
|
||||
new DistributedIpRateLimitingSettings();
|
||||
|
||||
public string BuildExternalUri(string explicitValue, string name)
|
||||
{
|
||||
@ -498,4 +500,23 @@ public class GlobalSettings : IGlobalSettings
|
||||
{
|
||||
public bool EmailOnNewDeviceLogin { get; set; } = false;
|
||||
}
|
||||
|
||||
public class DistributedIpRateLimitingSettings
|
||||
{
|
||||
public bool Enabled { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Maximum number of Redis timeouts that can be experienced within the sliding timeout
|
||||
/// window before IP rate limiting is temporarily disabled.
|
||||
/// TODO: Determine/discuss a suitable maximum
|
||||
/// </summary>
|
||||
public int MaxRedisTimeoutsThreshold { get; set; } = 10;
|
||||
|
||||
/// <summary>
|
||||
/// Length of the sliding window in seconds to track Redis timeout exceptions.
|
||||
/// TODO: Determine/discuss a suitable sliding window
|
||||
/// </summary>
|
||||
public int SlidingWindowSeconds { get; set; } = 120;
|
||||
}
|
||||
|
||||
}
|
||||
|
102
src/Core/Utilities/CustomRedisProcessingStrategy.cs
Normal file
102
src/Core/Utilities/CustomRedisProcessingStrategy.cs
Normal file
@ -0,0 +1,102 @@
|
||||
using AspNetCoreRateLimit;
|
||||
using AspNetCoreRateLimit.Redis;
|
||||
using Bit.Core.Settings;
|
||||
using Microsoft.Extensions.Caching.Memory;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StackExchange.Redis;
|
||||
|
||||
namespace Bit.Core.Utilities;
|
||||
|
||||
/// <summary>
|
||||
/// A modified version of <see cref="AspNetCoreRateLimit.Redis.RedisProcessingStrategy"/> that gracefully
|
||||
/// handles a disrupted Redis connection. If the connection is down or the number of failed requests within
|
||||
/// a given time period exceed the configured threshold, then rate limiting is temporarily disabled.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This is necessary to ensure the service does not become unresponsive due to Redis being out of service. As
|
||||
/// the default implementation would throw an exception and exit the request pipeline for all requests.
|
||||
/// </remarks>
|
||||
public class CustomRedisProcessingStrategy : RedisProcessingStrategy
|
||||
{
|
||||
private readonly IConnectionMultiplexer _connectionMultiplexer;
|
||||
private readonly ILogger<CustomRedisProcessingStrategy> _logger;
|
||||
private readonly IMemoryCache _memoryCache;
|
||||
private readonly GlobalSettings.DistributedIpRateLimitingSettings _distributedSettings;
|
||||
|
||||
private const string _redisTimeoutCacheKey = "IpRateLimitRedisTimeout";
|
||||
|
||||
public CustomRedisProcessingStrategy(
|
||||
IConnectionMultiplexer connectionMultiplexer,
|
||||
IRateLimitConfiguration config,
|
||||
ILogger<CustomRedisProcessingStrategy> logger,
|
||||
IMemoryCache memoryCache,
|
||||
GlobalSettings globalSettings)
|
||||
: base(connectionMultiplexer, config, logger)
|
||||
{
|
||||
_connectionMultiplexer = connectionMultiplexer;
|
||||
_logger = logger;
|
||||
_memoryCache = memoryCache;
|
||||
_distributedSettings = globalSettings.DistributedIpRateLimiting;
|
||||
}
|
||||
|
||||
public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity,
|
||||
RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// If Redis is down entirely, skip rate limiting
|
||||
if (!_connectionMultiplexer.IsConnected)
|
||||
{
|
||||
_logger.LogDebug("Redis connection is down, skipping IP rate limiting");
|
||||
return SkipRateLimitResult();
|
||||
}
|
||||
|
||||
// Check if any Redis timeouts have occured recently
|
||||
if (_memoryCache.TryGetValue<TimeoutCounter>(_redisTimeoutCacheKey, out var timeoutCounter))
|
||||
{
|
||||
// We've exceeded threshold, backoff Redis and skip rate limiting for now
|
||||
if (timeoutCounter.Count >= _distributedSettings.MaxRedisTimeoutsThreshold)
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Redis timeout threshold has been exceeded, backing off and skipping IP rate limiting");
|
||||
return SkipRateLimitResult();
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
return await base.ProcessRequestAsync(requestIdentity, rule, counterKeyBuilder, rateLimitOptions, cancellationToken);
|
||||
}
|
||||
catch (RedisTimeoutException)
|
||||
{
|
||||
// If this is the first timeout we've had, start a new counter and sliding window
|
||||
timeoutCounter ??= new TimeoutCounter()
|
||||
{
|
||||
Count = 0,
|
||||
ExpiresAt = DateTime.UtcNow.AddSeconds(_distributedSettings.SlidingWindowSeconds)
|
||||
};
|
||||
timeoutCounter.Count++;
|
||||
|
||||
_memoryCache.Set(_redisTimeoutCacheKey, timeoutCounter,
|
||||
new MemoryCacheEntryOptions { AbsoluteExpiration = timeoutCounter.ExpiresAt });
|
||||
|
||||
// Just because Redis timed out does not mean we should kill the request
|
||||
return SkipRateLimitResult();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A RateLimitCounter result used when the rate limiting middleware should
|
||||
/// fail open and allow the request to proceed without checking request limits.
|
||||
/// </summary>
|
||||
private static RateLimitCounter SkipRateLimitResult()
|
||||
{
|
||||
return new RateLimitCounter { Count = 0, Timestamp = DateTime.UtcNow };
|
||||
}
|
||||
|
||||
internal class TimeoutCounter
|
||||
{
|
||||
public DateTime ExpiresAt { get; init; }
|
||||
|
||||
public int Count { get; set; }
|
||||
}
|
||||
}
|
@ -59,6 +59,11 @@
|
||||
"accessKeyId": "SECRET",
|
||||
"accessKeySecret": "SECRET",
|
||||
"region": "SECRET"
|
||||
},
|
||||
"distributedIpRateLimiting": {
|
||||
"enabled": true,
|
||||
"maxRedisTimeoutsThreshold": 10,
|
||||
"slidingWindowSeconds": 120
|
||||
}
|
||||
},
|
||||
"IpRateLimitOptions": {
|
||||
|
@ -2,7 +2,6 @@
|
||||
using System.Security.Claims;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using AspNetCoreRateLimit;
|
||||
using AspNetCoreRateLimit.Redis;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.HostedServices;
|
||||
@ -609,13 +608,20 @@ public static class ServiceCollectionExtensions
|
||||
services.AddHostedService<IpRateLimitSeedStartupService>();
|
||||
services.AddSingleton<IRateLimitConfiguration, RateLimitConfiguration>();
|
||||
|
||||
if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
|
||||
if (!globalSettings.DistributedIpRateLimiting.Enabled || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
|
||||
{
|
||||
services.AddInMemoryRateLimiting();
|
||||
}
|
||||
else
|
||||
{
|
||||
services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer
|
||||
// Use memory stores for Ip and Client Policy stores as we don't currently use them
|
||||
// and they add unnecessary Redis network delays checking for policies that don't exist
|
||||
services.AddSingleton<IIpPolicyStore, MemoryCacheIpPolicyStore>();
|
||||
services.AddSingleton<IClientPolicyStore, MemoryCacheClientPolicyStore>();
|
||||
|
||||
// Use a custom Redis processing strategy that skips Ip limiting if Redis is down
|
||||
// Requires a registered IConnectionMultiplexer
|
||||
services.AddSingleton<IProcessingStrategy, CustomRedisProcessingStrategy>();
|
||||
}
|
||||
}
|
||||
|
||||
|
198
test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
Normal file
198
test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
Normal file
@ -0,0 +1,198 @@
|
||||
using AspNetCoreRateLimit;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.Extensions.Caching.Memory;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Moq;
|
||||
using StackExchange.Redis;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Core.Test.Utilities;
|
||||
|
||||
public class CustomRedisProcessingStrategyTests
|
||||
{
|
||||
#region Sample RateLimit Options for Testing
|
||||
|
||||
private readonly GlobalSettings _sampleSettings = new()
|
||||
{
|
||||
DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings
|
||||
{
|
||||
Enabled = true,
|
||||
MaxRedisTimeoutsThreshold = 2,
|
||||
SlidingWindowSeconds = 5
|
||||
}
|
||||
};
|
||||
|
||||
private readonly ClientRequestIdentity _sampleClientId = new()
|
||||
{
|
||||
ClientId = "test",
|
||||
ClientIp = "127.0.0.1",
|
||||
HttpVerb = "GET",
|
||||
Path = "/"
|
||||
};
|
||||
|
||||
private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) };
|
||||
|
||||
private readonly RateLimitOptions _sampleOptions = new() { };
|
||||
|
||||
#endregion
|
||||
|
||||
private readonly Mock<ICounterKeyBuilder> _mockCounterKeyBuilder = new();
|
||||
private Mock<IDatabase> _mockDb;
|
||||
|
||||
public CustomRedisProcessingStrategyTests()
|
||||
{
|
||||
_mockCounterKeyBuilder
|
||||
.Setup(x =>
|
||||
x.Build(It.IsAny<ClientRequestIdentity>(), It.IsAny<RateLimitRule>()))
|
||||
.Returns(_sampleClientId.ClientId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task IncrementRateLimitCount_When_RedisIsHealthy()
|
||||
{
|
||||
// Arrange
|
||||
var strategy = BuildProcessingStrategy();
|
||||
|
||||
// Act
|
||||
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(1, result.Count);
|
||||
VerifyRedisCalls(Times.Once());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SkipRateLimit_When_RedisIsDown()
|
||||
{
|
||||
// Arrange
|
||||
var strategy = BuildProcessingStrategy(false);
|
||||
|
||||
// Act
|
||||
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0, result.Count);
|
||||
VerifyRedisCalls(Times.Never());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SkipRateLimit_When_TimeoutThresholdExceeded()
|
||||
{
|
||||
// Arrange
|
||||
var mockCache = new Mock<IMemoryCache>();
|
||||
object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter
|
||||
{
|
||||
Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1
|
||||
};
|
||||
mockCache.Setup(x => x.TryGetValue(It.IsAny<object>(), out existingCount)).Returns(true);
|
||||
|
||||
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object);
|
||||
|
||||
// Act
|
||||
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0, result.Count);
|
||||
VerifyRedisCalls(Times.Never());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SkipRateLimit_When_RedisTimeoutException()
|
||||
{
|
||||
// Arrange
|
||||
var mockCache = new Mock<IMemoryCache>();
|
||||
var mockCacheEntry = new Mock<ICacheEntry>();
|
||||
mockCacheEntry.SetupAllProperties();
|
||||
mockCache.Setup(x => x.CreateEntry(It.IsAny<object>())).Returns(mockCacheEntry.Object);
|
||||
|
||||
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object, throwRedisTimeout: true);
|
||||
|
||||
// Act
|
||||
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
var timeoutCounter = ((CustomRedisProcessingStrategy.TimeoutCounter)mockCacheEntry.Object.Value);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0, result.Count); // Skip rate limiting
|
||||
VerifyRedisCalls(Times.Once());
|
||||
|
||||
Assert.Equal(1, timeoutCounter.Count); // Timeout count increased/cached
|
||||
Assert.NotNull(mockCacheEntry.Object.AbsoluteExpiration);
|
||||
mockCache.Verify(x => x.CreateEntry(It.IsAny<object>()));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task BackoffRedis_After_ThresholdExceeded()
|
||||
{
|
||||
// Arrange
|
||||
var memoryCache = new MemoryCache(new MemoryCacheOptions());
|
||||
var strategy = BuildProcessingStrategy(mockCache: memoryCache, throwRedisTimeout: true);
|
||||
|
||||
// Act
|
||||
|
||||
// Redis Timeout 1
|
||||
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
// Redis Timeout 2
|
||||
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
// Skip Redis
|
||||
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||
CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
VerifyRedisCalls(Times.Exactly(_sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold));
|
||||
}
|
||||
|
||||
private void VerifyRedisCalls(Times times)
|
||||
{
|
||||
_mockDb.Verify(x =>
|
||||
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()),
|
||||
times);
|
||||
}
|
||||
|
||||
private CustomRedisProcessingStrategy BuildProcessingStrategy(
|
||||
bool isRedisConnected = true,
|
||||
bool throwRedisTimeout = false,
|
||||
IMemoryCache mockCache = null)
|
||||
{
|
||||
var mockRedisConnection = new Mock<IConnectionMultiplexer>();
|
||||
|
||||
mockRedisConnection.Setup(x => x.IsConnected).Returns(isRedisConnected);
|
||||
|
||||
_mockDb = new Mock<IDatabase>();
|
||||
|
||||
var mockScriptEvaluate = _mockDb
|
||||
.Setup(x =>
|
||||
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()));
|
||||
|
||||
if (throwRedisTimeout)
|
||||
{
|
||||
mockScriptEvaluate.ThrowsAsync(new RedisTimeoutException("Timeout", CommandStatus.WaitingToBeSent));
|
||||
}
|
||||
else
|
||||
{
|
||||
mockScriptEvaluate.ReturnsAsync(RedisResult.Create(1));
|
||||
}
|
||||
|
||||
mockRedisConnection
|
||||
.Setup(x =>
|
||||
x.GetDatabase(It.IsAny<int>(), It.IsAny<object>()))
|
||||
.Returns(_mockDb.Object);
|
||||
|
||||
var mockLogger = new Mock<ILogger<CustomRedisProcessingStrategy>>();
|
||||
var mockConfig = new Mock<IRateLimitConfiguration>();
|
||||
|
||||
mockCache ??= new Mock<IMemoryCache>().Object;
|
||||
|
||||
return new CustomRedisProcessingStrategy(mockRedisConnection.Object, mockConfig.Object,
|
||||
mockLogger.Object, mockCache, _sampleSettings);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user