生成模型 (1.1)

变分推断

Posted by Zifeng Mai on November 7, 2025

引言

在这一篇文章中,我们从变分 (Variation) 的视角来看待生成模型。

在讲解诸如VAE、DDPM等经典方法之前,我想先画一些时间讲清楚什么是【变分】。

我一开始学习VAE时,发现VAE中的隐变量是服从高斯分布的,而AE中的隐变量仍然服从某个未知的(往往也是难解的)隐分布。这个时候我可能认为变分就是引入一个隐分布的先验假设。在学习DDPM时,发现优化的ELBO损失是似然函数的“变分下界”,这里又出现了变分,但似乎和高斯分布之类的没有任何关系,然后我就搞不清楚到底变分的意义是什么了。

到后面我了解到数学中有一类称为【变分法】的方法,被用于求解泛函的极值问题。这似乎和优化问题扯上了一些关系,但也是死活没发现泛函体现在哪。

因此,我希望在学习生成模型之前,先把变分(准确来说是【变分推断】)的含义理解清楚。

在本文中,我们将看到数学上的变分法和机器学习中的变分推断实际上都是在处理某个函数空间中的优化问题,其中变分法用于求解泛函的极值所对应的最优函数,而变分推断则用于求解与难解的后验分布最接近且可解的概率分布

以下的内容主要来自知乎中的高赞回答博客

一、什么是推断

在机器学习领域中,推断 (Inference) 是指从观测数据 $x$ 计算未知量 $z$ 的分布的过程,即计算后验分布 $p(z\mid x)$ 的过程。

这里的未知量 $z$ 可以包括多种,比如说模型参数、隐变量,甚至是未来的观测数据。

推断主要用到的数学工具是贝叶斯定理,因此也被称为贝叶斯推断 (Bayesian Inference):

\[\begin{equation} p(z\mid x) = \frac{p(x\mid z)p(z)}{p(x)} \end{equation}\]

其中:

  • $p(z)$ 称为先验 (prior),是对未知量 $z$ 的一种无知识的估计
  • $p(z\mid x)$ 称为后验 (posterior)
  • $p(x)$ 称为证据 (evidence)
  • $p(x\mid z)$ 称为似然 (likelihood)

从这里可以看出机器学习的两种不同流派:

  1. 概率学派:利用最大似然对未知量 $z$ 进行点估计,即 $\hat{z}:=\arg\max_zp(x\mid z)$;
  2. 贝叶斯学派:估计未知量 $z$ 的整个分布,即 $p(z\mid x)$;

进行贝叶斯推断主要有以下三个步骤:

  1. 根据问题,选定先验 $p(z)$ 以及似然 $p(x\mid z)$ 所服从的分布;

  2. 计算证据:

    \[\begin{equation} p(x) = \int p(x\mid z)p(z)\mathrm{d}z \end{equation}\]
  3. 根据公式 $(1)$ 计算后验概率。

可以发现,整个推断的复杂度几乎全部集中在第二步,即求解公式 $(2)$ 的积分上。公式 $(2)$ 实际上是一个多维积分,其复杂度取决于隐变量 $z$ 的维度。在实际的机器学习问题中,隐变量(如模型参数)的维度通常都非常高,可能是成千上万维的一个高维向量。因此,在解决实际问题时贝叶斯推断几乎不可用。这也是为什么我们称 $p(x)$ 为难解的 (intractable)。

二、什么是变分推断

2.1. 从计算转为优化

为了解决贝叶斯推断复杂度过高的问题,我们可以使用变分推断 (Variational Inference) 来近似后验概率 $p(z\mid x)$。

变分推断的核心步骤包括以下两步:

  1. 引入一个带参数 $\lambda$ 的变分分布 (variational distribution) $q(z;\lambda)$,这个分布是可解的;
  2. 通过优化参数 $\lambda$,使得变分分布 $q(z;\lambda)$ 尽可能接近真实后验分布 $p(z\mid x)$。

可以看到,变分推断实际上是把贝叶斯推断的计算问题变为优化问题,通过优化参数 $\lambda$ 来逼近后验分布:

\[\begin{equation} \lambda^*=\arg\min_\lambda \mathcal{D}(p(z\mid x),q(z;\lambda)) \end{equation}\]

优化问题收敛后,我们就可以用可解的 $q(z;\lambda^*)$ 分布来替代难解的 $p(z\mid x)$ 分布。

公式 $(3)$ 中的 $\mathcal{D}$ 是度量两个分布之间距离的度量。一般来说,我们采用KL散度作为这个度量。在上一篇文章中,我们已经介绍了KL散度:

\[\begin{equation} \begin{aligned} \mathcal{D}_{KL}(p\|q) &:=\int p(x)\log\frac{p(x)}{q(x)}\mathrm{d}x\\ &=\mathbb{E}_{x\sim p}[\log p(x)-\log q(x)] \end{aligned} \end{equation}\]

KL散度满足以下性质:$\mathcal{D}_{KL}(p|q)\ge 0$,当且仅当 $p=q$ 时等号成立。

2.2. ELBO的引入

上面提到,变分推断的目标是求解以下最优化问题:

