找回密码
 立即注册
查看: 806|回复: 20

C++加速矩阵乘法的最简单方法

[复制链接]
发表于 2023-4-18 08:14 | 显示全部楼层 |阅读模式
假设我们封装了矩阵类,现在要实现矩阵乘法 C=AB 。为了讨论方便,设所有矩阵都是 n 阶矩阵。本文假定矩阵是按行存储的。最普通的实现方式如下:
for(int i=0;i<n;++i)
    for(int j=0;j<n;++j)
        for(int k=0;k<n;++k)
            C[j]+=A[k]*B[k][j];
这种方式,当然是速度不快的方式。对于矩阵乘法的加速,有如下两种策略:

  • 基于算法的优化
  • 基于硬件的优化
对于算法优化,最广为人知的是Strassen算法,能达到 O(n^{2.7}) 的时间复杂度,这甚至还不是渐进时间复杂度意义上最快的算法。但在实际的库中,没有用Strassen算法实现的,因为常数太大,要想超越朴素的算法,矩阵的规模必须非常大。本文暂不探讨算法优化。
对于硬件优化,有通用矩阵乘法(GEMM),里面有各种循环展开、基于内存布局等的优化技巧。此外还可以多线程(C++11有<thread>库)并发执行。这里不说的那么复杂,就说说最简单的一个技巧:调换循环顺序。例如将原来的 ijk 顺序改成 ikj 顺序:
for(int i=0;i<n;++i)
    for(int k=0;k<n;++k)
        s=A[k];
        for(int j=0;j<n;++j)
            C[j]+=s*B[k][j];
对于一千阶的矩阵,速度提升5倍左右。下面分析一下为什么效果这么显著。
首先从矩阵的存储方式说起。一般而言,矩阵有两种存储方式:

  • 一维数组
  • 二维数组
对于矩阵乘法而言,一维数组显然比二维数组好得多(下文的分析也能看出这一点)。但是如果用矩阵类实现其他功能的话,可能其他功能用二维数组更方便一些。所以下面对于这两种实现方式都加以分析。
造成矩阵乘法慢的原因,除了算法上的 O(n^3) 以外,还有内存访问不连续。这会导致cache命中率不高。所以为了加速,就要尽可能使内存访问连续,即不要跳来跳去。我们定义一个概念:跳跃数,来衡量访问的不连续程度。
对于最普通的实现方式(顺序: ijk ),它是依次计算 C 中的每个元素。当计算 C 中任一个元素时,需要将 A 对应的行与 B 对应的列依次相乘相加。之前已假设过,矩阵是按行存储的,所以在 A 相应行中不断向右移动时,内存访问是连续的。但 B 相应列不断向下移动时,内存访问是不连续的。计算完 C 的一个元素时, B 相应列中已经间断地访问了 n 次,而 A 只间断 1 次(这一次就是算完后跳转回本行的开头),故总共是 n+1 次。这样计算完 C 中所有 n^2  个元素,跳转了 n^3+n^2 次。但刚才没有计数 C 的跳转次数,加上以后是 n^3+n^2+n 。(注意,在计算完 C 中每行的最后一个元素时, A 是从相应行末尾转到下一行开头。如果使用一维数组实现的话,这是连续地访问,要减掉这 n 次。同时, C 没有跳转次数了,还要减掉 n 次。因此对于一维数组,跳转数是 n^3+n^2-n 次)
而如果以顺序 ikj 实现,它将 C 中元素一行一行计算。当计算 C 中任一行的第一个元素时,先访问 A 中相应行的第一列元素,和 B 中第一列的第一行元素,然后 B 不断往右挪(不间断),算完后跳转到下一行(如果二维数组则间断一次,一维数组不间断),此时 A 往右挪一个元素(不间断)。依次这样挪动,这样算完 C 的这一行元素后,恰好按顺序将 B 遍历一遍,间断了 n 次(一维数组是 1 次),且恰好从左往右遍历了 A 的相应行,间断了 1 次(一维数组没有间断),加起来是 n+1 次(一维数组是 1 次)。故算完 C 的所有 n 行后,跳转了 n^2+n 次(一维数组是 n 次)。刚才没有算 C 的跳转,算上后跳跃数是 2n^2+n 次(一维数组是 n^2 次)。
(我上面这块写的很乱,找时间修改一下)
由此可见:

  • 顺序 ikj 的跳转数渐进地少于顺序 ijk 的跳转数
  • 一维数组比二维数组好
下面是各个循环顺序的跳跃数列表(下面是我写文章时现计算的,可能会因为粗心犯错)

  • 顺序 ikj —— 2n^2+n (二维数组)—— n^2 (一维数组)
  • 顺序 kij —— 3n^2 (二维数组)—— 2n^2 (一维数组)
  • 顺序 jik —— n^3+2n^2 (二维数组)—— n^3+n^2+n (一维数组)
  • 顺序 ijk —— n^3+n^2+n (二维数组)—— n^3+n^2-n (一维数组)
  • 顺序 kji —— 2n^3+n (二维数组)—— 2n^3 (一维数组)
  • 顺序 jki —— 2n^3+n^2 (二维数组)—— 2n^3+n^2 (一维数组)
因此从速度来说:
ikj>kij>jik>ijk>kji>jki
实测速度:( 1000 阶, \text{gcc} , -O3 优化)

  • ikj:847ms
  • kij:1028ms
  • jik:2733ms
  • ijk:4552ms
  • kji:17269ms
  • jki:17197ms
发现基本符合( kji 与 jki 略微违反,其实它们很接近,回过头看理论分析也表明很接近)
发表于 2023-4-18 08:19 | 显示全部楼层
这样也会不会便于编译器优化用simd?
发表于 2023-4-18 08:24 | 显示全部楼层
不清楚
发表于 2023-4-18 08:28 | 显示全部楼层
有几个优化点,一个是对外层循环做tile,利用访问内存的局部性尽量多利用缓存中现有的数据进行计算,通常对i和j做blocking,第二是对外层循环并行化,第三是内层循环可以向量化,硬件通常提供了fma,也就是y=a*b+c的指令,最后在向量化的基础上还可以unroll,减少循环次数
发表于 2023-4-18 08:31 | 显示全部楼层
一般的实现都用一位数组封装成多维数组实现,cpp中通过重载实现起来很方便
发表于 2023-4-18 08:32 | 显示全部楼层
说的有道理
发表于 2023-4-18 08:34 | 显示全部楼层
这个优化的专业
发表于 2023-4-18 08:43 | 显示全部楼层
simd 大法比写这个更要快一些
发表于 2023-4-18 08:52 | 显示全部楼层
确实unrolling挺重要的
发表于 2023-4-18 08:55 | 显示全部楼层
csapp里面讲到了
懒得打字嘛,点击右侧快捷回复 【右侧内容,后台自定义】
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Unity开发者联盟 ( 粤ICP备20003399号 )

GMT+8, 2024-11-16 22:51 , Processed in 0.106700 second(s), 27 queries .

Powered by Discuz! X3.5 Licensed

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表