stonstad 发表于 2022-10-8 14:22

超全解读 DeepMind AlphaTensor: 使用AI(RL)自动设计算法 ...

本文详细解读 AlphaTensor,包括此工作在宏观上的定位意义和未来展望、在微观上的算法细节和结果,只需了解矩阵乘法即可阅读本文。有 RL (reinforcement learning 强化学习) 或 AutoML (自动机器学习) 基础会更好~一些参考链接: .
这是原提问中笔者的一个简要回答:
<hr/>背景介绍:矩阵乘

矩阵是非常重要的数学工具,尤其对于深度学习。全连接层几乎处处都在矩阵乘,而卷积神经网络的现代实现 (cuda im2col + GEMM) 也非常依赖矩阵乘。已经有无数的工作在优化计算机的矩阵乘效率。system角度的优化就有向量化、访存命中等等。而从数学角度,矩阵乘需要的计算次数其实也有优化空间。
在一次矩阵乘中,包含了许多次对两个数字的数值乘法或加法运算。考虑计算机架构的乘指令开销远大于加法,我们希望在保证结果正确的前提下,尽量减少矩阵乘中的数值乘法次数。下图是一个 2\times 2 矩阵乘的例子,请见图注。



对A、B两个2x2矩阵相乘得到C。左:按照矩阵乘法的定义需要做8次数值乘法 即h1~h8。右:按照Strassen在50年前提出的算法只需要7次。

可见对 2\times 2矩阵乘,按照Strassen的算法来算,比按照定义要更快 (7次乘法 vs 8次)。拓展到N阶矩阵乘,朴素算法的数值乘法次数显然是 \mathcal{O}(N^3) 的。而利用Strassen,把矩阵 2\times 2 分块,然后分治着乘,有乘法次数 T(N)=7\,T(\frac{N}{2})+\mathcal{O}(N^2),意味着总次数是 \Theta(N^{\text{log}_27})\approx\Theta(N^{2.81}) 的,这就加速了。
因此我们看到:人工设计矩阵乘法算法的流程,可以达到比原始乘法定义更少的数值乘法次数,实现更快的矩阵乘。那能不能让人工智能自动设计算法流程呢?
一句话介绍 AlphaTensor

AlphaTensor利用人工智能技术 (RL强化学习) 对算法 (矩阵乘) 流程进行自动设计。

[*]"算法":指不基于深度学习的传统算法/数值计算方法。
[*]"自动设计":把"算法"的流程参数化 (一套参数即可对应一个算法的流程),然后自动寻找一套好的参数,使其对应的算法流程的执行效率 (时间复杂度) 最优。这样就可以把算法流程的设计问题转化为一个参数的搜索/优化问题。优化目标是时间复杂度,优化手段可以是RL强化学习。
重点/难点在哪?搜索空间

可以看到上面"自动设计算法"这一套其实很像一个自动超参数优化问题:三要素 (搜索空间、搜索方法、优化目标) 非常明确。之前做自动机器学习AutoML/NAS的同学应该很熟悉这套方法论。而想套用这套方法论,难就难在第一步如何定义搜索空间,即如何把矩阵乘算法的流程进行数学化的抽象 (参数化)。
怎么定义搜索空间? 两个对应

AlphaTensor使用如下两个对应关系来构造搜索空间 (实际上这个搜索空间是已有的,在中会看到这个搜索空间已被很多前人工作使用):

[*]对应关系①:一种尺寸的矩阵乘法定义对应一个表征张量
[*]对应关系②:表征张量的一种低秩分解 (分解为R个秩1项) 对应一种包含R次数值乘法的矩阵乘法算法流程
这样,通过自动尝试找更小的R,就可以"自动设计算法"。以 2\times 2 矩阵乘为例,其表征张量记为 \mathcal{T}_2,张量尺寸是 4\times 4\times 4,总共包含64个元素。请见下图与图注:



