1
0
mirror of https://github.com/bitwarden/server.git synced 2025-02-22 02:51:33 +01:00

[EC-502] Rate Limiting Improvements (#2231)

* [EC-502] Add custom Redis IP rate limit processing strategy

* [EC-502] Formatting

* [EC-502] Add documentation and app setting config options

* [EC-502] Formatting

* [EC-502] Fix appsettings.json keys

* [EC-502] Replace magic string for cache key

* [EC-502] Add tests for custom processing strategy

* [EC-502] Formatting

* [EC-502] Use base class for custom processing strategy

* [EC-502] Fix failing test
This commit is contained in:
Shane Melton 2022-08-31 14:17:29 -07:00 committed by GitHub
parent e0f9d99b49
commit 2bf8438ff7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 340 additions and 3 deletions

View File

@ -69,6 +69,11 @@
"accessKeyId": "SECRET",
"accessKeySecret": "SECRET",
"region": "SECRET"
},
"distributedIpRateLimiting": {
"enabled": true,
"maxRedisTimeoutsThreshold": 10,
"slidingWindowSeconds": 120
}
},
"IpRateLimitOptions": {

View File

@ -69,6 +69,8 @@ public class GlobalSettings : IGlobalSettings
public virtual ISsoSettings Sso { get; set; } = new SsoSettings();
public virtual StripeSettings Stripe { get; set; } = new StripeSettings();
public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings();
public virtual DistributedIpRateLimitingSettings DistributedIpRateLimiting { get; set; } =
new DistributedIpRateLimitingSettings();
public string BuildExternalUri(string explicitValue, string name)
{
@ -498,4 +500,23 @@ public class GlobalSettings : IGlobalSettings
{
public bool EmailOnNewDeviceLogin { get; set; } = false;
}
public class DistributedIpRateLimitingSettings
{
public bool Enabled { get; set; } = true;
/// <summary>
/// Maximum number of Redis timeouts that can be experienced within the sliding timeout
/// window before IP rate limiting is temporarily disabled.
/// TODO: Determine/discuss a suitable maximum
/// </summary>
public int MaxRedisTimeoutsThreshold { get; set; } = 10;
/// <summary>
/// Length of the sliding window in seconds to track Redis timeout exceptions.
/// TODO: Determine/discuss a suitable sliding window
/// </summary>
public int SlidingWindowSeconds { get; set; } = 120;
}
}

View File

@ -0,0 +1,102 @@
using AspNetCoreRateLimit;
using AspNetCoreRateLimit.Redis;
using Bit.Core.Settings;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using StackExchange.Redis;
namespace Bit.Core.Utilities;
/// <summary>
/// A modified version of <see cref="AspNetCoreRateLimit.Redis.RedisProcessingStrategy"/> that gracefully
/// handles a disrupted Redis connection. If the connection is down or the number of failed requests within
/// a given time period exceed the configured threshold, then rate limiting is temporarily disabled.
/// </summary>
/// <remarks>
/// This is necessary to ensure the service does not become unresponsive due to Redis being out of service. As
/// the default implementation would throw an exception and exit the request pipeline for all requests.
/// </remarks>
public class CustomRedisProcessingStrategy : RedisProcessingStrategy
{
private readonly IConnectionMultiplexer _connectionMultiplexer;
private readonly ILogger<CustomRedisProcessingStrategy> _logger;
private readonly IMemoryCache _memoryCache;
private readonly GlobalSettings.DistributedIpRateLimitingSettings _distributedSettings;
private const string _redisTimeoutCacheKey = "IpRateLimitRedisTimeout";
public CustomRedisProcessingStrategy(
IConnectionMultiplexer connectionMultiplexer,
IRateLimitConfiguration config,
ILogger<CustomRedisProcessingStrategy> logger,
IMemoryCache memoryCache,
GlobalSettings globalSettings)
: base(connectionMultiplexer, config, logger)
{
_connectionMultiplexer = connectionMultiplexer;
_logger = logger;
_memoryCache = memoryCache;
_distributedSettings = globalSettings.DistributedIpRateLimiting;
}
public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity,
RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions,
CancellationToken cancellationToken = default)
{
// If Redis is down entirely, skip rate limiting
if (!_connectionMultiplexer.IsConnected)
{
_logger.LogDebug("Redis connection is down, skipping IP rate limiting");
return SkipRateLimitResult();
}
// Check if any Redis timeouts have occured recently
if (_memoryCache.TryGetValue<TimeoutCounter>(_redisTimeoutCacheKey, out var timeoutCounter))
{
// We've exceeded threshold, backoff Redis and skip rate limiting for now
if (timeoutCounter.Count >= _distributedSettings.MaxRedisTimeoutsThreshold)
{
_logger.LogDebug(
"Redis timeout threshold has been exceeded, backing off and skipping IP rate limiting");
return SkipRateLimitResult();
}
}
try
{
return await base.ProcessRequestAsync(requestIdentity, rule, counterKeyBuilder, rateLimitOptions, cancellationToken);
}
catch (RedisTimeoutException)
{
// If this is the first timeout we've had, start a new counter and sliding window
timeoutCounter ??= new TimeoutCounter()
{
Count = 0,
ExpiresAt = DateTime.UtcNow.AddSeconds(_distributedSettings.SlidingWindowSeconds)
};
timeoutCounter.Count++;
_memoryCache.Set(_redisTimeoutCacheKey, timeoutCounter,
new MemoryCacheEntryOptions { AbsoluteExpiration = timeoutCounter.ExpiresAt });
// Just because Redis timed out does not mean we should kill the request
return SkipRateLimitResult();
}
}
/// <summary>
/// A RateLimitCounter result used when the rate limiting middleware should
/// fail open and allow the request to proceed without checking request limits.
/// </summary>
private static RateLimitCounter SkipRateLimitResult()
{
return new RateLimitCounter { Count = 0, Timestamp = DateTime.UtcNow };
}
internal class TimeoutCounter
{
public DateTime ExpiresAt { get; init; }
public int Count { get; set; }
}
}

