Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@ public static IApplicationBuilder UseRateLimiter(this IApplicationBuilder app, R

VerifyServicesAreRegistered(app);

return app.UseMiddleware<RateLimitingMiddleware>(Options.Create(options));
var middleware = app.ApplicationServices.GetService<RateLimitingMiddleware>();
if (middleware == null)
{
throw new InvalidOperationException(Resources.FormatUnableToFindServices(
nameof(IServiceCollection),
nameof(RateLimiterServiceCollectionExtensions.AddRateLimiter)));
}
middleware.Initialize(Options.Create(options));

return app.UseMiddleware<RateLimitingMiddleware>();
}

private static void VerifyServicesAreRegistered(IApplicationBuilder app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public static IServiceCollection AddRateLimiter(this IServiceCollection services

services.AddMetrics();
services.AddSingleton<RateLimitingMetrics>();
services.AddSingleton<RateLimitingMiddleware>();
return services;
}
}
56 changes: 37 additions & 19 deletions src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,93 @@ namespace Microsoft.AspNetCore.RateLimiting;
/// <summary>
/// Limits the rate of requests allowed in the application, based on limits set by a user-provided <see cref="PartitionedRateLimiter{TResource}"/>.
/// </summary>
internal sealed partial class RateLimitingMiddleware
internal sealed partial class RateLimitingMiddleware : IMiddleware, IDisposable
{
private readonly RequestDelegate _next;
private readonly Func<OnRejectedContext, CancellationToken, ValueTask>? _defaultOnRejected;
private Func<OnRejectedContext, CancellationToken, ValueTask>? _defaultOnRejected;
private readonly ILogger _logger;
private readonly RateLimitingMetrics _metrics;
private readonly PartitionedRateLimiter<HttpContext>? _globalLimiter;
private readonly PartitionedRateLimiter<HttpContext> _endpointLimiter;
private readonly int _rejectionStatusCode;
private readonly Dictionary<string, DefaultRateLimiterPolicy> _policyMap;
private readonly IServiceProvider _serviceProvider;
private PartitionedRateLimiter<HttpContext>? _globalLimiter;
private PartitionedRateLimiter<HttpContext> _endpointLimiter = null!;
private int _rejectionStatusCode;
private Dictionary<string, DefaultRateLimiterPolicy> _policyMap = null!;
private readonly DefaultKeyType _defaultPolicyKey = new DefaultKeyType("__defaultPolicy", new PolicyNameKey { PolicyName = "__defaultPolicyKey" });

/// <summary>
/// Creates a new <see cref="RateLimitingMiddleware"/>.
/// </summary>
/// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
/// <param name="logger">The <see cref="ILogger"/> used for logging.</param>
/// <param name="options">The options for the middleware.</param>
/// <param name="serviceProvider">The service provider.</param>
/// <param name="metrics">The rate limiting metrics.</param>
public RateLimitingMiddleware(RequestDelegate next, ILogger<RateLimitingMiddleware> logger, IOptions<RateLimiterOptions> options, IServiceProvider serviceProvider, RateLimitingMetrics metrics)
public RateLimitingMiddleware(ILogger<RateLimitingMiddleware> logger, IOptions<RateLimiterOptions> options, IServiceProvider serviceProvider, RateLimitingMetrics metrics)
{
ArgumentNullException.ThrowIfNull(next);
ArgumentNullException.ThrowIfNull(logger);
ArgumentNullException.ThrowIfNull(serviceProvider);
ArgumentNullException.ThrowIfNull(metrics);

_next = next;
_logger = logger;
_metrics = metrics;
_serviceProvider = serviceProvider;

Initialize(options);
}

/// <summary>
/// Initialize or re-initialize the rate limiter with new options. Enables overriding the options from UseRateLimiter(...)
/// </summary>
/// <param name="options"></param>
internal void Initialize(IOptions<RateLimiterOptions> options)
{
_defaultOnRejected = options.Value.OnRejected;
_rejectionStatusCode = options.Value.RejectionStatusCode;
_policyMap = new Dictionary<string, DefaultRateLimiterPolicy>(options.Value.PolicyMap);

// Activate policies passed to AddPolicy<TPartitionKey, TPolicy>
foreach (var unactivatedPolicy in options.Value.UnactivatedPolicyMap)
{
_policyMap.Add(unactivatedPolicy.Key, unactivatedPolicy.Value(serviceProvider));
_policyMap.Add(unactivatedPolicy.Key, unactivatedPolicy.Value(_serviceProvider));
}

_globalLimiter = options.Value.GlobalLimiter;

_endpointLimiter?.Dispose();
_endpointLimiter = CreateEndpointLimiter();
}

public void Dispose()
{
_endpointLimiter?.Dispose();
}

// TODO - EventSource?
/// <summary>
/// Invokes the logic of the middleware.
/// </summary>
/// <param name="context">The <see cref="HttpContext"/>.</param>
/// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
/// <returns>A <see cref="Task"/> that completes when the request leaves.</returns>
public Task Invoke(HttpContext context)
public Task InvokeAsync(HttpContext context, RequestDelegate next)
{
ArgumentNullException.ThrowIfNull(next);

var endpoint = context.GetEndpoint();
// If this endpoint has a DisableRateLimitingAttribute, don't apply any rate limits.
if (endpoint?.Metadata.GetMetadata<DisableRateLimitingAttribute>() is not null)
{
return _next(context);
return next(context);
}
var enableRateLimitingAttribute = endpoint?.Metadata.GetMetadata<EnableRateLimitingAttribute>();
// If this endpoint has no EnableRateLimitingAttribute & there's no global limiter, don't apply any rate limits.
if (enableRateLimitingAttribute is null && _globalLimiter is null)
{
return _next(context);
return next(context);
}

return InvokeInternal(context, enableRateLimitingAttribute);
return InvokeInternal(context, next, enableRateLimitingAttribute);
}

private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute)
private async Task InvokeInternal(HttpContext context, RequestDelegate next, EnableRateLimitingAttribute? enableRateLimitingAttribute)
{
var policyName = enableRateLimitingAttribute?.PolicyName;

Expand All @@ -100,7 +118,7 @@ private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribu
{

_metrics.LeaseStart(metricsContext);
await _next(context);
await next(context);
}
finally
{
Expand Down Expand Up @@ -233,7 +251,7 @@ private async ValueTask<LeaseContext> CombinedWaitAsync(HttpContext context, Can
{
endpointLease?.Dispose();
globalLease?.Dispose();
// Don't throw if the request was canceled - instead log.
// Don't throw if the request was canceled - instead log.
if (ex is OperationCanceledException && context.RequestAborted.IsCancellationRequested)
{
RateLimiterLog.RequestCanceled(_logger);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public void UseRateLimiter_RespectsOptions()

// These should not get used
var services = new ServiceCollection();
services.AddScoped<IMiddlewareFactory, MiddlewareFactory>();
services.AddRateLimiter(options =>
{
options.GlobalLimiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
Expand All @@ -57,7 +58,10 @@ public void UseRateLimiter_RespectsOptions()
// Act
appBuilder.UseRateLimiter(options);
var app = appBuilder.Build();
var context = new DefaultHttpContext();
var context = new DefaultHttpContext()
{
RequestServices = serviceProvider,
};
app.Invoke(context);
Assert.Equal(429, context.Response.StatusCode);
}
Expand All @@ -66,6 +70,7 @@ public void UseRateLimiter_RespectsOptions()
public async Task UseRateLimiter_DoNotThrowWithoutOptions()
{
var services = new ServiceCollection();
services.AddScoped<IMiddlewareFactory, MiddlewareFactory>();
services.AddRateLimiter();
services.AddLogging();
var serviceProvider = services.BuildServiceProvider();
Expand All @@ -74,7 +79,11 @@ public async Task UseRateLimiter_DoNotThrowWithoutOptions()
// Act
appBuilder.UseRateLimiter();
var app = appBuilder.Build();
var context = new DefaultHttpContext();
var context = new DefaultHttpContext()
{
RequestServices = serviceProvider,
};

var exception = await Record.ExceptionAsync(() => app.Invoke(context));

// Assert
Expand Down
56 changes: 28 additions & 28 deletions src/Middleware/RateLimiting/test/RateLimitingMetricsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public async Task Metrics_Rejected()
using var rateLimitingRequestsCollector = new MetricCollector<long>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.requests");

// Act
await middleware.Invoke(context).DefaultTimeout();
await middleware.InvokeAsync(context, c => Task.CompletedTask).DefaultTimeout();

// Assert
Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
Expand Down Expand Up @@ -72,24 +72,25 @@ public async Task Metrics_Success()

var middleware = CreateTestRateLimitingMiddleware(
options,
meterFactory: meterFactory,
next: async c =>
{
await syncPoint.WaitToContinue();
});
meterFactory: meterFactory);
var meter = meterFactory.Meters.Single();

var context = new DefaultHttpContext();
context.Request.Method = "GET";

RequestDelegate next = async c =>
{
await syncPoint.WaitToContinue();
};

using var leaseRequestDurationCollector = new MetricCollector<double>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.request_lease.duration");
using var currentLeaseRequestsCollector = new MetricCollector<long>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.active_request_leases");
using var currentRequestsQueuedCollector = new MetricCollector<long>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.queued_requests");
using var queuedRequestDurationCollector = new MetricCollector<double>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.request.time_in_queue");
using var rateLimitingRequestsCollector = new MetricCollector<long>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.requests");

// Act
var middlewareTask = middleware.Invoke(context);
var middlewareTask = middleware.InvokeAsync(context, next);

await syncPoint.WaitForSyncPoint().DefaultTimeout();

Expand Down Expand Up @@ -131,18 +132,17 @@ public async Task Metrics_ListenInMiddleOfRequest_CurrentLeasesNotDecreased()

var middleware = CreateTestRateLimitingMiddleware(
options,
meterFactory: meterFactory,
next: async c =>
{
await syncPoint.WaitToContinue();
});
meterFactory: meterFactory);
var meter = meterFactory.Meters.Single();

var context = new DefaultHttpContext();
context.Request.Method = "GET";

// Act
var middlewareTask = middleware.Invoke(context);
var middlewareTask = middleware.InvokeAsync(context, async c =>
{
await syncPoint.WaitToContinue();
});

await syncPoint.WaitForSyncPoint().DefaultTimeout();

Expand Down Expand Up @@ -186,17 +186,18 @@ public async Task Metrics_Queued()
var middleware = CreateTestRateLimitingMiddleware(
serviceProvider.GetRequiredService<IOptions<RateLimiterOptions>>(),
meterFactory: meterFactory,
next: async c =>
{
await syncPoint.WaitToContinue();
},
serviceProvider: serviceProvider);
var meter = meterFactory.Meters.Single();

var routeEndpointBuilder = new RouteEndpointBuilder(c => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
routeEndpointBuilder.Metadata.Add(new EnableRateLimitingAttribute("concurrencyPolicy"));
var endpoint = routeEndpointBuilder.Build();

RequestDelegate next = async c =>
{
await syncPoint.WaitToContinue();
};

using var leaseRequestDurationCollector = new MetricCollector<double>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.request_lease.duration");
using var currentLeaseRequestsCollector = new MetricCollector<long>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.active_request_leases");
using var currentRequestsQueuedCollector = new MetricCollector<long>(meterFactory, RateLimitingMetrics.MeterName, "aspnetcore.rate_limiting.queued_requests");
Expand All @@ -207,15 +208,15 @@ public async Task Metrics_Queued()
var context1 = new DefaultHttpContext();
context1.Request.Method = "GET";
context1.SetEndpoint(endpoint);
var middlewareTask1 = middleware.Invoke(context1);
var middlewareTask1 = middleware.InvokeAsync(context1, next);

// Wait for first request to reach server and block it.
await syncPoint.WaitForSyncPoint().DefaultTimeout();

var context2 = new DefaultHttpContext();
context2.Request.Method = "GET";
context2.SetEndpoint(endpoint);
var middlewareTask2 = middleware.Invoke(context1);
var middlewareTask2 = middleware.InvokeAsync(context1, next);

// Assert second request is queued.
Assert.Collection(currentRequestsQueuedCollector.GetMeasurementSnapshot(),
Expand Down Expand Up @@ -261,30 +262,31 @@ public async Task Metrics_ListenInMiddleOfQueued_CurrentQueueNotDecreased()
var middleware = CreateTestRateLimitingMiddleware(
serviceProvider.GetRequiredService<IOptions<RateLimiterOptions>>(),
meterFactory: meterFactory,
next: async c =>
{
await syncPoint.WaitToContinue();
},
serviceProvider: serviceProvider);
var meter = meterFactory.Meters.Single();

var routeEndpointBuilder = new RouteEndpointBuilder(c => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
routeEndpointBuilder.Metadata.Add(new EnableRateLimitingAttribute("concurrencyPolicy"));
var endpoint = routeEndpointBuilder.Build();

RequestDelegate next = async c =>
{
await syncPoint.WaitToContinue();
};

// Act
var context1 = new DefaultHttpContext();
context1.Request.Method = "GET";
context1.SetEndpoint(endpoint);
var middlewareTask1 = middleware.Invoke(context1);
var middlewareTask1 = middleware.InvokeAsync(context1, next);

// Wait for first request to reach server and block it.
await syncPoint.WaitForSyncPoint().DefaultTimeout();

var context2 = new DefaultHttpContext();
context2.Request.Method = "GET";
context2.SetEndpoint(endpoint);
var middlewareTask2 = middleware.Invoke(context1);
var middlewareTask2 = middleware.InvokeAsync(context1, next);

// Start listening while the second request is queued.

Expand Down Expand Up @@ -332,11 +334,9 @@ private static void AssertTag<T>(IReadOnlyDictionary<string, object> tags, strin
}
}

private RateLimitingMiddleware CreateTestRateLimitingMiddleware(IOptions<RateLimiterOptions> options, ILogger<RateLimitingMiddleware> logger = null, IServiceProvider serviceProvider = null, IMeterFactory meterFactory = null, RequestDelegate next = null)
private RateLimitingMiddleware CreateTestRateLimitingMiddleware(IOptions<RateLimiterOptions> options, ILogger<RateLimitingMiddleware> logger = null, IServiceProvider serviceProvider = null, IMeterFactory meterFactory = null)
{
next ??= c => Task.CompletedTask;
return new RateLimitingMiddleware(
next,
logger ?? new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
options,
serviceProvider ?? Mock.Of<IServiceProvider>(),
Expand Down
Loading
Loading