一个尺寸为4x4x4的矩阵乘法表征张量T2,其对应2x2矩阵的乘法定义即C=A·B;张量T2有4x4x4=64个元素,深色取值为1,浅色取值为0

[对应关系①]

我们说 "\mathcal{T}_2对应了2x2矩阵乘法的定义",是由于:

[*]若 \left\langle a_i,b_j,c_k \right\rangle 位置是1,则代表 c_k 会累计上 a_ib_j。例如上图 c_1 这个维度 (用上下左右前后描述这个三维张量的话,c_1就是最"后"侧或者说最"内"侧的那一层16个格子),有两个格子是1,分别是 a_1b_1 和 a_2b_3,也就意味着 c_1=a_1b_1+a_2b_3.
[*]从前往后逐层考虑,就可以依次得到 c1,c2,c3,c4 的定义,从而得到矩阵乘法的定义。这也就是对应关系①。
[对应关系②]

接下来我们考虑对 \mathcal{T}_2 这个张量进行低秩分解,得到 R 个秩为1的项,即: \mathcal{T}_2=\Sigma_{r=1}^{R} \bm{U}^{(r)} \otimes \bm{V}^{(r)} \otimes \bm{W}^{(r)}\\其中 \otimes 是外积。别慌,这个式子很简单。由于 \mathcal{T}_2 是 4\times 4\times 4 的张量,所以 \bm{U},\bm{V},\bm{W} 都是 4 行 R 列的矩阵 (\bm{U}^{(r)}, \bm{V}^{(r)}, \bm{W}^{(r)} 就都是 4 维列向量)。这里有一个非常直观的几何意义:



左:Strassen算法的流程可以看到只包含m1~m7共七次数值乘法。右:Strassen算法对应的低秩分解得到的U V W.

我们拿 \bm{U},\bm{V},\bm{W} 的第 2 列为例,这可以考虑 m_2 和 c_{1\sim4} 是如何被计算出来的。

[*]绿色矩阵的第二列 \bm{U}^{(2)}=\left^\mathsf{T} 第3、4项为1,代表了上图左侧中 m_2 计算过程中的绿色部分 (a_3+a_4).
[*]类似地,紫色的 \bm{V}^{(2)}=\left^\mathsf{T} 只有第1项为1,代表 m_2 计算过程中的紫色部分 b_1.
[*]黄色矩阵则稍微有些不同,其第二列 \bm{W}^{(2)}=\left^\mathsf{T} 表示 m_2 在输出的 c_1,c_2,c_3,c_4 里的系数分别是 0,0,1,\text{-}1.
这样,三个 \bm{U},\bm{V},\bm{W} 矩阵 (对应着一种\mathcal{T}_2的分解方法) 就唯一确定了一个 2\times 2 矩阵乘法的算法流程。上图Strassen的分解矩阵是是 4 行 7 列的,也就是说 R=7。根据上文"低秩分解矩阵UVW的秩R等于算法流程中的数值乘法次数",所以Strassen的数值乘法次数也就是7。OK,这也就是我们上文说的对应关系②,一种低秩分解唯一确定一种算法流程。
在搜索空间中搜索

定义好了搜索空间 (也就是对应关系①②) 后,考虑 N 阶方阵乘法的定义直接确定了张量 \mathcal{T}_N 的每个元素的0/1取值,所以我们不断尝试对 \mathcal{T}_N 进行低秩分解就可以实现 N 阶方阵乘法算法流程的自动搜索了。搜索目标很简单,就是尽可能让三个 \bm{U},\bm{V},\bm{W} 矩阵的秩尽可能低,因为他们的秩就等于矩阵乘需要的数值乘法次数。
然而实际这个低秩分解问题是 np-hard的,所以这里只好进行启发式搜索 (heuristic search) 了,什么无梯度优化、强化学习都可以拿来用。这就到了 DeepMind 最擅长的领域了——玩游戏 (强化学习)。
RL 状态、动作、reward