\[\begin{equation} \min_q \mathcal{D}_{KL}(q(z)\|p(z\mid x)) \end{equation}\]

我们已经知道,后验分布 $p(z\mid x)$ 是难解的,因此我们需要做适当的变形。

首先根据定义展开KL散度:

\[\begin{equation} \begin{aligned} \mathcal{D}_{KL}(q(z)\|p(z\mid x)) &=\int q(z)\log\frac{q(z)}{p(z\mid x)}\mathrm{d}z\\ &=\int q(z)\log\frac{q(z)}{\frac{p(x\mid z)p(z)}{p(x)}}\mathrm{d}z\\ &=\int q(z)\log\left(q(z)\cdot\frac{p(x)}{p(x, z)}\right)\mathrm{d}z\\ &=\int q(z)\left(\log q(z)+\log p(x) - \log p(x, z)\right)\mathrm{d}z\\ &=\int q(z)\log q(z)\mathrm{d}z+\int q(z)\log p(x)\mathrm{d}z - \int q(z)\log p(x, z)\mathrm{d}z\\ &=\int q(z)\log q(z)\mathrm{d}z+\log p(x)- \int q(z)\log p(x, z)\mathrm{d}z\\ &=\log p(x) - \left(\int q(z)\log p(x, z)\mathrm{d}z-\int q(z)\log q(z)\mathrm{d}z\right)\\ &=\log p(x) - \int q(z)\log \frac{p(x, z)}{q(z)}\mathrm{d}z \end{aligned} \end{equation}\]

我们称第二项为经验下界 (Evidence Lower Bound, ELBO),因为它是经验 $p(x)$ 的一个下界估计:

\[\begin{equation} \begin{aligned} ELBO &:=\int q(z)\log \frac{p(x, z)}{q(z)}\mathrm{d}z\\ &=\log p(x) - \mathcal{D}_{KL}(q(z)\|p(z\mid x))\\ &\le \log p(x) \end{aligned} \end{equation}\]

ELBO项可以进一步做如下变形:

\[\begin{equation} \begin{aligned} ELBO &:=\int q(z)\log \frac{p(x, z)}{q(z)}\mathrm{d}z\\ &=\mathbb{E}_{q}[\log p(x,z)] - \mathbb{E}_{q}[\log q(z)] \end{aligned} \end{equation}\]

因此,我们的优化问题变为:

\[\begin{equation} \begin{aligned} q^* &=\arg\min_q \mathcal{D}_{KL}(q(z)\|p(z\mid x))\\ &=\arg\min_q\left(\log p(x) - ELBO \right)\\ &=\arg\max_q ELBO \end{aligned} \end{equation}\]

即转变为最大化ELBO项。

2.3. 求解变分推断

下面介绍两种求解变分推断问题的实际方法

2.3.1. 平均场变分族

平均场变分族 (Mean-Field Variational Family) 基于的核心假设是:隐变量的不同分量 $z={z_1,z_2,\dots,z_K}$ 之间是相互独立的。因此,变分分布 $q(z;\lambda)$ 可以分解为:

\[\begin{equation} q(z;\lambda) = \prod_{k=1}^{K}q_k(z_k;\lambda_k) \end{equation}\]

在这个假设下,我们可以通过坐标上升变分推断 (Coordinate Ascent Variational Inference, CAVI) 方法进行优化。

我们对ELBO的形式进行一定推导:

\[\begin{equation} \begin{aligned} ELBO(q) &=\mathbb{E}_{q}[\log p(x,z)] - \mathbb{E}_{q}[\log q(z;\lambda)]\\ &=\mathbb{E}_{q}[\log p(x,z)] - \sum_{k=1}^K \mathbb{E}_{q_k}[\log q_k(z_k;\lambda_k)]\\ \end{aligned} \end{equation}\]

我们记 $z_{-j}={z_1,\dots,z_{j-1},z_{j+1},\dots,z_K}$ 表示除 $z_j$ 之外的隐变量。我们在固定 $q_{-j}$ 的情况下,希望优化 $q_j$。

考虑ELBO关于 $q_j$ 的泛函:

\[\begin{equation} \begin{aligned} ELBO(q_j;q_{-j}) &=\mathbb{E}_{q_j,q_{-j}}[\log p(x,z)] - \mathbb{E}_{q_j}[\log q_j(z_j;\lambda_j)]- \sum_{k\neq j} \mathbb{E}_{q_k}[\log q_k(z_k;\lambda_k)]\\ &=\mathbb{E}_{q_j}[\mathbb{E}_{q_{-j}}[\log p(x,z)]] - \mathbb{E}_{q_j}[\log q_j(z_j;\lambda_j)]+\text{const} \\ \end{aligned} \end{equation}\]

记 $L_{j}(z_{j})=\mathbb{E}{q{-j}}\left[\log p(x,z)\right]$,则:

\[\begin{equation} \begin{aligned} ELBO(q_j;q_{-j}) &=\mathbb{E}_{q_j}[L_j(z_j)] - \mathbb{E}_{q_j}[\log q_j(z_j;\lambda_j)]+\text{const} \\ \end{aligned} \end{equation}\]

