Asp.Net Core并发限制数中间件

起因

因为前几天在看Asp.Net Core中间件的相关的源码.便看到有这个并发限制中间件,就顺便看看源码.

1. 如何使用

并发限制数中间件有两种策略模式.
  1. 队列策略: 请求先进先出
  2. 栈策略:请求后进先出

Asp.Net Core并发限制数中间件


//使用队列策略模式
services.AddQueuePolicy(options =>
{

    //最大并发请求数,超过之后,进行排队
    options.MaxConcurrentRequests = 100;

    //最大请求数,超过之后,返回503
    options.RequestQueueLimit = 100;

});

//使用栈策略模式
services.AddStackPolicy(options =>
{
    //最大并发请求数,超过之后,进行排队
    options.MaxConcurrentRequests = 100;

    //最大请求数,超过之后,返回503
    options.RequestQueueLimit = 100;

});
//如果这两个策略同时使用,后面的策略模式会覆盖上边的策略模式
//启用并发限制数中间件
app.UseConcurrencyLimiter();

2. 源码实现

这里的源码最后都加1,是方便调试的.
QueuePolicyServiceCollectionExtensions源码:
/// <summary>
/// Contains methods for specifying which queue the middleware should use.
/// </summary>
public static class QueuePolicyServiceCollectionExtensions
{
    /// <summary>
    /// Tells <see cref="ConcurrencyLimiterMiddleware"/> to use a FIFO queue as its queueing strategy.
    /// </summary>
    /// <param name="services">The <see cref="IServiceCollection"/> to add services to.</param>
    /// <param name="configure">Set the options used by the queue.
    /// Mandatory, since <see cref="QueuePolicyOptions.MaxConcurrentRequests"></see> must be provided.</param>
    /// <returns></returns>
    public static IServiceCollection AddQueuePolicy(this IServiceCollection services, Action<QueuePolicyOptions> configure)
    {
        services.Configure(configure);
        services.AddSingleton<IQueuePolicy, QueuePolicy>();  //QueuePolicy以单例方式添加到容器中
        return services;
    }

    /// <summary>
    /// Tells <see cref="ConcurrencyLimiterMiddleware"/> to use a LIFO stack as its queueing strategy.
    /// </summary>
    /// <param name="services">The <see cref="IServiceCollection"/> to add services to.</param>
    /// <param name="configure">Set the options used by the queue.
    /// Mandatory, since <see cref="QueuePolicyOptions.MaxConcurrentRequests"></see> must be provided.</param>
    /// <returns></returns>
    public static IServiceCollection AddStackPolicy(this IServiceCollection services, Action<QueuePolicyOptions> configure)
    {
        services.Configure(configure);
        services.AddSingleton<IQueuePolicy, StackPolicy>();  //StackPolicy以单例方式添加到容器中
        return services;
    }
}

QueuePolicy源码:

internal class QueuePolicy : IQueuePolicy, IDisposable
{
    private readonly int _maxConcurrentRequests;
    private readonly int _requestQueueLimit;
    private readonly SemaphoreSlim _serverSemaphore;

    private object _totalRequestsLock = new object();
    public int TotalRequests { get; private set; }

    public QueuePolicy(IOptions<QueuePolicyOptions> options)
    {
        _maxConcurrentRequests = options.Value.MaxConcurrentRequests;
        if (_maxConcurrentRequests <= 0)
        {
            throw new ArgumentException(nameof(_maxConcurrentRequests), "MaxConcurrentRequests must be a positive integer.");
        }

        _requestQueueLimit = options.Value.RequestQueueLimit;
        if (_requestQueueLimit < 0)
        {
            throw new ArgumentException(nameof(_requestQueueLimit), "The RequestQueueLimit cannot be a negative number.");
        }

        _serverSemaphore = new SemaphoreSlim(_maxConcurrentRequests);  //用最大并发数初作为信号量的并发数
        Console.WriteLine("QueuePolicy ctor");
    }

