KaaPexei 发表于 2022-12-27 21:10

1. Cuda矩阵乘法GeMM性能优化

0. 工作背景

由于近期有对cuda进行优化的工作, 然鹅自己已经有一段时间没有进行深度cuda算法优化, 所以才有了以下对矩阵乘法gemm进行优化的康复训练. 这里我将gemm从约300ms优化到了42ms, 并且达到了cutlass库的70%以上性能, 作为初步结果已经比较令人满意了.
1. 问题描述

首先我们需要对问题进行良好的定义. GeMM, General Matrix Multiply的简写, 指的是
\mathbf{C}=\alpha \text{op}_A(\mathbf{A})\text{op}_B(\mathbf{B})+\beta\mathbf{C}\\
其中的 \text{op}_A 和 \text{op}_B 指的是矩阵转置或者共轭的操作. 这里由于我们只关心并行算法的实现, 所以接下来所有的计算都是基于 \alpha=1, \beta=0 , 不对输入矩阵进行转置, 且输入输出的矩阵都是row-major的情况的. 所以对于 m\times k 的矩阵 \mathbf{A} 和 k\times n 的矩阵 \mathbf{B} , 输出矩阵 \mathbf{C} 是 m\times n 的, 并且 \mathbf{C} 的每一个元素 c_{ij}=\sum\limits_{k=0}^Ka_{ik}b_{kj}.\\ 很容易就可以发现, 对于规模为 m, n, k 的矩阵乘法, 所需要的计算量为 mnk 次 fma (即 d=a\times b+c )操作. 于是接下来我们的性能单位为TFlops, 即每秒钟进行的fma操作次数, 以 10^{12} 次每秒为单位. 而当矩阵较小的时候, 数据量无法喂饱庞大的GPU计算资源, 所以接下来使用 m=n=k=8192 这个足够大但是计算耗时相对可以接受的大小来进行测试, 测试使用的GPU是RTX 4090, 单精度浮点算力约为80TFlops, 显存带宽约为1TB/s 使用Nsight Compute进行定频约2.12GHz进行耗时测试.
2. 基准实现 (baseline)

基准实现是我们进行优化的起点, 通常是初学者都能想到的比较naive的实现. 容易发现, 矩阵乘法 \mathbf{C}=\mathbf{AB} 的结果都可以被单独的计算出来, 于是就有了以下基础版并行代码:
dim3 block_naive = { 32, 32, 1 };
dim3 grid_naive = { c_col / 32, c_row / 32, 1 };

__global__ void gemm_naive(float* a, float* b, float* c, size_t a_x, size_t b_x)
{
        int row = blockIdx.y * blockDim.y + threadIdx.y;
        int col = blockIdx.x * blockDim.x + threadIdx.x;
        float sum(0);
        for (int k(0); k < a_x; ++k)
        {
                sum += a * b;
        }
        c = sum;
}
通过分配 m*n 个线程来计算矩阵 \mathbf{C} 的每个元素来获取并行性. 这个基准实现的耗时是292ms, 对应约1.88TFlops, 并未充分利用GPU硬件资源.
3. 共享内存优化 (shared memory)

稍微分析代码可以发现, 我们的基准实现每次都需要从全局内存 (global memory) 中读取矩阵元素, 而且是每计算一次fma都需要读取两个float, 即8 bytes, 所以计算读写比仅仅有0.125. 但是为什么得到的1.88TFlops远大于0.125乘以带宽的1TB/s得到的0.125TFlops呢? 这是由于存在各级缓存与读取的合并, 并不会每次都从全局内存中读取. 但是内存带宽瓶颈仍然是存在的, 一个很通常的做法是引入共享内存, 即每个线程块首先将计算需要的矩阵块读取进共享内存, 然后再利用共享内存进行读取, 这样可以大大减轻对全局内存的访问. 于是我们将矩阵 \mathbf{C} 分成 32\times 32 的小块, 然后每个线程块解决每个小块的计算, 但是线程块对矩阵 \mathbf{A}, \mathbf{B} 的读取被分解成多次, 每次也是读取 32\times 32 的小块:
constexpr unsigned int TILE_DIM = 32;

dim3 block_shared = { TILE_DIM, TILE_DIM, 1 };
dim3 grid_shared = { c_col / TILE_DIM, c_row / TILE_DIM, 1 };

