- 优化器 (1) SGD和Adam
- 优化器 (2) Muon
引言
在上一篇文章中,我们深入介绍了基于梯度下降算法的各种机器学习优化器。这些优化器以SGD和Adam为代表,通过迭代更新模型参数,最小化损失函数。
在实际的机器学习模型中,参数往往是一个矩阵(比如MLP、transformer中的QKV等)。然而在SGD和Adam中,我们一般将参数看作一个一维向量,即 $\theta\in\mathbb{R}^d$,其中 $d$ 是参数的数量。这种强制展平的优化方式忽视了模型的结构。
近期,有一项研究在社区中引起了不小的关注。这项研究提出了一个全新的优化器Muon (MomentUm Orthogonalized by Newton-Schulz),这个优化器将参数作为一个矩阵而非向量来看待,因此可以更好地利用参数的结构信息,苏剑林在博客将这个工作称为【从向量到矩阵的本质跨越】。
在LLM的预训练中,研究者们发现Muon拥有比默认选择AdamW更好的性能,且训练的稳定性更高、收敛速度更快。更重要的是,Muon只需要维护一份梯度的动量估计(Adam需要维护一阶矩和二阶矩),因此在大规模训练中有很大优势。因此,Muon在Kimi-k1.5和Kimi-k2上均被用作预训练阶段的优化器。
在这篇博客中,我们就跟随着苏剑林的一系列博客,详细介绍Muon优化器的原理和实现。
一、Muon的数学原理 [1]
1.1. Muon的更新公式
Muon的全称是 “MomentUm Orthogonalized by Newton-Schulz”,从名字中可以看到,Muon的主要创新点是利用Newton-Schulz方法对参数的动量进行正交化。
具体来说,Muon是一个适用于矩阵参数 $W\in\mathbb{R}^{n\times m}$ 的优化器,其更新规则是:
\[\begin{equation} \begin{aligned} G_t&=\nabla_{W}\mathcal{J}(W_{t-1})\\ M_t&=\beta M_{t-1} + G_t\\ W_t&=W_{t-1} - \eta_t[\text{msign}(M_t)+\lambda W_{t-1}] \end{aligned} \label{eq:muon_update} \end{equation}\]公式 \eqref{eq:muon_update} 中出现了一个矩阵符号函数 $\text{msign}(M)$,它并不是简单地对矩阵的每个元素进行 $\text{sign}$ 操作,而是 $\text{sign}$ 函数在矩阵上的推广。其定义如下:
\[\begin{equation} \begin{aligned} U,\Sigma,V^T&=\text{SVD}(M)\\ \text{msign}(M)&=U_{[:,:r]}V_{[:,:r]}^T \end{aligned} \end{equation}\]其中,$U\in\mathbb{R}^{n\times n}$、$V\in\mathbb{R}^{m\times m}$ 为正交矩阵,$\Sigma\in\mathbb{R}^{n\times m}$ 为对角矩阵,$r$ 是矩阵 $M$ 的秩。
关于 $\text{msign}(M)$ 的更多理论分析我们后续再进行介绍。在这里我们先对Muon有一个定性的结论,即:“Muon是一个类似Adam的自适应学习率优化器”。
在上一篇博客中我们已经看到,像AdaGrad、RMSProp、Adam等自适应学习率优化器的特点是通过维护【梯度平方的滑动平均的平方根】来动态调整每个参数的学习率。这种调节方法有两大好处:
- 对损失函数进行常数缩放不影响优化轨迹。
- 每个参数的更新幅度尽可能一致。
而Muon也正好具有这两大好处,具体来说:
- 当损失函数乘上常数因子 $\lambda$,此时动量矩阵 $M$ 也会乘上相同的因子,但经过SVD分解之后的特征向量矩阵 $UV^T$ 不会发生变化。因此 $\text{msign}(M)$ 不变,即优化轨迹不变。
- 由于 $\text{msign}(M)$ 是对动量矩阵 $M$ 的正交化,因此它表现出“各向同性”,即每个参数的更新幅度是一致的。
事后考古发现,一篇2015年的论文也提出了类似的思想,当时称为 “Stochastic Spectral Descent”。
1.2. 矩阵符号函数
在这一小节中,我们对矩阵符号函数 $\text{msign}(M)$ 的数学性质进行详细介绍。
首先我们说明,为什么它是 $\text{sign}$ 函数在矩阵上的推广。
利用 SVD,我们可以证明(具体证明展示在附录中):
\[\begin{equation} \begin{aligned} \text{msign}(M)&=(MM^T)^{-1/2}M\\ &=M(M^TM)^{-1/2} \end{aligned} \label{eq:msign} \end{equation}\]对于实数 $x$,我们有 $\text{sign}(x)=x(x^2)^{-1/2}$。我们可以发现,这二者在形式上式非常相似的。具体来说,当矩阵 $M$ 是一个对角矩阵时,$\text{msign}(M)$ 就等于 $\text{sign}(M)$:
\[\begin{equation} \begin{aligned} \text{msign}(M) &=\text{diag}(m)\left[\text{diag}(m)^2\right]^{-1/2}\\ &=\text{diag}\left(\text{sign}(m)\right)\\ &=\text{sign}(M) \end{aligned} \end{equation}\]此外,考虑一维向量 $m\in\mathbb{R}^{n\times 1}$ ,我们有 $\text{msign}(m)=m/\lVert m\rVert_2$,相当于给参数 $m$ 做了L2归一化。
综上所述,Muon对于性质的矩阵参数有着不同的处理方式:
- 对于一般的矩阵参数(如MLP、QKV等),使用 $\text{msign}(M)$ 对动量进行正交化。
- 对于对角矩阵参数(如LayerNorm中的gamma),则是对动量进行 $\text{sign}$ 操作。
- 对于向量参数(如MLP的偏置和词表embedding),则是对动量进行L2归一化。
值得注意的是,虽然词表embedding也属于矩阵参数,但它们在使用上是稀疏的,因此更合理的方式是对它们的动量进行L2归一化。
值得补充的是,当矩阵 $M$ 是一个满秩方阵,即 $m=n=r$ 时,$\text{msign}(M)$ 还可以看作是对矩阵 $M$ 的一个最优正交近似,即:
\[\begin{equation} \text{msign}(M)=\arg\min_{OO^T=I}\lVert M-O\rVert_F^2 \label{eq:msign_opt} \end{equation}\]1.3. Newton-schulz迭代求解
由于SVD的计算开销较大,因此在实践中,作者提出了使用Newton-schulz迭代来近似计算 $\text{msign}(M)$。
迭代近似的出发点是公式 \eqref{eq:msign}。不失一般性地,我们设 $n\ge m$。考虑对矩阵函数 $X^{-1/2}$ 在 $X=I$ 处进行二阶泰勒展开得:
\[\begin{equation} \begin{aligned} X^{-1/2} &\approx I-\frac{1}{2}(X-I)+\frac{3}{8}(X-I)^2\\ &=\frac{15}{8}-\frac{5}{4}X+\frac{3}{8}X^2 \end{aligned} \end{equation}\]因此,我们有:
\[\begin{equation} \begin{aligned} \text{msign}(M) &=M(M^TM)^{-1/2}\\ &\approx\frac{15}{8}M-\frac{5}{4}M(M^TM)+\frac{3}{8}M(M^TM)^2 \end{aligned} \end{equation}\]因此,Newton-schulz算法的迭代公式为:
\[\begin{equation} X_{t+1}=\frac{15}{8}X_t-\frac{5}{4}X_t(X_t^TX_t)+\frac{3}{8}X_t(X_t^TX_t)^2 \label{eq:newton-schulz} \end{equation}\]然而,当我们查看Muon的官方实现时,发现其Newton-schulz的迭代公式与公式 \eqref{eq:newton-schulz} 大致相同,但常数项不同。苏剑林的博客中对这个不一致做出了解释,他认为官方实现对目的在于加速Newton-schulz迭代的收敛。感兴趣的读者可以在附录3中找到这个解释的具体内容。
二、从范数视角看Muon的优势 [1,2]
上面我们提到,Muon优化器是一种从向量到矩阵的本质跨越。它把模型的参数以矩阵的形式来进行优化,从而获得了更好的理论优势。在这一节中我们我们试图分析这一观点。
读者可能会感到好奇的是,把参数看作向量和矩阵有什么本质的区别吗?因为矩阵和向量都是一堆数字的排列,形式的不同会对优化的结果产生什么影响?
在苏剑林的博客 [1] 中提到了一种观点,即矩阵有一些性质和概念是与向量不同的。比如说,矩阵中有迹(trace)这个概念,它是矩阵主对角线上元素的和。矩阵的迹有一个重要特性是在相似变换下保持不变,且等于矩阵的所有特征值之和。从这个例子就可以看出,矩阵的对角线元素跟非对角线元素,地位其实是不完全对等的。而Muon正是因为考虑了这种不对等性,才有着更好的效果。
那么Muon究竟捕捉了矩阵的什么关键特性呢?我们下面将从范数的视角来重新看待Muon,试图来回答这个问题。
2.1. Muon控制了模型输出值的变化量
对于一个向量 $x\in\mathbb{R}^{d_x}$,我们可以使用下面的向量RMS范数来衡量其大小:
\[\begin{equation} \begin{aligned} \lVert x\rVert_{RMS} &:=\sqrt{\frac{1}{d_x}\sum_{i=1}^{d_x}x_i^2}\\ &=\frac{1}{\sqrt{d_x}}\lVert x\rVert_2 \end{aligned} \label{eq:vector-rms-norm} \end{equation}\]其中,$\lVert x\rVert_2=\sqrt{\sum_{i=1}^{d_x}x_i^2}$ 是向量2范数。
考虑一个线性变换 $y=Wx$,我们就可以使用该向量范数的【诱导范数】来衡量权重矩阵 $W\in\mathbb{R}^{d_y\times d_x}$ 的大小:
\[\begin{equation} \begin{aligned} \lVert W\rVert_{RMS} &:=\sup_{x\neq 0}\frac{\lVert Wx\rVert_{RMS}}{\lVert x\rVert_{RMS}}\\ &=\sqrt{\frac{d_x}{d_y}}\cdot\lVert W\rVert_{sp} \end{aligned} \label{eq:matrix-rms-norm} \end{equation}\]其中,$\lVert \cdot \rVert_{sp}$ 是矩阵谱范数,它是向量2范数所诱导的矩阵范数,其定义为:
\[\begin{equation} \lVert W \rVert_{sp}:=\sup_{x\neq 0}\frac{\lVert Wx\rVert_2}{\lVert x\rVert_2} \label{eq:matrix-sp-norm} \end{equation}\]公式 \eqref{eq:matrix-rms-norm} 的证明展示在附录4中。
当我们将权重矩阵更新 $\Delta W$ 时,输出向量 $\Delta y$ 的改变量为:
\[\begin{equation} \Delta y=\Delta Wx \end{equation}\]由于【任意向量范数与其诱导范数都是相容的】,因此我们可以写出 $\Delta y$ 的RMS范数的上界,即:
\[\begin{equation} \begin{aligned} \lVert \Delta y\rVert_{RMS} &=\lVert \Delta Wx\rVert_{RMS}\\ &\le \lVert \Delta W\rVert_{RMS}\cdot\lVert x\rVert_{RMS}\\ \end{aligned} \end{equation}\]这意味着我们可以通过控制权重更新 $\Delta W$ 的诱导范数大小,我们可以直接控制网络输出的变化量 $\Delta y$,从而实现更加稳定的训练过程。
2.2. Muon的带约束优化问题形式
Muon的核心思想可以总结为:在【稳】的前提下寻找【尽可能快】的更新量。即:
- 对模型的扰动尽可能小;
- 对Loss对贡献要尽可能大。
为此,Muon把参数的更新量 $\Delta W$ 定义为如下带约束优化问题的最优解:
\[\begin{equation} \begin{aligned} &\min_{\Delta W}\quad\langle\nabla_W\mathcal{J},\Delta W\rangle\\ &\text{s.t.}\quad\lVert \Delta W\rVert_{RMS}\le\beta \end{aligned} \label{eq:muon-opt} \end{equation}\]其中,矩阵内积定义为:$\langle A,B\rangle=\text{Tr}(A^TB)$。
可以证明,公式 \eqref{eq:muon-opt} 的最优解为:
\[\begin{equation} \Delta W\propto- \text{msign}(\nabla_W\mathcal{J}) \label{eq:muon-opt-sol} \end{equation}\]公式 \eqref{eq:muon-opt-sol} 的证明在附录5中。
公式 \eqref{eq:muon-opt} 中的目标函数正是梯度下降的目标函数。而公式 \eqref{eq:muon-opt-sol} 告诉我们,当给梯度下降的参数更新量施加一个范数约束时,得到的最优更新量正好就是Muon优化器。
因此,我们可以看到,Muon相当于一个在RMS范数约束下的梯度下降法,这个范数约束更好地度量了矩阵之间的本质差异,从而使得更新的每一步都走的更精准、更本质。
三、Scaling Muon to Large Models [3,4]
在这一节中,我们主要讨论如何在更大尺寸的模型上应用Muon优化器。虽然Muon的理论优势在小尺寸模型上已经被证明,但在更大尺寸的模型上,想要让Muon稳定优于经典的Adam等优化器,还需要一些工程优化。
3.1. 权重衰减项的引入 [3]
苏剑林的团队在将Muon应用在更大尺寸的模型上时,发现在训练前期收敛确实很快,但很快就被Adam追上,甚至还会有各种崩溃的苗头出现。
为此,他们在公式 \eqref{eq:muon-opt} 中引入了一个权重衰减项,即:
\[\begin{equation} \Delta W\propto- \text{msign}(\nabla_W\mathcal{J}+\lambda W) \end{equation}\]此时发现,Muon就能一直保持领先于Adam。他们分析,权重衰减项在其中起的作用就是让参数的范数保持有界。
当某一个优化器给出的更新向量 $\Phi_{t}$ 时,考虑任意一种矩阵范数 $\lVert \cdot \rVert$,当我们在参数更新中加入权重衰减项后,始终有下面不等式成立:
\[\begin{equation} \begin{aligned} \lVert W_t\rVert &=\lVert W_{t-1}-\eta_{t}(\Phi_{t}+\lambda W_{t-1})\rVert\\ &=\lVert (1-\eta_{t} \lambda)W_{t-1}-\eta_{t}\Phi_{t}\rVert\\ &\le (1-\eta_{t} \lambda) \lVert W_{t-1}\rVert + \eta_{t}\lambda \lVert \Phi_{t}/\lambda\rVert\\ &\le \max(\lVert W_{t-1}\rVert, \lVert \Phi_{t}/\lambda\rVert) \end{aligned} \end{equation}\]对于Muon来说,我们选择 $\Phi_{t}=-\text{msign}(\nabla_W\mathcal{J})$ 和谱范数 $\lVert \cdot \rVert_{sp}$。由于 $\lVert \text{msign}(\nabla_W\mathcal{J})\rVert_{sp}=1$,因此有:
\[\begin{equation} \lVert W_t\rVert_{sp}\le\max(\lVert W_{t-1}\rVert_{sp}, 1/\lambda)\le\cdots\le\max(\lVert W_{0}\rVert_{sp}, 1/\lambda) \end{equation}\]这样就保证了参数 $W$ 的模长是有界的。又因为 $\lVert y\rVert=\lVert Wx\rVert\le \lVert W\rVert_{sp}\cdot\lVert x\rVert$,因此模型的输出也能够被控制住,不会有爆炸的风险,这对于 Attention Ligits 爆炸等的问题是尤其重要的。
3.2. Update RMS对齐:快速找到最优超参数 [3]
当我们尝试一个新的优化器时,一个非常重要的问题就是如何快速找到最优的超参数。比如说Muon中至少有两个重要超参数:学习率 $\eta_{t}$ 和权重衰减系数 $\lambda$。
苏剑林团队提出了一种名为Update RMS对齐的超参迁移思路,可以快速地将Adam中经过检验的超参数快速应用到其他优化器上。
对于一个矩阵 $W\in\mathbb{R}^{n\times m}$,我们定义其RMS为:
\[\begin{equation} \begin{aligned} \text{RMS}(W) &:=\frac{\lVert W\rVert_F}{\sqrt{nm}}\\ &=\sqrt{\frac{1}{nm}\sum_{i=1}^n\sum_{j=1}^m W_{ij}^2} \end{aligned} \end{equation}\]注意这里的RMS和公式 \eqref{eq:matrix-rms-norm} 中定义的RMS范数不一样,请读者务必加以区分。
在实验中观察到,使用Adam更新参数时,更新量的RMS基本稳定在0.2~0.4之间 [5]。因此,我们可以将Muon的Update RMS也对齐到0.2,即改为:
\[\begin{equation} W_t=W_{t-1}-\eta_{t}(0.2\cdot\Phi_{t}/\text{RMS}(\Phi_{t})+\lambda W_{t-1})\\ \label{eq:muon-update} \end{equation}\]这样一来,我们就可以服用Adam中的 $\eta_{t}$ 和 $\lambda$,使得Muon和Adam对参数的更新幅度大致相同的效果。实践表明,通过这个简单策略从Adam迁移到Muon,就能训出明显优于Adam的效果,接近进一步对Muon超参进行精搜索的结果。
更进一步地,对于Muon来说,其更新量的RMS是可以解析地算出来的:
\[\begin{equation} \begin{aligned} \text{RMS}(\Phi_{t}) &=\text{RMS}(U_{[:,:r]}V_{[:,:r]}^T)\\ &=\sqrt{\frac{1}{nm}\sum_{i=1}^n\sum_{j=1}^m\sum_{k=1}^r U_{ik}^2V_{kj}^2}\\ &=\sqrt{\frac{1}{nm}\sum_{k=1}^r\left( \sum_{i=1}^n U_{ik}^2 \right)\left(\sum_{j=1}^m V_{kj}^2 \right)}\\ &=\sqrt{\frac{r}{nm}} \end{aligned} \end{equation}\]在实践中,一个矩阵是严格低秩的概率比较小,因此我们可以认为 $r\approx\min(n,m)$,从而有 $\text{RMS}(\Phi_{t})\approx\sqrt{1/\max(n,m)}$。
因此公式 \eqref{eq:muon-update} 可以进一步写为:
\[\begin{equation} W_t=W_{t-1}-\eta_{t}(0.2\sqrt{\max(n,m)}\cdot\Phi_{t}+\lambda W_{t-1})\\ \end{equation}\]上面的公式同时也表明了:在Muon中不适宜对所有的参数 $W$ 使用同一个学习率,否则必然会导致某些参数学习过快/过慢的不同步问题,从而影响最终效果。
3.3. QK-Clip:将Muon应用到100B以上的模型 [4]
上述两个优化手段的有效性在16B的模型上得到了验证,然而当他们试图进一步将Muon拓展到100B参数以上的模型时,则出现了新的问题:MaxLogits爆炸。
为了解决这个问题,他们使用了一种新技术:QK-Clip。该方法从一个非常本质的角度去看待和解决MaxLogit爆炸现象,并且无损模型效果,这也是Kimi K2(1000B参数)的关键训练技术之一。
MaxLogits是指attention矩阵的最大值,即:
\[\begin{equation} S_{max}=\max_{i,j}q_i\cdot k_j \end{equation}\]MaxLogits爆炸是指:$S_{max}$ 随着训练的推进一直往上涨,增长速度是线性甚至是超线性的,并且在相当长的时间内没有稳定的迹象。
尽管经过Softmax之后都会变成小于1,顶多是浪费了一个attention head,但在最坏情况下MaxLogits爆炸会引起梯度爆炸和训练崩溃。因此,我们也应当尽可能地避免MaxLogits爆炸的情况。
3.3.1. softcap和QK-Norm
在上面我们提到,可以使用权重衰减来一定程度上防止MaxLogits爆炸的情况出现,但这种策略仅仅适用于小模型。当模型参数量越来越大,训练的不稳定因素越多,权重衰减就越难稳定训练过程,且会造成严重的效果损失。
一种直接的方式是给Logits加上一个上界 softcap:
\[\begin{equation} \begin{aligned} O&=\text{softmax}(\text{softcap}(QK^T))V\\ \text{softcap}(x)&=\tau\tanh(x/\tau)\\ \end{aligned} \end{equation}\]由于 $\tanh$ 的有界性,softcap能够保证logits的有界性。但问题在于无法保证softcap之前的Logits也是有界的,所以softcap只是将一个问题转化为了另一个问题,实际上并没有解决问题。
为此,Gemma3、Qwen3等模型都该用了QK-Norm:
\[\begin{equation} \begin{aligned} O&=\text{softmax}(\tilde{Q}\tilde{K}^T)V\\ \tilde{Q}&=\text{RMSNorm}(Q)\\ \tilde{K}&=\text{RMSNorm}(K)\\ \end{aligned} \end{equation}\]QK-Norm是一种压制MaxLogits的有效方法,但它的一个严重问题是只适用于MHA、GQA等注意力机制,无法用在MLA中。
这是因为QK-Norm需要完整写出Q、K矩阵,但对于MLA来说,其训练阶段和推理阶段的Q、K矩阵是不同的,因此在推理阶段没法做QK-Norm。
3.3.2. QK-Clip
其实对QK进行缩放的关键问题就是:什么时候缩放、缩放多少。为了解决这个问题,QK-Clip将MaxLogit本身作为触发缩放的信号。具体来说,当MaxLogit超过一个阈值 $\tau$ 时,就对QK进行缩放,缩放比例为 $\gamma=\tau/S_{max}$,这样一来新的MaxLogit就一定不会超过 $\tau$。
同时,由于QK-Clip是对参数直接进行操作,因此不会影响推理阶段,自然也就能够兼容MLA。
Appendix
Apd.1. Proof on Eq. \eqref{eq:msign}
下面我们证明恒等式:
\[\begin{equation} \begin{aligned} \text{msign}(M)&=(MM^T)^{-1/2}M\\ &=M(M^TM)^{-1/2} \end{aligned} \end{equation}\]我们首先计算 $MM^T$:
\[\begin{equation} \begin{aligned} MM^T &=U\Sigma V^T V\Sigma^T U^T\\ &=U\Sigma\Sigma^T U^T\\ &=U\Lambda U^T\\ \end{aligned} \end{equation}\]其中 $\Lambda=\Sigma\Sigma^T\in\mathbb{R}^{n\times n}$ 是对角矩阵,其对角线元素为:
\[\begin{equation} \Lambda_{ii}= \begin{cases} \sigma_i^2,&i=1,\dots,r\\ 0,&i\gt r \end{cases} \end{equation}\]代入恒等式右侧得:
\[\begin{equation} \begin{aligned} (MM^T)^{-1/2}M &=U\Lambda^{-1/2}U^T\cdot U\Sigma V^T\\ &=U\Lambda^{-1/2}\Sigma V^T\\ \end{aligned} \end{equation}\]其中,$\Lambda^{-1/2}\Sigma$ 是一个对角矩阵,其对角线元素为:
\[\begin{equation} (\Lambda^{-1/2}\Sigma)_{ii}= \begin{cases} 1,&i=1,\dots,r\\ 0,&i\gt r \end{cases} \end{equation}\]因此,代入上式得:
\[\begin{equation} \begin{aligned} (MM^T)^{-1/2}M &=U\Lambda^{-1/2}\Sigma V^T\\ &=U \begin{bmatrix} I_r&0\\0&0 \end{bmatrix} V^T\\ &=U_{[:,:r]}V_{[:,:r]}^T\\ &=\text{msign}(M) \end{aligned} \end{equation}\]得证。
Apd.2. Proof on Eq. \eqref{eq:msign_opt}
下面我们证明当矩阵 $M$ 是一个满秩方阵,即 $m=n=r$ 时,有:
\[\begin{equation} \text{msign}(M)=\arg\min_{OO^T=I}\lVert M-O\rVert_F^2 \end{equation}\]对于正交矩阵 $O$,我们有:
\[\begin{equation} \begin{aligned} \lVert M-O\rVert_F^2 &=\lVert M\rVert_F^2+\lVert O\rVert_F^2-2{\langle M,O\rangle}_F\\ &=\lVert M\rVert_F^2+n-2\text{Tr}(MO^T)\\ &=\lVert M\rVert_F^2+n-2\text{Tr}(U\Sigma V^T O^T)\\ &=\lVert M\rVert_F^2+n-2\text{Tr}(\Sigma V^T O^TU)\\ &=\lVert M\rVert_F^2+n-2\sum_{i=1}^n\Sigma_{ii}\cdot\left(V^T O^TU\right)_{ii} \end{aligned} \end{equation}\]由于 $\Sigma_{ii}=\sigma_i\gt 0$,因此我们需要最大化 $\sum_{i=1}^n\Sigma_{ii}\cdot\left(V^T O^TU\right)_{ii}$。
又因为 $V^T O^TU$ 是一个正交矩阵,其对角线元素不超过1。因此当 $V^T O^TU=I$ 时,$\sum_{i=1}^n\Sigma_{ii}\cdot\left(V^T O^TU\right)_{ii}$。最大,即 $\lVert M-O\rVert_F^2$ 最小。此时,$O=UV^T=\text{msign}(M)$。
Apd.3. Newton-schulz迭代的官方实现 [1]
在Muon的官方实现中,Newton-schulz迭代的实现为:
\[\begin{equation} X_{t+1}=3.4445X_t-4.7750X_t(X_t^TX_t)+2.0315X_t(X_t^TX_t)^2 \end{equation}\]可以看到,所选用的常数项与公式 \eqref{eq:newton-schulz} 不同。在苏剑林的博客 [1] 中,他认为这样设计是为了加速迭代的收敛速度。下面我们复述一下博客中的内容。
考虑更一般的迭代过程:
\[\begin{equation} X_{t+1}=aX_t+bX_t(X_t^TX_t)+cX_t(X_t^TX_t)^2 \label{eq:newton-schulz-general} \end{equation}\]其中 $a,b,c$ 是三个待求解的常数项。
我们选择的迭代初始值是 $X_0=M/ \lVert M\rVert_F$,之所以要除以F范数,是因为这样子不改变SVD分解得到的正交矩阵 $U,V$,但可以让 $X_0$ 的奇异值都控制在 $[0,1]$ 之间。
设 $U\Sigma_t V^T=\text{SVD}(X_t)$,则代入公式 \eqref{eq:newton-schulz-general} 得:
\[\begin{equation} X_{t+1}=U_{[:,:r]}\left( a\Sigma_{t,[:,:r]}+b\Sigma_{t,[:,:r]}^3+c\Sigma_{t,[:,:r]}^5 \right)V_{[:,:r]}^T \end{equation}\]因此,我们可以看到公式 \eqref{eq:newton-schulz-general} 实际上是在对奇异值对角矩阵 $\Sigma_{t,[:,:r]}$ 进行迭代。这是因为:令 $g(x)=ax+bx^3+cx^5$,则有 $\Sigma_{t+1,[:,:r]}=g(\Sigma_{t,[:,:r]})$。又因为对角阵的幂等于对角线元素各自取幂,所以问题简化成单个奇异值 $\sigma$ 的迭代。
由于我们的目标是将 $X_0=M=U_{[:,:r]}\Sigma U_{[:r,:r]}V_{[:,:r]}^T$ 迭代到 $X_T=\text{msign}(M)=U_{[:,:r]}V_{[:,:r]}^T$,因此我们需要把中间的对角矩阵 $\Sigma U_{[:r,:r]}$ 单位阵 $I$。也就是说,整个迭代过程等价于利用 $x_{t+1}=g(x_t)$ 将实数 $x_0$ 迭代到 $x_T=1$。
我们将常数 $a,b,c$ 的选择视为一个最优化问题,目标是使得迭代过程对于任意初始值 $x_0$ 都能尽可能快地收敛到 $x_T=1$,即迭代步数 $T$ 越小越好。
我们将 $g(x)$ 重新写为:
\[\begin{equation} g(x)=x+\kappa x(x^2-x_1^2)(x^2-x_2^2) \end{equation}\]不失一般性地,我们令 $x_1\le x_2$。这样子我们就能直观地写出 $g(x)$ 的所有不动点:$0, \pm x_1,\pm x_2$。由于我们的目标是收敛到 $x_T=1$,因此作为初始化我们选择 $x_1\lt 1,x_2\gt 1$。只要我们确定好迭代步长 $T$,我们就可以通过优化 $x_T$ 和 $x_0$ 之间的距离,来确定最优的 $a,b,c$。具体来说,我们的求解思路如下:
- 选定超参数 $n,m,T$。
- 生成随机矩阵 $M\in\mathbb{R}^{n\times m}$,并计算其SVD分解 $M=U\Sigma V^T$。
- 初始化 $x_0=\sigma$,并迭代 $T$ 步,得到 $x_T=g^T(x_0)$。
- 最小化 $(x_T-1)^2$,通过梯度回传求得最优的 $a,b,c$。
我编写了一个python脚本来进行上述模拟,结论发现:最优参数与矩阵大小、迭代次数 $T$ 都有着明显关系。而Muon官方实现中选择的常数,大概是迭代部署 $T=5$ 时方阵的最优解。
Apd.4. Proof on Eq. \eqref{eq:matrix-rms-norm}
下面我们证明由向量RMS范数诱导出的矩阵RMS范数有如下形式:
\[\begin{equation} \begin{aligned} \lVert W\rVert_{RMS} &:=\sup_{x\neq 0}\frac{\lVert Wx\rVert_{RMS}}{\lVert x\rVert_{RMS}}\\ &=\sqrt{\frac{d_x}{d_y}}\cdot\lVert W\rVert_{sp} \end{aligned} \end{equation}\]由公式 \eqref{eq:vector-rms-norm} 我们可以用2范数来表示上式的两个向量RMS范数:
\[\begin{equation} \begin{aligned} \lVert x\rVert_{rms}&=\frac{1}{\sqrt{d_x}}\lVert x\rVert_2\\ \lVert Wx\rVert_{rms}&=\frac{1}{\sqrt{d_y}}\lVert Wx\rVert_2\\ \end{aligned} \end{equation}\]代入诱导范数的定义得:
\[\begin{equation} \begin{aligned} \lVert W\rVert_{RMS} &=\sup_{x\neq 0}\frac{\frac{1}{\sqrt{d_y}}\lVert Wx\rVert_2}{\frac{1}{\sqrt{d_x}}\lVert x\rVert_2}\\ &=\sqrt{\frac{d_x}{d_y}}\sup_{x\neq 0}\frac{\lVert Wx\rVert_2}{\lVert x\rVert_2}\\ &=\sqrt{\frac{d_x}{d_y}}\cdot\lVert W\rVert_{sp} \end{aligned} \end{equation}\]最后一步是利用了谱范数的定义,即公式 \eqref{eq:matrix-sp-norm}。
得证。
Apd.5. Proof on Eq. \eqref{eq:muon-opt-sol}
下面我们证明带约束的优化问题 \eqref{eq:muon-opt} 的最优解为公式 \eqref{eq:muon-opt-sol}。
Step 1. 优化问题变形
将公式 \eqref{eq:matrix-rms-norm} 代入 \eqref{eq:muon-opt} 中的约束条件并移项后,可以得到约束条件等价于:
\[\begin{equation} \lVert \Delta W\rVert_{sp}\le\beta\sqrt{\frac{d_y}{d_x}}:=C \end{equation}\]我们记 $U\Sigma V^T=\text{SVD}(\nabla_W\mathcal{J})$,则目标函数可以做如下变形:
\[\begin{equation} \begin{aligned} \langle\nabla_W\mathcal{J},\Delta W\rangle &=\text{Tr}\left(\left(U\Sigma V^T\right)^T\Delta W\right)\\ &=\text{Tr}\left(V\Sigma^T U^T\Delta W\right)\\ &=\text{Tr}\left(\Sigma^T U^T\Delta WV\right) \end{aligned} \end{equation}\]我们定义 $\Delta W’=U^T\Delta WV$,由于 $U$ 和 $V$ 是正交的,因此有 $\Delta W=U\Delta W’ V^T$。
又由于 $\Sigma$ 是对角矩阵,即 $\Sigma^T=\Sigma$,因此目标函数可以写为:
\[\begin{equation} \begin{aligned} \langle\nabla_W\mathcal{J},\Delta W\rangle &=\text{Tr}\left(\Sigma^T U^T\Delta WV\right)\\ &=Tr\left(\Sigma \Delta W'\right)\\ &=\sum\sigma_i(\Delta W')_{ii} \end{aligned} \end{equation}\]由于矩阵的谱范数在正交变化下保持不变,因此有:
\[\begin{equation} \begin{aligned} \lVert \Delta W\rVert_{sp} &=\lVert U^T\Delta WV\rVert_{sp}\\ &=\lVert \Delta W'\rVert_{sp} \end{aligned} \end{equation}\]因此我们的优化问题可以等价地写为:
\[\begin{equation} \begin{aligned} &\min_{\Delta W}\quad\sum\sigma_i(\Delta W')_{ii}\\ &\text{s.t.}\quad\lVert \Delta W'\rVert_{sp}\le C \end{aligned} \end{equation}\]Step 2. 求解
对于任意矩阵,其谱范数都大于等于其任意对角线元素的绝对值,即:
\[\begin{equation} \lvert (\Delta W')_{ii}\rvert\le\lVert \Delta W'\rVert_{sp}\le C \end{equation}\]因此我们有 $(\Delta W’)_{ii}\in[-C,C]$。
由于 $\sigma_i\ge 0$,为了使目标函数最小,我们需要给每个对角线元素取最小值,即 $(\Delta W’)_{ii}=-C$。
同时,为了使约束条件中的谱范数尽可能小,非对角线元素我们应该置0,即 $(\Delta W’)_{ij}=0$。
此时,最优解 $\Delta W’ = -C\cdot I$,此时 $\lVert \Delta W’\rVert_{sp}=C$。
最后我们便可以求出原始的最优解:
\[\begin{equation} \begin{aligned} \Delta W &=U\Delta W' V^T\\ &=-C\cdot UV^T\\ &\propto -\text{msign}(\nabla_W\mathcal{J}) \end{aligned} \end{equation}\]公式 \eqref{eq:muon-opt-sol} 得证。
Reference
[2] Building the Muon Optimizer in PyTorch: A Geometric Approach to Neural Network Optimization