    public async ValueTask<bool> TryEnterAsync()
    {
        // a return value of 'false' indicates that the request is rejected
        // a return value of 'true' indicates that the request may proceed
        // _serverSemaphore.Release is *not* called in this method, it is called externally when requests leave the server

        lock (_totalRequestsLock)
        {
            //但请求数量大于 (最大限制请求数+最大并发数),直接返回false
            if (TotalRequests >= _requestQueueLimit + _maxConcurrentRequests)
            {
                return false;
            }

            TotalRequests++;
        }
           
        //信号量进行异步等待(信号量内部并发数减一)
        await _serverSemaphore.WaitAsync();

        return true;
    }

    public void OnExit()
    {
        //信号量释放,信号量内部的并发数加一
        _serverSemaphore.Release();

        lock (_totalRequestsLock)
        {
            TotalRequests--;
        }
    }

    public void Dispose()
    {
        _serverSemaphore.Dispose();
    }
}

StackPolicy源码:

internal class StackPolicy : IQueuePolicy
{
    private readonly List<ResettableBooleanCompletionSource> _buffer;
    public ResettableBooleanCompletionSource _cachedResettableTCS;

    private readonly int _maxQueueCapacity;
    private readonly int _maxConcurrentRequests;
    private bool _hasReachedCapacity;
    private int _head;
    private int _queueLength;

    private readonly object _bufferLock = new Object();

    private readonly static ValueTask<bool> _trueTask = new ValueTask<bool>(true);

    private int _freeServerSpots;

    public StackPolicy(IOptions<QueuePolicyOptions> options)
    {
        //
        _buffer = new List<ResettableBooleanCompletionSource>();
        _maxQueueCapacity = options.Value.RequestQueueLimit;
        _maxConcurrentRequests = options.Value.MaxConcurrentRequests;
        _freeServerSpots = options.Value.MaxConcurrentRequests;
        Console.WriteLine("StackPolicy ctor");
    }

    public ValueTask<bool> TryEnterAsync()
    {
        lock (_bufferLock)
        {
            if (_freeServerSpots > 0)
            {
                _freeServerSpots--;
                return _trueTask;
            }

            // if queue is full, cancel oldest request
            if (_queueLength == _maxQueueCapacity)
            {
                _hasReachedCapacity = true;
                _buffer[_head].Complete(false);
                _queueLength--;
            }

            //
            var tcs = _cachedResettableTCS ??= new ResettableBooleanCompletionSource(this);
            _cachedResettableTCS = null;

            if (_hasReachedCapacity || _queueLength < _buffer.Count)
            {
                _buffer[_head] = tcs;
            }
            else
            {
                _buffer.Add(tcs);
            }
            _queueLength++;

            // increment _head for next time
            _head++;
            if (_head == _maxQueueCapacity)
            {
                _head = 0;
            }

            return tcs.GetValueTask();
        }
    }

    public void OnExit()
    {
        lock (_bufferLock)
        {
            if (_queueLength == 0)
            {
                _freeServerSpots++;

                if (_freeServerSpots > _maxConcurrentRequests)
                {
                    _freeServerSpots--;
                    throw new InvalidOperationException("OnExit must only be called once per successful call to TryEnterAsync");
                }

                return;
            }

            // step backwards and launch a new task
            if (_head == 0)
            {
                _head = _maxQueueCapacity - 1;
            }
            else
            {
                _head--;
            }

            _buffer[_head].Complete(true);
            _queueLength--;
        }
    }
}

ConcurrencyLimiterExtensions源码:

public static class ConcurrencyLimiterExtensions
{
    /// <summary>
    /// Adds the <see cref="ConcurrencyLimiterMiddleware"/> to limit the number of concurrently-executing requests.
    /// </summary>
    /// <param name="app">The <see cref="IApplicationBuilder"/>.</param>
    /// <returns>The <see cref="IApplicationBuilder"/>.</returns>
    public static IApplicationBuilder UseConcurrencyLimiter(this IApplicationBuilder app)
    {
        if (app == null)
        {
            throw new ArgumentNullException(nameof(app));
        }

        return app.UseMiddleware<ConcurrencyLimiterMiddleware>();
    }
}

ConcurrencyLimiterMiddleware源码:

/// <summary>
/// Limits the number of concurrent requests allowed in the application.
/// </summary>
public class ConcurrencyLimiterMiddleware
{
    private readonly IQueuePolicy _queuePolicy;
    private readonly RequestDelegate _next;
    private readonly RequestDelegate _onRejected;
    private readonly ILogger _logger;