__global__ void gemm_shared(float* a, float* b, float* c, size_t a_x, size_t b_x)
{
        __shared__ float aTile, bTile;
        int row = blockIdx.y * blockDim.y + threadIdx.y;
        int col = blockIdx.x * blockDim.x + threadIdx.x;
        float sum = 0.0f;
        for (int c0(0); c0 < a_x; c0 += TILE_DIM)
        {
                aTile = a;
                bTile = b[(c0 + threadIdx.y) * b_x + col];
                __syncthreads();
                for (int i = 0; i < TILE_DIM; i++)
                {
                        sum += aTile * bTile;
                }
                __syncthreads();
        }
        c = sum;
}
这里引入了两个__syncthreads(), 是用于避免在其他线程还没完成对shared memory的写入时就开始进行计算的. 这个共享内存的优化版本的耗时是324ms, 对应约1.70TFlops, 居然比基准实现还慢!
原因大概率是对aTile的访问是一个非常影响性能的32-way bank conflict. 具体解释后面会提到. 从这也可以看出, 有时候优化会引入新的需要优化的问题, 这时候不一定会比原实现更快.
4. 另一种并行模式

接下来的实现源于一篇古老的paper: Vasily Volkov, James Demmel: Benchmarking GPUs to tune dense linear algebra. 这个实现在2008年还是state-of-the-art, 但是随着硬件的不断发展, 其实现的GPU峰值性能比例也在逐渐降低.
constexpr unsigned int VecSize_gemm_fast = 256;
constexpr unsigned int VecWarp_gemm_fast = 64;
constexpr unsigned int VecWarpNum_gemm_fast = VecSize_gemm_fast / VecWarp_gemm_fast;
constexpr unsigned int RollWidth_gemm_fast = 64;
constexpr unsigned int RollTimes_gemm_fast = VecWarp_gemm_fast;

dim3 block_fast = { VecWarp_gemm_fast, VecWarpNum_gemm_fast, 1 };
dim3 grid_fast = { c_row / VecSize_gemm_fast, c_col / RollWidth_gemm_fast, 1 };

__global__ void gemm_fast(float* a, float* b, float* c, size_t a_x, size_t b_x)
{
        unsigned int id = threadIdx.x + threadIdx.y * VecWarp_gemm_fast;
        a += a_x * (blockIdx.x * VecSize_gemm_fast + id);
        b += b_x * threadIdx.y + blockIdx.y * RollWidth_gemm_fast + threadIdx.x;
        c += b_x * (blockIdx.x * VecSize_gemm_fast + id) + blockIdx.y * RollWidth_gemm_fast;
        __shared__ float bs;
        float cs = { 0 };
        int cnt(0);
        do
        {
                // read: 16 + 64
                // calc: 4096
                // ratio: 4096 / (4 * 80) = 12.8
                for (int i(0); i < RollTimes_gemm_fast; i += VecWarpNum_gemm_fast)
                {
                        bs = b;
                }
                b += RollTimes_gemm_fast * b_x;
                cnt += RollTimes_gemm_fast;
                __syncthreads();
                for (int i(0); i < RollTimes_gemm_fast; ++i, ++a)
                {
                        float a0 = a;
                        for (int j(0); j < RollWidth_gemm_fast; ++j)
                        {
                                cs += a0 * bs;
                        }
                }
                __syncthreads();
        } while (cnt < a_x);
        for (int i(0); i < RollWidth_gemm_fast; ++i)
        {
                c = cs;
        }
}
该实现只将矩阵 \mathbf{B} 的分块放入共享内存, 而矩阵 \mathbf{A} 的元素从全局内存读取. 同时使用了64个寄存器存放矩阵 \mathbf{C} 的临时结果. 按照其实现, 我们发现每个循环内每个线程读取了16个矩阵 \mathbf{B} 的元素和64个矩阵 \mathbf{A} 的元素, 而进行了4096次fma计算, 所以可以计算得到计算读写比为12.8, 但是测试结果显示其耗时为144ms, 对应约3.82TFlops, 这远低于理论上限12.8T.
个人认为原因是其每次进行fma计算的一个操作数都是在共享内存的, 查看对应ptx代码也发现每次fma计算都要进行shared memory的读取, 虽然不存在bank conflict, 但是这种方式仍然远慢于直接使用寄存器进行fma计算. 于是便有了以下的我自己的一种实现.
5. 进一步8*8分块

