Asp.Net Core UseMiddleware内部实现

起因

学习了在Asp.Net Core如何Use使用中间件,然后便好奇UseMiddleware是如何将中间件进行注册的.至于如何编写中间件可以看 给Asp.Net Core写一个结束进程的中间件 ,通过UseMiddleware注册中间件,编写中间件有一定的格式和规范.
规范如下:
  1. 构造函数参数为RequestDelegate next
  2. 必须有Invoke或者InvokeAsync函数,参数为HttpContext类型,返回类型为Task

至于为什么要这样,在看了UseMiddleware源码之后,才真的明白.UseMiddleware最终还是通过Use将中间件注册到请求管道中的
Asp.Net Core UseMiddleware内部是如何实现将中间件添加到请求管道中的.

实现

/// <summary>
/// Extension methods for adding typed middleware.
/// </summary>
public static class UseMiddlewareExtensions
{
    internal const string InvokeMethodName = "Invoke";
    internal const string InvokeAsyncMethodName = "InvokeAsync";

    private static readonly MethodInfo GetServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetService), BindingFlags.NonPublic | BindingFlags.Static)!;

    /// <summary>
    /// Adds a middleware type to the application's request pipeline.
    /// </summary>
    /// <typeparam name="TMiddleware">The middleware type.</typeparam>
    /// <param name="app">The <see cref="IApplicationBuilder"/> instance.</param>
    /// <param name="args">The arguments to pass to the middleware type instance's constructor.</param>
    /// <returns>The <see cref="IApplicationBuilder"/> instance.</returns>
    public static IApplicationBuilder UseMiddleware<TMiddleware>(this IApplicationBuilder app, params object[] args)
    {
        return app.UseMiddleware(typeof(TMiddleware), args);
    }

    /// <summary>
    /// Adds a middleware type to the application's request pipeline.
    /// 将中间件进行注册到请求管道中,可以看到UseMiddleware最终还是由Use添加到管道中
    /// </summary>
    /// <param name="app">The <see cref="IApplicationBuilder"/> instance.</param>
    /// <param name="middleware">The middleware type.</param>
    /// <param name="args">The arguments to pass to the middleware type instance's constructor.</param>
    /// <returns>The <see cref="IApplicationBuilder"/> instance.</returns>
    public static IApplicationBuilder UseMiddleware(this IApplicationBuilder app, Type middleware, params object[] args)
    {
        //1. 如果中间件实现IMiddleware接口,调用UseMiddlewareInterface
        if (typeof(IMiddleware).GetTypeInfo().IsAssignableFrom(middleware.GetTypeInfo()))
        {
            // IMiddleware doesn't support passing args directly since it's
            // activated from the container
            if (args.Length > 0)
            {
                throw new NotSupportedException(Resources.FormatException_UseMiddlewareExplicitArgumentsNotSupported(typeof(IMiddleware)));
            }

            return UseMiddlewareInterface(app, middleware);
        }

        //2. 中间件没有实现IMiddleware接口
        var applicationServices = app.ApplicationServices;  //获取di的容器
        return app.Use(next =>
        {
            //2.1 获取编写中间件的实例方法,并筛选为名称为Invoke或InvokeAsync的方法
            var methods = middleware.GetMethods(BindingFlags.Instance | BindingFlags.Public);
            var invokeMethods = methods.Where(m =>
                string.Equals(m.Name, InvokeMethodName, StringComparison.Ordinal)
                || string.Equals(m.Name, InvokeAsyncMethodName, StringComparison.Ordinal)
                ).ToArray();

            //去除一些验证

            var methodInfo = invokeMethods[0];
            //去除一些验证

            var parameters = methodInfo.GetParameters();
            //去除一些验证

            //
            var ctorArgs = new object[args.Length + 1];
            ctorArgs[0] = next;
            Array.Copy(args, 0, ctorArgs, 1, args.Length);

            //2.2 通过中间件的类型和构造函数参数实例化中间件,实例存放在app.ApplicationServices(容器)
            var instance = ActivatorUtilities.CreateInstance(app.ApplicationServices, middleware, ctorArgs);
                
            if (parameters.Length == 1)
            {
                //当Invoke或InvokeAsync方法参数的长度为1的时候,通过中间件实例创建创建委托
                return (RequestDelegate)methodInfo.CreateDelegate(typeof(RequestDelegate), instance);
            }

            //2.3 当Invoke或InvokeAsync方法参数的长度大于1的时候,便通过方法信息和参数,创建表达式树,生成委托
            var factory = Compile<object>(methodInfo, parameters);

            return context =>
            {
                var serviceProvider = context.RequestServices ?? applicationServices;
                if (serviceProvider == null)
                {
                    throw new InvalidOperationException(Resources.FormatException_UseMiddlewareIServiceProviderNotAvailable(nameof(IServiceProvider)));
                }

                return factory(instance, context, serviceProvider);
            };
        });
    }

    private static IApplicationBuilder UseMiddlewareInterface(IApplicationBuilder app, Type middlewareType)
    {
        return app.Use(next =>
        {
            return async context =>
            {
                //1. 通过容器获取MiddlewareFactory
                var middlewareFactory = (IMiddlewareFactory?)context.RequestServices.GetService(typeof(IMiddlewareFactory));
                if (middlewareFactory == null)
                {
                    // No middleware factory
                    throw new InvalidOperationException(Resources.FormatException_UseMiddlewareNoMiddlewareFactory(typeof(IMiddlewareFactory)));
                }

                //2. 通过MiddlewareFactory创建中间件的实例
                var middleware = middlewareFactory.Create(middlewareType);
                if (middleware == null)
                {
                    // The factory returned null, it's a broken implementation
                    throw new InvalidOperationException(Resources.FormatException_UseMiddlewareUnableToCreateMiddleware(middlewareFactory.GetType(), middlewareType));
                }

                try
                {
                    //3. 调用InvokeAsync方法
                    await middleware.InvokeAsync(context, next);
                }
                finally
                {
                    middlewareFactory.Release(middleware);
                }
            };
        });
    }

    private static Func<T, HttpContext, IServiceProvider, Task> Compile<T>(MethodInfo methodInfo, ParameterInfo[] parameters)
    {
        // If we call something like
        //
        // public class Middleware
        // {
        //    public Task Invoke(HttpContext context, ILoggerFactory loggerFactory)
        //    {
        //
        //    }
        // }
        //

        // We'll end up with something like this:
        //   Generic version:
        //
        //   Task Invoke(Middleware instance, HttpContext httpContext, IServiceProvider provider)
        //   {
        //      return instance.Invoke(httpContext, (ILoggerFactory)UseMiddlewareExtensions.GetService(provider, typeof(ILoggerFactory));
        //   }

        //   Non generic version:
        //
        //   Task Invoke(object instance, HttpContext httpContext, IServiceProvider provider)
        //   {
        //      return ((Middleware)instance).Invoke(httpContext, (ILoggerFactory)UseMiddlewareExtensions.GetService(provider, typeof(ILoggerFactory));
        //   }

        var middleware = typeof(T);

        var httpContextArg = Expression.Parameter(typeof(HttpContext), "httpContext");
        var providerArg = Expression.Parameter(typeof(IServiceProvider), "serviceProvider");
        var instanceArg = Expression.Parameter(middleware, "middleware");

        var methodArguments = new Expression[parameters.Length];
        methodArguments[0] = httpContextArg;
        for (int i = 1; i < parameters.Length; i++)
        {
            var parameterType = parameters[i].ParameterType;
            if (parameterType.IsByRef)
            {
                throw new NotSupportedException(Resources.FormatException_InvokeDoesNotSupportRefOrOutParams(InvokeMethodName));
            }

            var parameterTypeExpression = new Expression[]
            {
                providerArg,
                Expression.Constant(parameterType, typeof(Type)),
                Expression.Constant(methodInfo.DeclaringType, typeof(Type))
            };

            var getServiceCall = Expression.Call(GetServiceInfo, parameterTypeExpression);
            methodArguments[i] = Expression.Convert(getServiceCall, parameterType);
        }

        Expression middlewareInstanceArg = instanceArg;
        if (methodInfo.DeclaringType != null && methodInfo.DeclaringType != typeof(T))
        {
            middlewareInstanceArg = Expression.Convert(middlewareInstanceArg, methodInfo.DeclaringType);
        }

        var body = Expression.Call(middlewareInstanceArg, methodInfo, methodArguments);

        var lambda = Expression.Lambda<Func<T, HttpContext, IServiceProvider, Task>>(body, instanceArg, httpContextArg, providerArg);

        return lambda.Compile();
    }

    private static object GetService(IServiceProvider sp, Type type, Type middleware)
    {
        var service = sp.GetService(type);
        if (service == null)
        {
            throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));
        }

        return service;
    }
}
秋风 2020-07-26