[*]环境初始状态是 \mathcal{S}_0\leftarrow\mathcal{T}_N
[*]考虑第 t 步,agent 根据状态 \mathcal{S}_{t-1}进行决策,这个决策是给出仨列向量 \bm{U}^{(t)}, \bm{V}^{(t)}, \bm{W}^{(t)},从而使状态发生如下更新:
\mathcal{S}_t\leftarrow \mathcal{S}_{t-1} - \bm{U}^{(r)} \otimes \bm{V}^{(r)} \otimes \bm{W}^{(r)} \\

[*]当 \mathcal{S}_t = \bm{0} 或 trajectory 的步骤数大于预先指定的 R_\text{limit} 时,trajectory 自动终止。
实际上这就是一个硬凑 \mathcal{T}_N 分解的过程。为了促进 agent 尽可能早地凑出零张量 (得到尽可能低秩的 \bm{U},\bm{V},\bm{W}) reward 是这样设计的:

[*]每凑一步得到一个常数 -1的reward
[*]若 R_\text{limit} 步后没凑出零张量,会得到额外的 -\gamma(\mathcal{S}_{R_\text{limit}}) 的 reward。这个数跟最后剩下的这个结果张量的秩有关,秩越大当然惩罚就越多。
离散化动作之后用 MCTS

为了剪枝搜索空间,agent 在做 action 的时候是离散化做的,即每个决策列向量里面的每个元素的取值都必须在 \{-2,-1,0,1,2\} 中选择。这件事情的形象意义是,算法流程中矩阵的每一项前面的系数都只能从这五个数里面选择。
经过离散化之后 (其实是非常强的一步搜索空间剪枝),整个搜索过程和下围棋就很像了:每一步有多个离散分支选择,对应了一棵搜索树。这时为了更科学地做 explore-exploit,MCTS(蒙特卡洛树搜索) 就可以派上用场了。于是得到了 DeepMind 的传家宝——AlphaZero 风格的 RL+MCTS,如下图:



使用RL+MCTS来搜索TN分解的过程。

关于更多的一些 trick 或者网络结构,详见原始论文。值得一提的是 policy network 里面已经在用 transformer decoder 结构了:



policy head 网络结构图

实验结果

理论结果

首先是论文里最醒目的 Fig. 3:



论文中的 Fig. 3

这个图表怎么看呢?首先表格列出了各种矩阵size下的结果,其中 Best rank known 是指人类已知的最优方案的分解后的秩,也就等于数值乘法的次数。可以看到 AlphaTensor 在较大矩阵乘法上有了超越人类已知的新发现。
而在右图中展示了更大矩阵尺寸的情况,最大到了 11x12 和 12x12 的矩阵乘 (当然,也不是很大)。可以看到整个图是呈递增趋势的,也就是说对于越大的矩阵,AlphaTensor 甩开人类的差距越大。
然而,有一个比较悲观的事实是,AlphaTensor 论文在 intro 第一段中就提到 3x3 矩阵的相乘的全局最优解仍未被发现。而由于 AlphaTensor 本质上也是在 heuristic 搜索,所以它自己也是无法保证找到这个最优解...
硬件 runtime 测试

我们基本可以看到 AlphaTensor 搜出来的矩阵乘法算法还是和 Strassen 有着基本一致的、简洁的形式,所以实际可用性还是比较强。AlphaTensor 的作者毫不畏惧,挑了个硬骨头 baseline 也就是 cuda 上的 BLAS 来比,因此结论应该还是比较 solid 的,但主要还是在大型矩阵上的提升更为明显:



在不同架构上,AlphaTensor 的加速情况

从左边两个图可以看到,无论是 GPU 还是 TPU,当矩阵尺寸大于 8192 时,AlphaTensor 的加速比例还是非常可观的。最右边的图展示了一下在 GPU 上还是跑 GPU 定制优化过的算法更好,TPU 亦然。
宏观讨论:定位和意义