既然共享内存的速度也无法满足要求, 那么将一小块矩阵直接存入寄存器怎么样呢?
constexpr unsigned int TileSize_faster = 128;
constexpr unsigned int RollLength_faster = 16;
constexpr unsigned int KernelSize_faster = 8;
constexpr unsigned int KernelLength_faster = 8;
constexpr unsigned int KernelNum_faster = TileSize_faster / KernelSize_faster;
constexpr unsigned int ThreadNum_faster = KernelNum_faster * KernelNum_faster;

dim3 block_faster = { KernelNum_faster, KernelNum_faster, 1 };
dim3 grid_faster = { c_row / TileSize_faster, c_col / TileSize_faster, 1 };

__global__ void gemm_faster(float* a, float* b, float* c, size_t a_x, size_t b_x)
{
        __shared__ float ta, tb;
        int row = blockIdx.y * blockDim.y * KernelSize_faster;
        int col = blockIdx.x * blockDim.x * KernelSize_faster;
        float ar;
        float br;
        float cr = { 0 };
        for (int c0(0); c0 < a_x; c0 += RollLength_faster)
        {
                // read: 128*16*2
                // calc: 128*128*16
                // ratio: 128/2 / 4 = 16
                int id = threadIdx.x + threadIdx.y * KernelNum_faster;
                for (int c1(0); c1 < TileSize_faster; c1 += ThreadNum_faster / RollLength_faster)
                {
                        int x = id % RollLength_faster;
                        int y = c1 + id / RollLength_faster;
                        ta = a[(row + y) * a_x + c0 + x];
                }
                for (int c1(0); c1 < RollLength_faster; c1 += ThreadNum_faster / TileSize_faster)
                {
                        int x = id % TileSize_faster;
                        int y = c1 + id / TileSize_faster;
                        tb = b[(c0 + y) * b_x + col + x];
                }
                __syncthreads();
                for (int c1(0); c1 < RollLength_faster; c1 += KernelLength_faster)
                {
                        for (int i(0); i < KernelSize_faster; ++i)
                        {
                                for (int j(0); j < KernelLength_faster; ++j)
                                {
                                        // broadcast and 2 way conflict: threadIdx.y = (2 * n) and (2 * n + 1)
                                        ar = ta;
                                        // 4 way conflict:
                                        // 4 threads from threadIdx.x 0, 4, 8, 12 access the same bank
                                        // 4 threads from threadIdx.x 1, 5, 9, 13 access the same bank
                                        // we need to make thread 0, 4 access bank 0, 1
                                        br = tb;
                                }
                        }
                        for (int i(0); i < KernelSize_faster; ++i)
                                for (int k(0); k < KernelLength_faster; ++k)
                                        for (int j(0); j < KernelSize_faster; ++j)
                                                cr += ar * br;
                }
                __syncthreads();
        }
        for (int c0(0); c0 < KernelSize_faster; ++c0)
        {
                for (int c1(0); c1 < KernelSize_faster; ++c1)
                {
                        c = cr;
                }
        }
}
这里我直接将 8\times 8 的小块矩阵ar, br, cr存入了寄存器, 所以相当于引入了一个三级的存储系统: global memory最慢, shared memory次之, register最快, 所以首先把矩阵元素从global memory读取进shared memory, 然后再将shared memory载入到寄存器内, 以此达到更高的性能.
实测该实现的耗时为54ms, 对应约10.2TFlops, 已经达到了很高的水平.
不过这份代码仍有进一步优化的空间: 这里存在bank conflict. 什么是bank conflict呢? 可以将shared memory视为一个每32个32bit元素作为一行的矩阵, 而对于同一列的32bit元素, 是属于同一个bank的, 所以总共有32个bank.
n-way bank conflict指的是同一warp (同时运行的32个线程) 中n个不同的线程访问了shared memory同一列的不同元素, 这样会造成串行的访问. 当然, 如果这些线程访问的是同一个元素, 由于broadcast机制的存在, 实际上并不存在conflict.
在上面的代码中我们发现读取ta时存在2-way bank conflict, 读取tb时存在4-way bank conflict (具体分析留给读者), tb属于比较严重的了. 于是就有了下面的进一步优化:
6. 解决bank conflict

