zt3ff3n 发表于 2022-10-13 16:24

论文分享:Discovering faster matrix multiplication …

一、概述

DeepMind的最新暴力美学《AlphaTensor》,文章对矩阵乘法算法的设计过程进行巧妙的表征(表征本身的设计是前人已有的工作),将寻找快速矩阵乘法算法问题转化为寻找张量的更低秩的分解问题,并利用(Sampled)AlphaZero+足够算力进行暴力求解。AlphaTensor得出的新的矩阵乘法算法,对依赖矩阵乘法的计算过程都能起到显著的加速作用(较常用算法提速10%~20%)。
正文链接:Discovering faster matrix multiplication algorithms with reinforcement learning - Nature
附录链接:https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-022-05172-4/MediaObjects/41586_2022_5172_MOESM1_ESM.pdf
官方blog: Discovering novel algorithms with AlphaTensor
alphazero 以及 sampled alphazero相关内容可移步:强化学习实验室:model based专题三--MuZero系列
二、方法

如果一个实际应用问题可以被抽象成在某个搜索空间内的优化问题,并且该搜索问题需要进行序列多步的优化,就可以使用深度强化学习(尤其是基于MCTS搜索的算法)+足够算力支持,进行暴力求解(至少可以得到可行解)。
在实际问题里,最难的或者说技巧性最强的一步是在(1)实际问题的抽象、建模 ;(2)高效搜索空间的设计。
这篇文章瞄准广泛存在的“矩阵乘法”问题(矩阵乘法的应用广泛存在,寻找更快速的算法,提升矩阵乘法的计算效率意义重大),在“矩阵乘法”问题上,给出了“AI改进经典算法”的一次完整示范:(1)问题的建模,(2)搜索空间设计,(3)搜索算法设计,(4)性能的比较。

具体设计如下:
2.1 两个矩阵相乘,存在多种计算方式

首先,2个矩阵相乘,是存在多种计算方式的。我们在教科书上学过以下2种:
(1)向量内积


(2)向量外积之和(可以看成是分块矩阵乘法)


除了常用的这两种,为了优化计算效率,数学家陆续研究出第3种、第4种、......
以2x2矩阵为例,经典的算法是德国数学家Volken Strassen于1969年发现的更高效的Strassen's algorithm,与上面2种基本算法相比,乘法的次数可以由8次减少到7次。



Strassen's algorithm

在计算机中,乘法带来的计算消耗远大于加法带来的计算消耗。快速矩阵乘法问题就是去寻找数值乘法次数更少的矩阵乘法算法的问题。这类算法的核心思想是:构造中间变量,用更多次的加法来代替乘法。一个简单的例子是平方差公式:a^2-b^2 =(a+b)*(a-b)。也就是说,这些算法本质上描述的是相同的计算过程,区别在于对变量的排列组合以及约简方式(构造中间变量)不一样。
2.2 问题的抽象:矩阵乘法问题的张量表征

既然描述的是同一个计算过程,我们可以对这个过程进行统一的表征,然后利用AI再对这一表征进行转换和约简,从而发现不同的“新算法”。
一种尺寸的矩阵乘法的计算过程对应一个唯一的张量表征(3维表示)



矩阵乘法问题的张量表征

以2x2的矩阵乘法为例,c1=a1*b1+a2*b3,也就是c1的值对应了a1*b1和a2*b3两部分。分别以a、b、c为x轴、y轴、z轴,将这一对应关系表示在3维坐标中就得到 (a1, b1, c1) 和 (a2, b3, c1) 2个点。在c1这个维度上,将 (a1, b1, c1) 和 (a2, b3, c1) 2个点的位置标记为“1”,其余的位置标记为“0”,就得到了这种对应关系的一种“数值表征”。
类似的,也可以得到在c2、c3、c4维度上各自的“数值表征”。这样,我们就得到了表示这种对应关系的一个3维的张量,size为(4, 4, 4),将其记为 \mathcal{T}_{2,2,2},表示的是大小为2x2和2x2的矩阵相乘。\mathcal{T}_{2,2,2}的示意图如上图所示。
2.3 搜索空间设计:矩阵乘法与张量表征的低秩分解存在一一对应关系