注意 $q_j$ 满足约束条件

\[\begin{equation} \int q_j(z_j)\mathrm{d}z_j=1 \end{equation}\]

我们通过拉格朗日乘子法来求泛函 $ELBO(q_j)$ 的极值。构造如下的拉格朗日函数:

\[\begin{equation} \begin{aligned} \mathcal{J} &=ELBO(q_j) + \lambda\left(\int q_j(z_j)\mathrm{d}z_j-1\right)\\ &=\mathbb{E}_{q_j}[L_j(z_j)] - \mathbb{E}_{q_j}[\log q_j(z_j;\lambda_j)] + \lambda\left(\int q_j(z_j)\mathrm{d}z_j-1\right)\\ &=\int q_j(z_j)L_j(z_j)\mathrm{d}z_j-\int q_j(z_j)\log q_j(z_j)\mathrm{d}z_j+ \lambda\left(\int q_j(z_j)\mathrm{d}z_j-1\right)\\ &=\int q_j(z_j)\left( L_j(z_j)-\log q_j(z_j)+\lambda \right)\mathrm{d}z_j-\lambda\\ \end{aligned} \end{equation}\]

对 $q_j$ 求导得:

\[\begin{equation} \begin{aligned} \frac{\delta\mathcal{J}}{\delta q_j} &= L_j(z_j)-\log q_j(z_j)+\lambda-1 \end{aligned} \end{equation}\]

令导数为0得:

\[\begin{equation} \begin{aligned} \log q_j^*(z_j) &=L_j(z_j)+\lambda-1\\ &=L_j(z_j) + \text{const}\\ &=\mathbb{E}_{q_{-j}}[\log p(x,z)] + \text{const}\\ \end{aligned} \end{equation}\]

因此,最优变分分布满足以下性质:

\[\begin{equation} \begin{aligned} q_j^*(z_j) &\propto \exp\left(\mathbb{E}_{q_{-j}}[\log p(x,z)]\right)\\ \end{aligned} \end{equation}\]

2.3.2. 黑盒变分推断

黑盒变分推断 (Black Box Variational Inference, BBVI) 是一种通用的、模型无关的变分推断方法,核心思想是将变分推断转化为一个可以通过随机梯度优化的问题,无需为特定模型推导复杂的更新方程。

我们将ELBO看作一个关于参数 $\lambda$ 的函数,则求偏导得:

\[\begin{equation} \begin{aligned} \frac{\partial ELBO(\lambda)}{\partial \lambda} &=\frac{\partial }{\partial \lambda}\int q(z;\lambda)\log \frac{p(x, z)}{q(z;\lambda)}\mathrm{d}z\\ &=\int\frac{\partial }{\partial \lambda}\left[q(z;\lambda)\log\frac{p(x, z)}{q(z;\lambda)}\right]\mathrm{d}z\\ &=\int\left[\frac{\partial q(z;\lambda)}{\partial \lambda}\log p(x, z) - \frac{\partial q(z;\lambda)\log q(z;\lambda)}{\partial \lambda}\right]\mathrm{d}z\\ &=\int\frac{\partial q(z;\lambda)}{\partial \lambda}\log p(x, z)\mathrm{d}z - \int\frac{\partial q(z;\lambda)}{\partial \lambda}\log q(z;\lambda)\mathrm{d}z - \int\frac{\partial q(z;\lambda)}{\partial \lambda}\mathrm{d}z\\ &=\int\frac{\partial q(z;\lambda)}{\partial \lambda}(\log p(x, z)-\log q(z;\lambda))\mathrm{d}z - \int\frac{\partial q(z;\lambda)}{\partial \lambda}\mathrm{d}z\\ \end{aligned} \end{equation}\]

其中,

\[\begin{equation} \begin{aligned} \int\frac{\partial q(z;\lambda)}{\partial \lambda}\mathrm{d}z &=\frac{\partial }{\partial \lambda}\int q(z;\lambda)\mathrm{d}z\\ &=0 \end{aligned} \end{equation}\]

代入得

\[\begin{equation} \begin{aligned} \frac{\partial ELBO(\lambda)}{\partial \lambda} &=\int\frac{\partial q(z;\lambda)}{\partial \lambda}(\log p(x, z)-\log q(z;\lambda))\mathrm{d}z\\ &=\int q(z;\lambda)\frac{\partial \log q(z;\lambda)}{\partial \lambda}(\log p(x, z)-\log q(z;\lambda))\mathrm{d}z\\ &=\int q(z;\lambda)\nabla_\lambda\log q(z;\lambda)(\log p(x, z)-\log q(z;\lambda))\mathrm{d}z\\ &=\mathbb{E}_q[\nabla_\lambda\log q(z;\lambda)(\log p(x, z)-\log q(z;\lambda))] \end{aligned} \end{equation}\]

上面的公式可以写成随机梯度下降的形式来优化:

\[\begin{equation} \lambda \leftarrow \lambda - \eta\cdot\frac{1}{N}\sum_{i=1}^N\nabla_\lambda\log q(z_i;\lambda)(\log p(x, z_i)-\log q(z_i;\lambda)) \end{equation}\]