View File

@ -59,6 +59,11 @@
"accessKeyId": "SECRET",
"accessKeySecret": "SECRET",
"region": "SECRET"
},
"distributedIpRateLimiting": {
"enabled": true,
"maxRedisTimeoutsThreshold": 10,
"slidingWindowSeconds": 120
}
},
"IpRateLimitOptions": {

View File

@ -2,7 +2,6 @@
using System.Security.Claims;
using System.Security.Cryptography.X509Certificates;
using AspNetCoreRateLimit;
using AspNetCoreRateLimit.Redis;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.HostedServices;
@ -609,13 +608,20 @@ public static class ServiceCollectionExtensions
services.AddHostedService<IpRateLimitSeedStartupService>();
services.AddSingleton<IRateLimitConfiguration, RateLimitConfiguration>();
if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
if (!globalSettings.DistributedIpRateLimiting.Enabled || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
{
services.AddInMemoryRateLimiting();
}
else
{
services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer
// Use memory stores for Ip and Client Policy stores as we don't currently use them
// and they add unnecessary Redis network delays checking for policies that don't exist
services.AddSingleton<IIpPolicyStore, MemoryCacheIpPolicyStore>();
services.AddSingleton<IClientPolicyStore, MemoryCacheClientPolicyStore>();
// Use a custom Redis processing strategy that skips Ip limiting if Redis is down
// Requires a registered IConnectionMultiplexer
services.AddSingleton<IProcessingStrategy, CustomRedisProcessingStrategy>();
}
}

View File

@ -0,0 +1,198 @@
using AspNetCoreRateLimit;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using Moq;
using StackExchange.Redis;
using Xunit;
namespace Bit.Core.Test.Utilities;
public class CustomRedisProcessingStrategyTests
{
#region Sample RateLimit Options for Testing
private readonly GlobalSettings _sampleSettings = new()
{
DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings
{
Enabled = true,
MaxRedisTimeoutsThreshold = 2,
SlidingWindowSeconds = 5
}
};
private readonly ClientRequestIdentity _sampleClientId = new()
{
ClientId = "test",
ClientIp = "127.0.0.1",
HttpVerb = "GET",
Path = "/"
};
private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) };
private readonly RateLimitOptions _sampleOptions = new() { };
#endregion
private readonly Mock<ICounterKeyBuilder> _mockCounterKeyBuilder = new();
private Mock<IDatabase> _mockDb;
public CustomRedisProcessingStrategyTests()
{
_mockCounterKeyBuilder
.Setup(x =>
x.Build(It.IsAny<ClientRequestIdentity>(), It.IsAny<RateLimitRule>()))
.Returns(_sampleClientId.ClientId);
}
[Fact]
public async Task IncrementRateLimitCount_When_RedisIsHealthy()
{
// Arrange
var strategy = BuildProcessingStrategy();
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
Assert.Equal(1, result.Count);
VerifyRedisCalls(Times.Once());
}
[Fact]
public async Task SkipRateLimit_When_RedisIsDown()
{
// Arrange
var strategy = BuildProcessingStrategy(false);
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
Assert.Equal(0, result.Count);
VerifyRedisCalls(Times.Never());
}
[Fact]
public async Task SkipRateLimit_When_TimeoutThresholdExceeded()
{
// Arrange
var mockCache = new Mock<IMemoryCache>();
object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter
{
Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1
};
mockCache.Setup(x => x.TryGetValue(It.IsAny<object>(), out existingCount)).Returns(true);
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object);
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
Assert.Equal(0, result.Count);
VerifyRedisCalls(Times.Never());
}
[Fact]
public async Task SkipRateLimit_When_RedisTimeoutException()
{
// Arrange
var mockCache = new Mock<IMemoryCache>();
var mockCacheEntry = new Mock<ICacheEntry>();
mockCacheEntry.SetupAllProperties();
mockCache.Setup(x => x.CreateEntry(It.IsAny<object>())).Returns(mockCacheEntry.Object);
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object, throwRedisTimeout: true);
// Act
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
var timeoutCounter = ((CustomRedisProcessingStrategy.TimeoutCounter)mockCacheEntry.Object.Value);
// Assert
Assert.Equal(0, result.Count); // Skip rate limiting
VerifyRedisCalls(Times.Once());
Assert.Equal(1, timeoutCounter.Count); // Timeout count increased/cached
Assert.NotNull(mockCacheEntry.Object.AbsoluteExpiration);
mockCache.Verify(x => x.CreateEntry(It.IsAny<object>()));
}
[Fact]
public async Task BackoffRedis_After_ThresholdExceeded()
{
// Arrange
var memoryCache = new MemoryCache(new MemoryCacheOptions());
var strategy = BuildProcessingStrategy(mockCache: memoryCache, throwRedisTimeout: true);
// Act
// Redis Timeout 1
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Redis Timeout 2
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Skip Redis
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
CancellationToken.None);
// Assert
VerifyRedisCalls(Times.Exactly(_sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold));
}
private void VerifyRedisCalls(Times times)
{
_mockDb.Verify(x =>
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()),
times);
}
private CustomRedisProcessingStrategy BuildProcessingStrategy(
bool isRedisConnected = true,
bool throwRedisTimeout = false,
IMemoryCache mockCache = null)
{
var mockRedisConnection = new Mock<IConnectionMultiplexer>();
mockRedisConnection.Setup(x => x.IsConnected).Returns(isRedisConnected);
_mockDb = new Mock<IDatabase>();
var mockScriptEvaluate = _mockDb
.Setup(x =>
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()));
if (throwRedisTimeout)
{
mockScriptEvaluate.ThrowsAsync(new RedisTimeoutException("Timeout", CommandStatus.WaitingToBeSent));
}
else
{
mockScriptEvaluate.ReturnsAsync(RedisResult.Create(1));
}
mockRedisConnection
.Setup(x =>
x.GetDatabase(It.IsAny<int>(), It.IsAny<object>()))
.Returns(_mockDb.Object);
var mockLogger = new Mock<ILogger<CustomRedisProcessingStrategy>>();
var mockConfig = new Mock<IRateLimitConfiguration>();
mockCache ??= new Mock<IMemoryCache>().Object;
return new CustomRedisProcessingStrategy(mockRedisConnection.Object, mockConfig.Object,
mockLogger.Object, mockCache, _sampleSettings);
}
}