jquave 发表于 2022-7-22 18:11

矩阵乘法&优化方法

我们定义一个矩阵 为一个大小为m行n列的矩阵,那么两个矩阵的乘法可以定义为:

https://www.zhihu.com/equation?tex=A_%7Bm+%5Ctimes+k%7D+B_%7Bk+%5Ctimes+n%7D+%3D+C_%7Bm+%5Ctimes+n%7D

https://www.zhihu.com/equation?tex=C_%7Bi%2Cj%7D+%3D+%5Csum_%7Bt+%3D+1%7D%5E%7Bk%7D+A_%7Bi%2Ct%7D+B_%7Bt%2Cj%7D
一个简单的矩阵乘法实现

for (int i = 0; i < M; i++)
for (int j = 0; j < N; j++)
    for (int k = 0; k < K; k++)
      C += A * B
性能:
M, K, N时间(ms)GFLOPs1024, 1024, 102419981.074循环重排序(reordering)

调整上述实现中j循环和k循环的顺序,那么我们可以得到
for (int i = 0; i < M; i++)
for (int k = 0; k < K; k++)
    for (int j = 0; j < N; j++)
      C += A * B很明显,循环重排之后,对于A,B,C的空间访问局部性都很得到了保证。


M, K, N时间(ms)GFLOPs1024, 1024, 102420010.69分块(tiling)

分块可以进一步的提升B C矩阵的空间局部性, 我们把C分成多个title,然后针对每一个title,A和B种对应的行tile和列tile也会切成相应大小的tile进行多个小矩阵乘法,最后加和到C的tile中,当我们把tile的大小限定到合适的范围内时,就可以把整个tile填充到cache内,分块的好处就体现在一个block内的计算小到可以被cache容纳。
for (int i_outer = 0; i_outer < iOuterBound; i_outer++) {
for (int j_outer = 0; j_outer < jOuterBound; j_outer++) {
    for (int k_outer = 0; k_outer < kOuterBound; k_outer++) {
      for (int i_inner = 0; i_inner < iTile; i_inner++) {
      for (int k_inner = 0; k_inner < kTile; k_inner++) {
          for (int j_inner = 0; j_inner < jTile; j_inner++) {
            C[(i_outer * iTile + i_inner) * N +
            (j_outer * jTile + j_inner)] +=
                A[(i_outer * iTile + i_inner) * K +
                  (k_outer * kTile + k_inner)] *
                B[(k_outer * kTile + k_inner) * N +
                  (j_outer * jTile + j_inner)];
          }
      }
      }
    }
}
}M, K, N时间(ms)GFLOPs1024, 1024, 102420310.557向量化(SIMD)

把最内层的连续的多个乘法计算用一个向量乘法完成,需要用到avx256操作。
M, K, N时间(ms)GFLOPs1024, 1024, 10249422.716array packing

把分块优化技术中对B的一个块的访问转化为空间上连续的一段数组访问,而不是跳跃式的逐段访问,这中优化技术需要我们用一块中间数组来完成。


M, K, N时间(ms)GFLOPs1024, 1024, 10248126.253写缓存优化(write caching)

类似的,对C的每个块的计算结果,我们在写回的时候,也是跳跃式的逐段写回,理想的情况应该是一个块的计算结果在内存中是连续的放置的,所以我们需要再开一块write cache内存空间,然后每个block的计算结果直接在write cache上读写,最后计算完一个block之后,整块写回C数组中对应的不同数组段上。
M, K, N时间(ms)GFLOPs1024, 1024, 102410520.384并行

直接对乘法循环最外层进行unroll,有一定效果,单次乘法的时间下降到了16ms,但离线性扩展还很遥远。

M, K, N时间(ms)GFLOPs1024, 1024, 102416130.178

本文例程:

TVM codegen

参考TVM官方的GEMM优化示例
我在实验机器上面测试了一把,最后使用所有6个优化技术,生成如下的IR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {C: Buffer(C_2: Pointer(float32), float32, , []),
             B: Buffer(B_2: Pointer(float32), float32, , []),
             A: Buffer(A_2: Pointer(float32), float32, , [])}
buffer_map = {A_1: A, B_1: B, C_1: C} {
allocate(packedB: Pointer(global float32x32), float32x32, ), storage_scope = global {
    for (bigN: int32, 0, 32) "parallel" {
      for (k: int32, 0, 1024) {
      packedB = (float32x32*)B_2
      }
    }
    for (m.outer: int32, 0, 32) "parallel" {
      allocate(C.global: Pointer(global float32), float32, ), storage_scope = global;
      for (n.outer: int32, 0, 32) {
      for (m.c.init: int32, 0, 32) {
          C.global = broadcast(0f32, 32)
      }
      for (k.outer: int32, 0, 256) {
          for (m.c: int32, 0, 32) {
            C.global = ((float32x32*)C.global + (broadcast((float32*)A_2[(((m.outer*32768) + (m.c*1024)) + (k.outer*4))], 32)*(float32x32*)packedB))
            C.global = ((float32x32*)C.global + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 1)], 32)*(float32x32*)packedB))
            C.global = ((float32x32*)C.global + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 2)], 32)*(float32x32*)packedB))
            C.global = ((float32x32*)C.global + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 3)], 32)*(float32x32*)packedB))
          }
      }
      for (m.inner: int32, 0, 32) {
          for (n.inner: int32, 0, 32) {
            C_2[((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)) + n.inner)] = (float32*)C.global[((m.inner*32) + n.inner)]
          }
      }
      }
    }
}
}对于M=1024, K=1024, N=1024的矩阵乘法,耗时14.036 ms,也即142.857GFlops.
综上,手写的实现达到了TVM优化版本的130.178/142.857 ≈ 91%

JamesB 发表于 2022-7-22 18:20

reordering 部分代码 应该时 C += A * B 吧

mastertravels77 发表于 2022-7-22 18:25

是的 已经更正 谢谢
页: [1]
查看完整版本: 矩阵乘法&优化方法