|
我们定义一个矩阵 为一个大小为m行n列的矩阵,那么两个矩阵的乘法可以定义为:
一个简单的矩阵乘法实现
for (int i = 0; i < M; i++)
for (int j = 0; j < N; j++)
for (int k = 0; k < K; k++)
C[j] += A[k] * B[k][j]
性能:
M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 1998 | 1.074 | 循环重排序(reordering)
调整上述实现中j循环和k循环的顺序,那么我们可以得到
for (int i = 0; i < M; i++)
for (int k = 0; k < K; k++)
for (int j = 0; j < N; j++)
C[j] += A[k] * B[k][j]很明显,循环重排之后,对于A,B,C的空间访问局部性都很得到了保证。
M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 200 | 10.69 | 分块(tiling)
分块可以进一步的提升B C矩阵的空间局部性, 我们把C分成多个title,然后针对每一个title,A和B种对应的行tile和列tile也会切成相应大小的tile进行多个小矩阵乘法,最后加和到C的tile中,当我们把tile的大小限定到合适的范围内时,就可以把整个tile填充到cache内,分块的好处就体现在一个block内的计算小到可以被cache容纳。
for (int i_outer = 0; i_outer < iOuterBound; i_outer++) {
for (int j_outer = 0; j_outer < jOuterBound; j_outer++) {
for (int k_outer = 0; k_outer < kOuterBound; k_outer++) {
for (int i_inner = 0; i_inner < iTile; i_inner++) {
for (int k_inner = 0; k_inner < kTile; k_inner++) {
for (int j_inner = 0; j_inner < jTile; j_inner++) {
C[(i_outer * iTile + i_inner) * N +
(j_outer * jTile + j_inner)] +=
A[(i_outer * iTile + i_inner) * K +
(k_outer * kTile + k_inner)] *
B[(k_outer * kTile + k_inner) * N +
(j_outer * jTile + j_inner)];
}
}
}
}
}
}M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 203 | 10.557 | 向量化(SIMD)
把最内层的连续的多个乘法计算用一个向量乘法完成,需要用到avx256操作。
M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 94 | 22.716 | array packing
把分块优化技术中对B的一个块的访问转化为空间上连续的一段数组访问,而不是跳跃式的逐段访问,这中优化技术需要我们用一块中间数组来完成。
M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 81 | 26.253 | 写缓存优化(write caching)
类似的,对C的每个块的计算结果,我们在写回的时候,也是跳跃式的逐段写回,理想的情况应该是一个块的计算结果在内存中是连续的放置的,所以我们需要再开一块write cache内存空间,然后每个block的计算结果直接在write cache上读写,最后计算完一个block之后,整块写回C数组中对应的不同数组段上。
M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 105 | 20.384 | 并行
直接对乘法循环最外层进行unroll,有一定效果,单次乘法的时间下降到了16ms,但离线性扩展还很遥远。
M, K, N | 时间(ms) | GFLOPs | 1024, 1024, 1024 | 16 | 130.178 |
本文例程:
TVM codegen
参考TVM官方的GEMM优化示例
我在实验机器上面测试了一把,最后使用所有6个优化技术,生成如下的IR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {&#34;from_legacy_te_schedule&#34;: True, &#34;global_symbol&#34;: &#34;main&#34;, &#34;tir.noalias&#34;: True}
buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], []),
A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}
buffer_map = {A_1: A, B_1: B, C_1: C} {
allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global {
for (bigN: int32, 0, 32) &#34;parallel&#34; {
for (k: int32, 0, 1024) {
packedB[ramp(((bigN*32768) + (k*32)), 1, 32)] = (float32x32*)B_2[ramp(((k*1024) + (bigN*32)), 1, 32)]
}
}
for (m.outer: int32, 0, 32) &#34;parallel&#34; {
allocate(C.global: Pointer(global float32), float32, [1024]), storage_scope = global;
for (n.outer: int32, 0, 32) {
for (m.c.init: int32, 0, 32) {
C.global[ramp((m.c.init*32), 1, 32)] = broadcast(0f32, 32)
}
for (k.outer: int32, 0, 256) {
for (m.c: int32, 0, 32) {
C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[(((m.outer*32768) + (m.c*1024)) + (k.outer*4))], 32)*(float32x32*)packedB[ramp(((n.outer*32768) + (k.outer*128)), 1, 32)]))
C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 1)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 32), 1, 32)]))
C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 2)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 64), 1, 32)]))
C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 3)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 96), 1, 32)]))
}
}
for (m.inner: int32, 0, 32) {
for (n.inner: int32, 0, 32) {
C_2[((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)) + n.inner)] = (float32*)C.global[((m.inner*32) + n.inner)]
}
}
}
}
}
}对于M=1024, K=1024, N=1024的矩阵乘法,耗时14.036 ms,也即142.857GFlops.
综上,手写的实现达到了TVM优化版本的130.178/142.857 ≈ 91% |
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有账号?立即注册
×
|