    /// <summary>
    /// Creates a new <see cref="ConcurrencyLimiterMiddleware"/>.
    /// </summary>
    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
    /// <param name="loggerFactory">The <see cref="ILoggerFactory"/> used for logging.</param>
    /// <param name="queue">The queueing strategy to use for the server.</param>
    /// <param name="options">The options for the middleware, currently containing the 'OnRejected' callback.</param>
    public ConcurrencyLimiterMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IQueuePolicy queue, IOptions<ConcurrencyLimiterOptions> options)
    {
        if (options.Value.OnRejected == null)
        {
            throw new ArgumentException("The value of 'options.OnRejected' must not be null.", nameof(options));
        }

        _next = next;
        _logger = loggerFactory.CreateLogger<ConcurrencyLimiterMiddleware>();
        _onRejected = options.Value.OnRejected;
        _queuePolicy = queue;
    }

    /// <summary>
    /// Invokes the logic of the middleware.
    /// </summary>
    /// <param name="context">The <see cref="HttpContext"/>.</param>
    /// <returns>A <see cref="Task"/> that completes when the request leaves.</returns>
    public async Task Invoke(HttpContext context)
    {
        var waitInQueueTask = _queuePolicy.TryEnterAsync();   //队列策略和栈策略都实现TryEnterAsync(请求进来时,减一)和OnExit(请求结束,加一)

        // Make sure we only ever call GetResult once on the TryEnterAsync ValueTask b/c it resets.
        bool result;

        if (waitInQueueTask.IsCompleted)
        {
            ConcurrencyLimiterEventSource.Log.QueueSkipped();
            result = waitInQueueTask.Result;
        }
        else
        {
            using (ConcurrencyLimiterEventSource.Log.QueueTimer())
            {
                result = await waitInQueueTask;
            }
        }

        if (result)  //如果策略返回true,则进入下一个中间件继续处理请求,并调用OnExit,返回false,说明请求超过最大并发数,以503状态返回
        {
            try
            {
                await _next(context);
            }
            finally
            {
                _queuePolicy.OnExit();
            }
        }
        else
        {
            ConcurrencyLimiterEventSource.Log.RequestRejected();
            ConcurrencyLimiterLog.RequestRejectedQueueFull(_logger);
            context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
            await _onRejected(context);
        }
    }

    private static class ConcurrencyLimiterLog
    {
        private static readonly Action<ILogger, int, Exception> _requestEnqueued =
            LoggerMessage.Define<int>(LogLevel.Debug, new EventId(1, "RequestEnqueued"), "MaxConcurrentRequests limit reached, request has been queued. Current active requests: {ActiveRequests}.");

        private static readonly Action<ILogger, int, Exception> _requestDequeued =
            LoggerMessage.Define<int>(LogLevel.Debug, new EventId(2, "RequestDequeued"), "Request dequeued. Current active requests: {ActiveRequests}.");

        private static readonly Action<ILogger, int, Exception> _requestRunImmediately =
            LoggerMessage.Define<int>(LogLevel.Debug, new EventId(3, "RequestRunImmediately"), "Below MaxConcurrentRequests limit, running request immediately. Current active requests: {ActiveRequests}");

        private static readonly Action<ILogger, Exception> _requestRejectedQueueFull =
            LoggerMessage.Define(LogLevel.Debug, new EventId(4, "RequestRejectedQueueFull"), "Currently at the 'RequestQueueLimit', rejecting this request with a '503 server not availible' error");

        internal static void RequestEnqueued(ILogger logger, int activeRequests)
        {
            _requestEnqueued(logger, activeRequests, null);
        }

        internal static void RequestDequeued(ILogger logger, int activeRequests)
        {
            _requestDequeued(logger, activeRequests, null);
        }

        internal static void RequestRunImmediately(ILogger logger, int activeRequests)
        {
            _requestRunImmediately(logger, activeRequests, null);
        }

        internal static void RequestRejectedQueueFull(ILogger logger)
        {
            _requestRejectedQueueFull(logger, null);
        }
    }
}


秋风 2020-08-10