在.Net 7中Array/Span的IndexOf性能优化

前言

前一段时间.Net源码有关Array下IndexOf性能优化,说是Array,还是跟Span有关,主要是Array下IndexOf和Span下Contains在值类型(int和long),会调用SpanHelp.T.cs文件中的IndexOfValueType方法.这次是使用指令集(向量)优化.
  1. 具体的issue: Vectorize SpanHelpers<T>.IndexOf (#60974)

测试代码

using System;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Order;

namespace CSharpBenchmarks.SpanTest
{
    [MemoryDiagnoser]
    [DisassemblyDiagnoser(printSource: true)]
    [Orderer(SummaryOrderPolicy.FastestToSlowest)]
    public class IndexOfTest
    {
        public byte[] bytes = new byte[] { 0, 0, 0, 0, 71, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 71, 0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0 };

        public byte[] searchBytes = new byte[] { 0, 1, 0, 0 };

        [Params(1024, 2048)]
        public int Times { get; set; }

        [Benchmark]
        public int SpanIndexOf()
        {
            int sum = 0;
            for (int i = 0; i < Times; i++)
            {
                sum += bytes.AsSpan().IndexOf(searchBytes);
            }
            return sum;
        }

        [Benchmark]
        public int SpanLastIndexOf()
        {
            int sum = 0;
            for (int i = 0; i < Times; i++)
            {
                sum += bytes.AsSpan().LastIndexOf(searchBytes);
            }
            return sum;
        }
    }
}

测试在.Net 7中Array和Span的IndexOf方法性能改进

从测试结果得出:IndexOf在.Net 7上相对.Net 6提升了71%,LastIndexOf在.Net 7上相对.Net 6提升36%(提升没那么多).因为是指令集优化,生成的汇编代码比较多,这里就就不贴了.

我们可以学习IndexOfValueType源码:

internal static unsafe int IndexOfValueType<T>(ref T searchSpace, T value, int length) where T : struct, IEquatable<T>
{
    Debug.Assert(length >= 0);

    nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
    if (Vector.IsHardwareAccelerated && Vector<T>.IsTypeSupported && (Vector<T>.Count * 2) <= length)
    {
        Vector<T> valueVector = new Vector<T>(value);
        Vector<T> compareVector = default;
        Vector<T> matchVector = default;
        if ((uint)length % (uint)Vector<T>.Count != 0)
        {
            // Number of elements is not a multiple of Vector<T>.Count, so do one
            // check and shift only enough for the remaining set to be a multiple
            // of Vector<T>.Count.
            compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
            matchVector = Vector.Equals(valueVector, compareVector);
            if (matchVector != Vector<T>.Zero)
            {
                goto VectorMatch;
            }
            index += length % Vector<T>.Count;
            length -= length % Vector<T>.Count;
        }
        while (length > 0)
        {
            compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
            matchVector = Vector.Equals(valueVector, compareVector);
            if (matchVector != Vector<T>.Zero)
            {
                goto VectorMatch;
            }
            index += Vector<T>.Count;
            length -= Vector<T>.Count;
        }
        goto NotFound;
    VectorMatch:
        for (int i = 0; i < Vector<T>.Count; i++)
            if (compareVector[i].Equals(value))
                return (int)(index + i);
    }

    while (length >= 8)
    {
        if (value.Equals(Unsafe.Add(ref searchSpace, index)))
            goto Found;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
            goto Found1;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
            goto Found2;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
            goto Found3;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 4)))
            goto Found4;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 5)))
            goto Found5;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 6)))
            goto Found6;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 7)))
            goto Found7;

        length -= 8;
        index += 8;
    }

    while (length >= 4)
    {
        if (value.Equals(Unsafe.Add(ref searchSpace, index)))
            goto Found;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
            goto Found1;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
            goto Found2;
        if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
            goto Found3;

        length -= 4;
        index += 4;
    }

    while (length > 0)
    {
        if (value.Equals(Unsafe.Add(ref searchSpace, index)))
            goto Found;

        index += 1;
        length--;
    }
NotFound:
    return -1;

Found: // Workaround for https://github.com/dotnet/runtime/issues/8795
    return (int)index;
Found1:
    return (int)(index + 1);
Found2:
    return (int)(index + 2);
Found3:
    return (int)(index + 3);
Found4:
    return (int)(index + 4);
Found5:
    return (int)(index + 5);
Found6:
    return (int)(index + 6);
Found7:
    return (int)(index + 7);
}


秋风 2022-06-05