在 AI for science (AI4science) 中的定位:AlphaTensor 指出了一个新的方向——在棋类对弈/游戏 (AlphaGo/Star)、生物 (AlphaFold) 之外,AI 可能还可以在经典算法/数值方法/理论计算机科学上大放光彩。
在 DeepMind Alpha-系列中的定位:算是 AlphaZero 的拓展吧,因为也是基于 RL+MCTS 一套。
笔者认为从设计搜索空间的 formulation 来看,AlphaTensor 比 AlphaGo 困难不少,而和 AlphaStar 相比各有各的难处 (AlphaStar 难在特征工程)。另外从搜索算法执行起来的难度看,AlphaTensor 应该比下围棋和打星际简单不少,毕竟不涉及博弈,不需要设计自对弈的复杂系统。
和自动机器学习 (AutoML) 的关系:可以认为 AutoML 是对 ML 算法的设计,而 AlphaTensor 就是对非 ML 算法的设计。这两者在哲学上讲是有相似之处的。
笔者认为显然 AlphaTensor 的搜索空间比绝大多数 AutoML 系统要复杂得多。不过呢,AlphaTensor 的搜索空间并不是 AlphaTensor 开辟性提出的,而是之前一些用 heuristic search 来搜索矩阵乘的工作就已经提出了并广泛使用了的。
Contribution:技术上 contribution 不多,搜索空间是拿来即用的,搜索算法基本也是沿用 AlphaZero 的。主要的 contribution 在于 insight 和指路的价值上。
未来方向浅展望

AlphaTensor 开辟了一个大方向 (大坑?) 是:把 AI 做到去自动优化经典算法/数值方法/理论计算机科学方法的算法流程。其中的关键难点应该在于分析算法流程的本质,然后建立算法流程的搜索空间。从哲学上来讲,可以认为这是一种新型的"AutoML"——或许应该叫"AutoAlgo"来的更自然些?
实际上笔者个人认为经典算法和矩阵乘算法在性质上相似的还是不多。例如 dijkstra,huffman tree,kmp,一些常见DP,这些经典算法的渐进复杂度看起来没有矩阵乘这种搜一搜算法流程就能"压缩"的机会。相比之下,或许一些数值方法倒是存在一些"超参数",然后适合使用这种搜一搜的自动优化做法吧。
参考


[*]^Strassen, V. (1969). Gaussian elimination is not optimal. Numerische mathematik, 13(4), 354-356.
[*]^Landsberg, Joseph M. Geometry and complexity theory. Vol. 169. Cambridge University Press, 2017.
[*]^张量可以理解为更高维的矩阵,例如一维张量是向量,二维张量是矩阵,三维张量则是更高维的矩阵.
[*]^外积是Kronecker积的一种特例。两个N维向量的外积是NxN维矩阵。例如 和 的外积是一个矩阵 [, ]。NxN矩阵和N维向量的外积是NxNxN的张量.
[*]^Hillar, C. J., & Lim, L. H. (2013). Most tensor problems are NP-hard. Journal of the ACM (JACM), 60(6), 1-39.
[*]^Fawzi, A., Balog, M., Huang, A., Hubert, T., Romera-Paredes, B., Barekatain, M., ... & Kohli, P. (2022). Discovering faster matrix multiplication algorithms with reinforcement learning. Nature, 610(7930), 47-53.
[*]^Smirnov, A. V. (2013). The bilinear complexity and practical algorithms for matrix multiplication. Computational Mathematics and Mathematical Physics, 53(12), 1781-1795.

jquave 发表于 2022-10-8 14:22

Strassen只算乘法复杂度递推公式应该把N去掉,N应该把加法都算成原子操作了,而且得乘个系数或者写成大O(N)吧[思考]

acecase 发表于 2022-10-8 14:30

[赞]严谨,已修正
页: [1]
查看完整版本: 超全解读 DeepMind AlphaTensor: 使用AI(RL)自动设计算法 ...