constexpr unsigned int TileSizeA_extreme = 256;
constexpr unsigned int TileSizeB_extreme = 128;
constexpr unsigned int RollLength_extreme = 16;
constexpr unsigned int KernelSize_extreme = 8;
constexpr unsigned int KernelLength_extreme = 8;
constexpr unsigned int KernelNumX_extreme = TileSizeB_extreme / KernelSize_extreme;
constexpr unsigned int KernelNumY_extreme = TileSizeA_extreme / KernelSize_extreme;
constexpr unsigned int ThreadNum_extreme = KernelNumX_extreme * KernelNumY_extreme;

dim3 block_extreme = { KernelNumX_extreme, KernelNumY_extreme, 1 };
dim3 grid_extreme = { c_row / TileSizeB_extreme, c_col / TileSizeA_extreme, 1 };

__global__ void gemm_extreme(float* a, float* b, float* c, size_t a_x, size_t b_x)
{
        __shared__ float ta, tb;
        int row = blockIdx.y * blockDim.y * KernelSize_extreme;
        int col = blockIdx.x * blockDim.x * KernelSize_extreme;
        float ar;
        float br;
        float cr = { 0 };
        int id = threadIdx.x + threadIdx.y * KernelNumX_extreme;
        for (int c0(0); c0 < a_x; c0 += RollLength_extreme)
        {
                // read: (128 + 256) * 16
                // calc: 128 * 256 * 16
                // ratio: 128 * 256 / (4 * (128 + 256)) = 21.3
                int x = id % RollLength_extreme;
                int x8 = (x >> 3) << 3;
                for (int c1(0); c1 < TileSizeA_extreme; c1 += ThreadNum_extreme / RollLength_extreme)
                {
                        int y = c1 + id / RollLength_extreme;
                        int new_x_0 = x8 + ((x + ((y >> 3) & 1)) & 7);
                        ta = a[(row + y) * a_x + c0 + x];
                }
                x = id % TileSizeB_extreme;
                x8 = (x >> 3) << 3;
                int new_x = x8 + ((x + (x >> 5)) & 7);
                for (int c1(0); c1 < RollLength_extreme; c1 += ThreadNum_extreme / TileSizeB_extreme)
                {
                        int y = c1 + id / TileSizeB_extreme;
                        tb = b[(c0 + y) * b_x + col + x];
                }
                __syncthreads();
                for (int c1(0); c1 < RollLength_extreme; c1 += KernelLength_extreme)
                {
                        for (int i(0); i < KernelSize_extreme; ++i)
                        {
                                int new_i = (i + (threadIdx.x >> 2)) & 7;
                                for (int j(0); j < KernelLength_extreme; ++j)
                                {
                                        int new_j = (j + (threadIdx.y & 1)) & 7;
                                        ar = ta;
                                        br = tb;
                                }
                        }
                        for (int i(0); i < KernelSize_extreme; ++i)
                                for (int j(0); j < KernelSize_extreme; ++j)
                                        for (int k(0); k < KernelLength_extreme; ++k)
                                                cr += ar * br;
                }
                __syncthreads();
        }
        for (int c0(0); c0 < KernelSize_extreme; ++c0)
        {
                for (int c1(0); c1 < KernelSize_extreme; ++c1)
                {
                        c = cr;
                }
        }
}
这个解决方案的核心思想是通过对threadIdx.x不同的线程进行循环内的轮换, 使得其对shared memory读取的时候不会读取到同一个bank. 而且由于ar, br是寄存器, 并不能进行实际意义上的针对不同线程进行不同位置的访存, 所以必须将这种轮换同时作用与shared memory的写入时.
最终, 这份代码的耗时为42ms, 达到了13.1TFlops, 而作为对比, cutlass的gemm实现仅仅使用了24ms, 达到了22.9TFlops, 所以说继续优化的空间还是有的.

pc8888888 发表于 2022-12-27 21:13

这收藏的可比赞多多了

APSchmidt 发表于 2022-12-27 21:19

cuda的计算结构比较复杂,所以对于特定的一个算法,针对性的优化,可以想象性能方面可以有很大的提升余地

LiteralliJeff 发表于 2022-12-27 21:20

雀食,都是满满的干货
页: [1]
查看完整版本: 1. Cuda矩阵乘法GeMM性能优化