在C#中学习如何使用SIMD

前言

在前面说 在C#中如何提高Linq的性能 ,其中就提到在.Net 8之前的版本(.Net Core 3.0之后的版本,)可以通过使用SimdLinq库来提高性能.SimdLinq源码还是比较轻量级的,因为轻量级只提供了以下方法的支持:
  1. Sum(支持的类型: int, uint, long, ulong, float, double)
  2. LongSum(支持的类型: int, uint)
  3. Average(支持的类型: int, uint, long, ulong, float, double)
  4. Min(支持的类型: byte, sbyte, short, ushort, int, uint, long, ulong, float, double)
  5. Max(支持的类型: byte, sbyte, short, ushort, int, uint, long, ulong, float, double)
  6. MinMax(支持的类型: byte, sbyte, short, ushort, int, uint, long, ulong, float, double)
  7. Contains(支持的类型: byte, sbyte, short, ushort, int, uint, long, ulong, float, double)
  8. SequenceEqual(支持的类型:byte, sbyte, short, ushort, int, uint, long, ulong, float, double)

支持的集合有: T[], List<T>, Span<T>, Memory<T>, ReadOnlyMemory<T>, Span<T>, ReadOnlySpan<T>.

理解SIMD(单指令多数据)

这里从汇编代码来说SIMD的.
extern printf
section .data
        dummy   db      13
align   16
        ;pdivector1相当于数组
        pdivector1      dd 1
                        dd 2
                        dd 3
                        dd 4
	;pdivector2相当于数组
        pdivector2      dd 5
                        dd 6
                        dd 7
                        dd 8
        fmt             db "Sum Vector:%d %d %d %d",10,0
section .bss
alignb  16
        pdivector_res   resd    4
section .text
        global main
main:
	;序言
        push    rbp
        mov     rbp,rsp

	;*****将pdivector1,加载到xmm0寄存器中*****
        movdqa  xmm0,[pdivector1]
	;*****将pdivector2和xmm0寄存器内的值,进行加法运算***
        paddd  xmm0,[pdivector2]

        ;将结果保存在内存中
        movdqa  [pdivector_res],xmm0
        ;打印内存中的向量
        mov     rsi,pdivector_res
        mov     rdi,fmt
        call    printpdi

	;尾言
        mov     rsp,rbp
        pop     rbp
        ret

;打印-----------------------------------
printpdi:
        push    rbp
        mov     rbp,rsp

        movdqa  xmm0,[rsi]
        ;从xmmo0中提取打包的值
        pextrd  esi,xmm0,0
        pextrd  edx,xmm0,1
        pextrd  ecx,xmm0,2
        pextrd  r8d,xmm0,3

        ;没有浮点数
        mov     rax,0
        call    printf
        mov     rsp,rbp
        pop     rbp
        ret

使用汇编代码,使用Simd将数据加载xmm0寄存器上,并进行加法运算

重点就这两行代码:

;*****将pdivector1,加载到xmm0寄存器中*****
movdqa  xmm0,[pdivector1]
;*****将pdivector2和xmm0寄存器内的值,进行加法运算***
paddd  xmm0,[pdivector2]

通过画图来看SIMD单指令多数

接着看SimdLinq源码

先看一下SimdLinq源码目录:
SimdLinq源码目录
在学习源码的时候,只需要关心文件名带有Core的源码,包含Core源码就是具体实现.
接着看一下Sum的源码:
static T SumCore<T>(ReadOnlySpan<T> source)
where T : struct, INumber<T>
{
	T sum = T.Zero;

	if (!Vector128.IsHardwareAccelerated || source.Length < Vector128<T>.Count)
	{
		// Not SIMD supported or small source.
		//1. 当硬件不支持,会退变为for循环
		//2. 集合内的数量小于Vector128的支持的数量,如int->4 long->2,会退变for循环
		unchecked // SIMD operation is unchecked so keep same behaviour
		{
			for (int i = 0; i < source.Length; i++)
			{
				sum += source[i];
			}
		}
	}
	else if (!Vector256.IsHardwareAccelerated || source.Length < Vector256<T>.Count)
	{
		// Only 128bit SIMD supported or small source.
		//满足128bit,不足256bit的,数量少的时候

		//获取开始元素
		ref var begin = ref MemoryMarshal.GetReference(source);

		//获取结尾的元素
		ref var last = ref Unsafe.Add(ref begin, source.Length);
		ref var current = ref begin;
		//获取一个初始值为0的Vector128
		var vectorSum = Vector128<T>.Zero;

		//集合的长度减去Vector128<T>的数量,让开始元素进行偏移
		ref var to = ref Unsafe.Add(ref begin, source.Length - Vector128<T>.Count);
		//开始元素的地址是否和小于一次Vector的地址
		while (Unsafe.IsAddressLessThan(ref current, ref to))
		{
			//如果是int类型, 就一次加载前4个元素
			vectorSum += Vector128.LoadUnsafe(ref current);
			current = ref Unsafe.Add(ref current, Vector128<T>.Count); //如果是int类型,偏移4个元素
		}

		//判断current的地址是否小于结尾元素的地址
		//处理不够一次Vector的时候,退变循环处理
		while (Unsafe.IsAddressLessThan(ref current, ref last))
		{
			unchecked // SIMD operation is unchecked so keep same behaviour
			{
				sum += current;
			}
			current = ref Unsafe.Add(ref current, 1); //每次偏移1个元素
		}

		sum += Vector128.Sum(vectorSum); //进行求和计算
	}
	else
	{
		// 256bit SIMD supported
                //Vector256就不注释了
		ref var begin = ref MemoryMarshal.GetReference(source);
		ref var last = ref Unsafe.Add(ref begin, source.Length);
		ref var current = ref begin;
		var vectorSum = Vector256<T>.Zero;

		ref var to = ref Unsafe.Add(ref begin, source.Length - Vector256<T>.Count);
		while (Unsafe.IsAddressLessThan(ref current, ref to))
		{
			vectorSum += Vector256.LoadUnsafe(ref current);
			current = ref Unsafe.Add(ref current, Vector256<T>.Count);
		}
		while (Unsafe.IsAddressLessThan(ref current, ref last))
		{
			unchecked // SIMD operation is unchecked so keep same behaviour
			{
				sum += current;
			}
			current = ref Unsafe.Add(ref current, 1);
		}

		sum += Vector256.Sum(vectorSum);
	}

	return sum;
}


秋风 2024-03-05