using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Options;
|
using StackExchange.Redis;
|
using WIDESEAWCS_RedisService.Connection;
|
using WIDESEAWCS_RedisService.Options;
|
|
namespace WIDESEAWCS_RedisService.RateLimiting
|
{
|
public class RedisRateLimitingService : IRateLimitingService
|
{
|
private readonly IRedisConnectionManager _connectionManager;
|
private readonly RedisOptions _options;
|
private readonly ILogger<RedisRateLimitingService> _logger;
|
|
private const string FixedWindowScript = @"
|
local key = KEYS[1]
|
local limit = tonumber(ARGV[1])
|
local window = tonumber(ARGV[2])
|
local current = tonumber(redis.call('GET', key) or '0')
|
if current < limit then
|
redis.call('INCR', key)
|
if current == 0 then
|
redis.call('PEXPIRE', key, window)
|
end
|
return 1
|
end
|
return 0";
|
|
public RedisRateLimitingService(
|
IRedisConnectionManager connectionManager,
|
IOptions<RedisOptions> options,
|
ILogger<RedisRateLimitingService> logger)
|
{
|
_connectionManager = connectionManager;
|
_options = options.Value;
|
_logger = logger;
|
}
|
|
private string BuildKey(string key) => $"{_options.KeyPrefix}rate:{key}";
|
|
public bool IsAllowed(string key, int maxRequests, TimeSpan window)
|
{
|
var db = _connectionManager.GetDatabase();
|
var result = db.ScriptEvaluate(
|
FixedWindowScript,
|
new RedisKey[] { BuildKey(key) },
|
new RedisValue[] { maxRequests, (long)window.TotalMilliseconds });
|
return (long)result == 1;
|
}
|
|
public bool IsAllowedSliding(string key, int maxRequests, TimeSpan window)
|
{
|
var db = _connectionManager.GetDatabase();
|
var fullKey = BuildKey(key);
|
var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
|
var windowMs = (long)window.TotalMilliseconds;
|
|
var tran = db.CreateTransaction();
|
tran.SortedSetRemoveRangeByScoreAsync(fullKey, 0, now - windowMs);
|
tran.SortedSetAddAsync(fullKey, now.ToString(), now);
|
tran.KeyExpireAsync(fullKey, window.Add(TimeSpan.FromSeconds(1)));
|
tran.Execute();
|
|
var count = db.SortedSetLength(fullKey);
|
if (count > maxRequests)
|
{
|
db.SortedSetRemove(fullKey, now.ToString());
|
return false;
|
}
|
return true;
|
}
|
|
public bool TryAcquireToken(string key, int maxTokens, int refillRate, TimeSpan refillInterval)
|
{
|
var db = _connectionManager.GetDatabase();
|
var fullKey = BuildKey($"token:{key}");
|
var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
|
|
var script = @"
|
local key = KEYS[1]
|
local max = tonumber(ARGV[1])
|
local rate = tonumber(ARGV[2])
|
local interval = tonumber(ARGV[3])
|
local now = tonumber(ARGV[4])
|
local info = redis.call('HMGET', key, 'tokens', 'last')
|
local tokens = tonumber(info[1]) or max
|
local last = tonumber(info[2]) or now
|
local elapsed = now - last
|
local refill = math.floor(elapsed / interval) * rate
|
tokens = math.min(max, tokens + refill)
|
if tokens > 0 then
|
tokens = tokens - 1
|
redis.call('HMSET', key, 'tokens', tokens, 'last', now)
|
redis.call('PEXPIRE', key, interval * max / rate * 2)
|
return 1
|
end
|
return 0";
|
|
var result = db.ScriptEvaluate(script,
|
new RedisKey[] { fullKey },
|
new RedisValue[] { maxTokens, refillRate, (long)refillInterval.TotalMilliseconds, now });
|
return (long)result == 1;
|
}
|
|
public long GetRemainingRequests(string key, int maxRequests, TimeSpan window)
|
{
|
var db = _connectionManager.GetDatabase();
|
var val = db.StringGet(BuildKey(key));
|
if (val.IsNullOrEmpty) return maxRequests;
|
return Math.Max(0, maxRequests - (long)val);
|
}
|
}
|
}
|