接下来考虑对\mathcal{T}_{2,2,2}这个张量进行低秩分解,得到R个秩为1的张量(每个秩1张量是通过三个向量做外积得到的)即:\mathcal{T}_{2,2,2}=\Sigma_{r=1}^{R} \bm{U}^{(r)} \otimes \bm{V}^{(r)} \otimes \bm{W}^{(r)}\\
接上例,\bm{U},\bm{V},\bm{W}都是4行R列的矩阵 (\bm{U}^{(r)}, \bm{V}^{(r)}, \bm{W}^{(r)}就都是4维列向量,表示第 r 列),然后\bm{U}^{(r)} \otimes \bm{V}^{(r)} \otimes \bm{W}^{(r)}就是一个秩为1的张量, \mathcal{T}_{2,2,2} 是这R个秩1张量的和。



矩阵乘法算法(Strassen's algorithm)张量低秩分解

矩阵乘法算法张量低秩分解

[*]图c中绿色矩阵的第 i 列系数与向量 \left^\mathsf{T} 作內积,可以得到图b中计算 m_i 的第一项(绿色部分)。
[*]图c中紫色矩阵的第 i 列系数与向量 \left^\mathsf{T} 作內积,可以得到图b中计算 m_i 的第二项(紫色部分)。
[*]最后,图c中黄色矩阵的第 j 行系数与向量 \left^\mathsf{T} 作內积,可以得到图b中 c_j 的计算过程(黄色部分)。
[*]注意:因系数项都是比较小的常数,系数与对应 a_j, b_j, m_i 的“乘法计算”耗时可忽略。
这样,\bm{U},\bm{V},\bm{W}矩阵的列数就等于低秩分解后秩1张量的个数,也等于矩阵乘法算法中需要的乘法次数。经前人研究,有个结论:一种3维张量的分解方法(一组\bm{U},\bm{V},\bm{W}矩阵)就唯一确定了一个矩阵乘法的算法流程。张量分解的秩越小,导出的矩阵乘法算法中的乘法次数越少。
对应张量低秩分解→矩阵乘法算法的转换过程如下算法所示:



算法:张量低秩分解→矩阵乘法算法

至此,已经将“设计更高效的矩阵乘法算法问题”,转换为“寻找秩更小的张量低秩分解问题”。
在AlphaTensor之前,人们已经发现了这个转化过程,并已经使用各种搜索算法来求解这个问题。本文的主要贡献是将搜索算法换为Sampled AlphaZero并取得了更好的搜索结果。2.4 搜索算法设计:基于Sampled AlphaZero的强化学习搜索算法

寻找秩更小的张量低秩分解问题定义如下:
\text{min}_{\bm{U},\bm{V},\bm{W}} R \\ \text{subject to } \Sigma_{r=1}^{R} \bm{U}^{(r)} \otimes \bm{V}^{(r)} \otimes \bm{W}^{(r)}=\mathcal{T}_{2,2,2} \\
MDP建模:该问题本身可以看做是single state下(可以将输入的待分解张量看做是state),指数动作空间(分解矩阵)的参数搜索问题。DeepMind对搜索空间做了进一步的简化,并将single state下的搜索问题,建模为一个multiple state的MDP。核心是将寻找最小 R 的问题,转换为求解最小step的最短路径问题。
初始状态 s_0=\mathcal{T}_{2,2,2},在每个step t\in\{0,\ldots R\} ,agent的动作是生成 t step对应的3个向量 \bm{U}^{(t)},\bm{V}^{(t)},\bm{W}^{(t)}。同时,agent每多走一步都会收到 r=-1 的惩罚,环境的状态转移到:
s_{t+1}\leftarrow s_{t} - \bm{U}^{(t)} \otimes \bm{V}^{(t)} \otimes \bm{W}^{(t)} \\直到某个时刻 T,s_{T}=\textbf{0} ,即找到了一个可行的动作序列使约束条件\Sigma_{t=1}^{T} \bm{U}^{(t)} \otimes \bm{V}^{(t)} \otimes \bm{W}^{(t)}=\mathcal{T}_{2,2,2}成立。agent执行动作的step数越多,在整个episode收到的cumulative reward就越小(即惩罚越大),因此,随着训练,agent会去寻找step数更少的动作序列。为了避免agent长时间做无效探索,设置了一个最大step数 T_\text{limit},step数超过 T_\text{limit} ,会给一个较大的惩罚(正比于剩余张量 s_{T_\text{limit}}的秩)。
这里在描述方法思想时,用了两个2x2矩阵相乘的张量 \mathcal{T}_{2,2,2} 作为例子,待分解的张量可以是任意形状。高维动作空间:这里,还有一个问题是,即便在每个step,agent只需要生成3个向量 \bm{U}^{(t)},\bm{V}^{(t)},\bm{W}^{(t)} ,向量的每一维数值都可以取任意实数,这仍然是一个高维连续动作空间的优化问题。DeepMind对动作空间做了进一步的简化,约束向量中的每个数值仅能在\{-2,-1,0,1,2\}中选择。
即便如此,离散化后的动作空间仍然非常巨大,假设每个向量的长度为 N ,离散化的完整动作空间为 5^{3N} (随向量维度指数增长)。因此,在实际执行时,DeepMind采用了前序的Sampled AlphaZero的工作,作为backbone的搜索算法。Sampled AlphaZero针对高维的离散动作空间,设计了基于动作空间下采样的策略迭代算子。
这里放一张AlphaZero的核心思想:



MCTS search

关于Sampled AlphaZero详情,请移步:
网络结构:对3维tensor state的encoder以及3个向量action的decoder部分,都采用了比较复杂的Transformer的结构。
这里重点解释下对action的decoder部分,对上文提到的离散化后的指数动作空间 5^{3N},进行了序列化拆解,转换为一个序列生成问题,采用了Causal Transformer,autoregressively decode a discrete action a\in\{-2,-1,0,1,2\} for each vector dimension,总共decode 3N个action。


三、实验

实验设置:
有监督的辅助任务:

[*]增广数据集



[*]混合的training loss



[*]training architecture


主要结论:
1. Discovery of matrix multiplication algorithms
① re-discovers the best algorithms known for multiplying matrices.
② improves over the best algorithms known for several matrix sizes.
③ generates a large database of matrix multiplication algorithms — up to thousands of algorithms for each size (the space is richer than previously known).


有学者指出,这个结果本身提升不大,文中仅对比强调了Strassen's algorithm,但目前理论上最快的算法达到 \text{O}(n^{2.373}) (Ryan Williams在Tweet也说“这种文章不可能被计算理论顶会STOC/FOCS/SODA接收”)。2. Rapid tailored algorithm discovery(面向特定硬件的优化: 基于硬件,定制算法)
在reward function中增加在实际硬件上运行的耗时惩罚:we provide an additional reward at the terminal state (after the agent found a correct algorithm) equal to the negative of the runtime of the algorithm when benchmarked on the target hardware.


图c表明,在一个硬件上优化出的算法,在另一个硬件上会有performance drop,因此,面向特定硬件定制化算法(软件,即硬件软件联合优化)是非常有必要的。
四、总结

DeepMind选问题还是选的好!
不同于以往Learning to optimize任务中(例如经典的Pointer Network 求解TSP问题),针对具体的problem cases,耗费大量算力,train一个神经网络模型,在testing的problem cases上测试+比较model的泛化性。性能的好坏完全依赖于模型的泛化性,导致神经网络模型最终在性能上(例如解的质量)往往不如经典的启发式算法,只是在求解时间上有一定的优势(因为仅调用模型做inference)。
Vinyals O, Fortunato M, Jaitly N. Pointer networks. Advances in neural information processing systems, 2015, 28.AlphaTensor找到了合适的问题设定,模型的输出是直接“应用广泛的且可大量次复用的高性能算法”,不再去拼训练好的模型的泛化性。虽然花费了大量算力训练,但找到更好的“算法”本身就是有意义的。
这就给了我们一种启示,不单纯追求AI模型本身的泛化性,而是利用(训练代价昂贵的)AI模型去优化一些“fundamental algorithms”,对“fundamental algorithms”一点点的性能提升也可能给这个世界带来很大的变化。这也就间接证明了AI算法的价值(再一次为AI续命)。
页: [1]
查看完整版本: 论文分享:Discovering faster matrix multiplication …