TL;DR

最近,AIGC是极火热的讨论话题,而文生图可以说是AIGC的代表性工作。目前,效果最好的文生图模型是基于扩散模型的,当进一步深入扩散模型时,又对他的损失函数产生了很大的疑问。通过查找各方资料,才发现扩散模型与变分自编码器在损失定义上同出一门,理解了变分自编码器的损失自然也能理解扩散模型的损失。

另外,变分自编码器已经作为基础模型,集成到许多后续工作中,例如:

  1. Stable Diffusion用变分自编码器获取图片的潜在表征(latents)进行前向扩散,避免直接在像素空间中前向扩散,极大地提升了计算效率;
  2. 作为变分自编码器的拓展性工作,向量化离散变分自编码器(Vector Quantised-Variational AutoEncoder, VQ-VAE)已经被广泛用作图像分词器,如BEITDALL·E等。

可以说,变分自编码器是过不去的一个坎,极有必要对变分自编码器做细致的了解。

但是,查阅已有资料发现,有关变分自编码器的教程总是伴随复杂的公式推导,而实现的代码又难以与公式严格对应。另外,理论部分还涉及变分推断、ELBO、重参数等等多种技巧,让人摸不着头脑。本文将从基本原理入手,逐步介绍变分自编码器的概念、损失函数、推断过程等关键内容,旨在对变分自编码器理论的来龙去脉进行详细的解释,并将推导过程与具体实现相结合,帮助更好地理解变分自编码器。

理论部分

什么是自编码器?:自编码器(AutoEncoder, AE)是一种无监督方式训练的神经网络,主要思想是将高维的输入数据进行编码、压缩,得到低维的特征表示,然后将该特征解码回原始数据,从而学习数据的特征表示。可以用于数据压缩、降维、异常检测、图像去噪等。

如图所示,自编码器包含两个部分:

  1. 编码器(Encoder):将原始高维数据映射到低维隐空间中,以得到低维特征表示;
  2. 解码器(Decoder):低维隐空间中的特征表示作为输入,将其重新映射到原始数据空间,以得到重建数据。

记原始输入数据点为xx,编码器为gϕg_{\phi},编码后的特征为zz,解码器为fθf_{\theta},解码重建后的数据为xx',那么就有

z=gϕ(x)x=fθ(z)(1)\begin{aligned} z &= g_{\phi}(x) \\ x' &= f_{\theta}(z) \end{aligned} \tag{1}

其中ϕ\phiθ\theta分别为编码器g()g(\cdot)和解码器f()f(\cdot)的参数。最终的目标是学习一个恒等映射,即

xfθ(gϕ(x))(2)x' \approx f_{\theta}(g_{\phi}(x)) \tag{2}

损失可以用xx'xx间的距离度量定义,如熵、MSE等,下面用MSE定义损失

LAE(θ,ϕ)=1ni=1n(x(i)fθ(gϕ(x(i))))2(3)L_{AE} (\theta, \phi) = \frac{1}{n} \sum_{i=1}^n (x^{(i)} - f_{\theta}(g_{\phi}(x^{(i)})))^2 \tag{3}

自编码器与内容生成:那么训练结束后,获得了编码器、解码器两个网络,除了对原始数据的压缩、降维,是否还可以用来生成数据?比如在隐空间随机取一个特征,用解码器对这个特征进行重构,从而得到新的数据。

这听起来是合理的,但事实上这样做的结果却不尽如人意,原因是:

  1. 自编码器的训练目标是重构输入数据,模型规模较大、数据量较小的情况下,能做到一对一的映射,但也引入了过拟合问题;
  2. 训练过程中没有对隐空间作任何限制,也就是说隐空间是以任意方式组织的,导致是不连续的,呈现不规则的、无界的分布。

也就是说,隐空间中随机选取特征可能不具有任何实际含义,导致解码后的结果无意义。

变分自编码器如何解决这个问题?:变分自编码器(Variational AutoEncoder)是一种改进的自编码器,目的是使自编码器能应用于内容生成。其思想是:将原始数据编码为隐空间中的概率分布,而不是特定的单个特征,使隐空间具有可采样的特性。

进一步地,为了使隐空间具有可采样的特性,可以令隐变量zz服从某简单分布(如正态分布),那么可以通过下面步骤采样得到隐层表征,并重构生成数据:

  1. 从先验概率pθ(z)p_{\theta}(z)中采样,得到特征z(i)z^{(i)}
  2. 用似然函数pθ(xz=z(i))p_{\theta}(x|z=z^{(i)})重构数据,得到xx'

那么,接下来的问题就是如何估计变分自编码器的参数θ\theta。在解决这个问题前,先从贝叶斯模型角度讲解“变分推断”是怎么回事。

从贝叶斯模型谈起:假设输入变量为xx,隐变量是zz(在分类问题中即标签yy,回归问题中就是预测值),那么贝叶斯模型中有

  • 先验概率p(z)p(z)
  • 似然函数p(xz)p(x|z)
  • 后验概率p(zx)p(z|x)

它们之间的联系可以用贝叶斯公式描述:

p(zx)=p(xz)p(z)p(x)(4.1)p(z|x) = \frac{p(x|z) p(z)}{p(x)} \tag{4.1}

其中

p(x)=p(x,z)dz=p(xz)p(z)dz(4.2)p(x) = \int p(x, z) dz = \int p(x|z) p(z) dz \tag{4.2}

其中,p(z)p(z)p(xz)p(x|z)可以从数据集估计得到,那么目的就是为了求解后验概率分布p(zx)p(z|x)。将已知项代入上式就能得到结果,但可以看到,p(zx)=p(xz)p(z)p(xz)p(z)dzp(z|x) = \frac{p(x|z) p(z)}{\int p(x|z) p(z) dz}涉及积分计算,这就很难求解了,需要通过近似推断的方法求解,这就引入了变分推断。

“变分”是什么意思?:“变分”来自变分推断(Variational Inference, VI),是通过引入一个已知分布(如高斯分布)q(zx)q(z|x)来逼近复杂分布p(zx)p(z|x),设已知分布参数为ϕ\phi、复杂分布参数为θ\theta,将两个分布记作qϕ(zx)q_{\phi}(z|x)pθ(zx)p_{\theta}(z|x)。那么希望两个分布越接近越好,可以用KL散度来度量。

但注意到,KL散度是非对称的:

  • KL(PQ)=EzP(z)logP(z)Q(z)\text{KL}(P||Q) = \mathbb{E}_{z \sim P(z)} \log \frac{P(z)}{Q(z)},是指用分布QQ近似分布PP,需要保证任意P(z)>0P(z) > 0的地方都有Q(z)>0Q(z) > 0,结果是QQ的分布会覆盖整个PP的分布;
  • KL(QP)=EzQ(z)logQ(z)P(z)\text{KL}(Q||P) = \mathbb{E}_{z \sim Q(z)} \log \frac{Q(z)}{P(z)},是指用分布PP近似分布QQ,当P(z)0P(z) \rightarrow 0时一定有Q(z)0Q(z) \rightarrow 0,结果是使QQ逼近PP的其中一个峰。

在变分推断中,一般用反向KL散度,即

ϕ=argminϕKL(qϕ(zx)pθ(zx))=argminϕEzqϕ(zx)logqϕ(zx)pθ(zx)(5)\begin{aligned} \phi^* &= \arg \min_{\phi} \text{KL}(q_{\phi}(z|x) || p_{\theta}(z|x)) \\ &= \arg \min_{\phi} \mathbb{E}_{z \sim q_{\phi}(z|x)} \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \end{aligned} \tag{5}

其中pθ(zx)p_{\theta}(z|x)未知,需要经过一系列变换才能进行优化。

变分推断与ELBO:对上式进行变换,由贝叶斯公式有pθ(zx)=pθ(xz)pθ(z)pθ(x)p_{\theta}(z|x) = \frac{p_{\theta}(x|z) p_{\theta}(z)}{p_{\theta}(x)},代入可以得到

KL(qϕ(zx)pθ(zx))=Ezqϕ(zx)logqϕ(zx)pθ(x)pθ(xz)pθ(z)=Ezqϕ(zx)logqϕ(zx)pθ(xz)pθ(z)+logpθ(x)Ezqϕ(zx)logpθ(x)=logpθ(x)=Ezqϕ(zx)(logqϕ(zx)pθ(z)logpθ(xz))+logpθ(x)=KL(qϕ(zx)pθ(z))Ezqϕ(zx)logpθ(xz)+logpθ(x)(6)\begin{aligned} \text{KL}(q_{\phi}(z|x) || p_{\theta}(z|x)) &= \mathbb{E}_{z \sim q_{\phi}(z|x)} \log \frac{q_{\phi}(z|x) p_{\theta}(x)}{p_{\theta}(x|z) p_{\theta}(z)} \\ &= \mathbb{E}_{z \sim q_{\phi}(z|x)} \log \frac{q_{\phi}(z|x)}{p_{\theta}(x|z) p_{\theta}(z)} + \log p_{\theta}(x) & \scriptstyle{\mathbb{E}_{z \sim q_{\phi}(z|x)} \log p_{\theta}(x) = \log p_{\theta}(x)}\\ &= \mathbb{E}_{z \sim q_{\phi}(z|x)} \left( \log \frac{q_{\phi}(z|x)}{p_{\theta}(z)} - \log p_{\theta}(x|z) \right) + \log p_{\theta}(x) \\ &= \text{KL}(q_{\phi}(z|x)||p_{\theta}(z)) - \mathbb{E}_{z \sim q_{\phi}(z|x)}\log p_{\theta}(x|z) + \log p_{\theta}(x) \\ \end{aligned} \tag{6}

多项式移项整理后,可以得到

logpθ(x)=KL(qϕ(zx)pθ(zx))KL(qϕ(zx)pθ(z))+Ezqϕ(zx)logpθ(xz)(7)\log p_{\theta}(x) = \text{KL}(q_{\phi}(z|x) || p_{\theta}(z|x)) - \text{KL}(q_{\phi}(z|x)||p_{\theta}(z)) + \mathbb{E}_{z \sim q_{\phi}(z|x)}\log p_{\theta}(x|z) \tag{7}

由于KL散度非负,即KL(qϕ(zx)pθ(zx))0\text{KL}(q_{\phi}(z|x) || p_{\theta}(z|x)) \geq 0,因此

logpθ(x)KL(qϕ(zx)pθ(z))+Ezqϕ(zx)logpθ(xz)(8)\log p_{\theta}(x) \geq - \text{KL}(q_{\phi}(z|x)||p_{\theta}(z)) + \mathbb{E}_{z \sim q_{\phi}(z|x)}\log p_{\theta}(x|z) \tag{8}

右边多项式可以视作logpθ(x)\log p_{\theta}(x)的下界,或称证据变量xx的下界,定义为证据下界(Evidence Lower Bound, ELBO),即

LVI=KL(qϕ(zx)pθ(z))+Ezqϕ(zx)logpθ(xz)(9)-L_{\text{VI}} = - \text{KL}(q_{\phi}(z|x)||p_{\theta}(z)) + \mathbb{E}_{z \sim q_{\phi}(z|x)}\log p_{\theta}(x|z) \tag{9}

那么优化目标就可以进行转换,即

ϕ=argminϕKL(qϕ(zx)pθ(zx))=argminϕLVI(10)\phi^* = \arg \min_{\phi} \text{KL}(q_{\phi}(z|x) || p_{\theta}(z|x)) = \arg \min_{\phi} L_{\text{VI}} \tag{10}

回到变分自编码器:VAE的训练目标定义为最大化真实数据的概率分布,也即

θ=argmaxθi=1npθ(x(i))=argmaxθi=1nlogpθ(x(i))(11)\begin{aligned} \theta^* &= \arg \max_{\theta} \prod_{i=1}^n p_{\theta} (x^{(i)}) \\ &= \arg \max_{\theta} \sum_{i=1}^n \log p_{\theta} (x^{(i)}) \\ \end{aligned} \tag{11}

上面提到,用贝叶斯公式直接展开上式,会引入积分项导致难以求解。而由式(8)(8)又可知,(LVI)(-L_{VI})logpθ(x)\log p_{\theta} (x)的一个下界,那么通过最大化下界,可以间接地最大化logpθ(x)\log p_{\theta} (x),也就是

θ,ϕ=argmaxθ,ϕi=1nKL(qϕ(z(i)x(i))pθ(z(i)))+Ezqϕ(zx(i))logpθ(x(i)z)(12)\theta^*, \phi^* = \arg \max_{\theta, \phi} \sum_{i=1}^n - \text{KL}(q_{\phi}(z^{(i)}|x^{(i)})||p_{\theta}(z^{(i)})) + \mathbb{E}_{z \sim q_{\phi}(z|x^{(i)})}\log p_{\theta}(x^{(i)}|z) \tag{12}

通常最小化损失,因此记变分自编码器的损失为

LVAE=1ni=1nEzqϕ(zx(i))logpθ(x(i)z)+KL(qϕ(z(i)x(i))pθ(z(i)))(13)L_{\text{VAE}} = \frac{1}{n} \sum_{i=1}^n - \mathbb{E}_{z \sim q_{\phi}(z|x^{(i)})}\log p_{\theta}(x^{(i)}|z) + \text{KL}(q_{\phi}(z^{(i)}|x^{(i)})||p_{\theta}(z^{(i)})) \tag{13}

其中,qϕ(zx)q_{\phi}(z|x)是编码器部分,pθ(xz)p_{\theta}(x|z)是解码器部分,pθ(z)p_{\theta}(z)是期望的令zz服从的已知简单分布(如正态分布、均匀分布等)。

损失的具体形式:写到这里,已经完成了形式化的损失函数定义,许多教程在这里就结束了。但阅读一些具体实现的代码,发现损失如式(14)(14)所示,很难将其联系到式(13)(13)上:

LVAE=1ni=1nx(i)x(i)2+12μ(i)2+σ(i)2logσ(i)212(14)L_{\text{VAE}} = \frac{1}{n} \sum_{i=1}^n ||x^{(i)} - x'^{(i)}||^2 + \frac{1}{2} ||\mu^{(i)2} + \sigma^{(i)2} - \log \sigma^{(i)2} - 1||^2 \tag{14}

其中x(i)x^{(i)}是样本点,x(i)x'^{(i)}是重构后的样本点。上面引入近似分布(也即编码器)qϕ(zx)q_{\phi}(z|x)是高斯分布,即qϕ(z(i)x(i))N(μ(i),σ(i)2I)q_{\phi}(z^{(i)}|x^{(i)}) \sim \mathcal{N}(\mu^{(i)}, \sigma^{(i)2}I)μ(i)\mu^{(i)}σ(i)2\sigma^{(i)2}表示x(i)x^{(i)}输入对应的均值、方差。

接下来说明,如何从式(13)(13)得到(14)(14)

形式化损失与具体损失的联系:回到式(13)(13),我们可以将其拆分为重构损失、正则项损失两部分:

{Lrecon=1ni=1nEzqϕ(zx(i))logpθ(x(i)z)Lregu=1ni=1nKL(qϕ(z(i)x(i))pθ(z(i)))(15)\begin{cases} L_{\text{recon}} &= \frac{1}{n} \sum_{i=1}^n - \mathbb{E}_{z \sim q_{\phi}(z|x^{(i)})}\log p_{\theta}(x^{(i)}|z) \\ L_{\text{regu}} &= \frac{1}{n} \sum_{i=1}^n \text{KL}(q_{\phi}(z^{(i)}|x^{(i)})||p_{\theta}(z^{(i)})) \end{cases} \tag{15}

其中:

  • zqϕ(zx(i))z \sim q_{\phi}(z|x^{(i)})表示采样过程,涉及到重参数技巧;
  • LreconL_{\text{recon}}是重构损失,与自编码器一致,LreguL_{\text{regu}}是正则项损失,目的是更好地组织隐空间,使其具有可采样的特性,并防止过拟合;
  • 注意到这两项是相互对抗的,因为最小化LreguL_{\text{regu}}使KL(qϕ(z(i)x(i))pθ(z(i)))=0\text{KL}(q_{\phi}(z^{(i)}|x^{(i)})||p_{\theta}(z^{(i)})) = 0时,zz就没有了任何差异,这样重建准确率就很低,导致LreconL_{\text{recon}}很高,因此最终目的是达到两项的平衡状态。

再看式(15)(15)中各项概率分布:

  • pθ(z)p_{\theta}(z):为了方便采样,一般令zN(0,I)z \sim \mathcal{N}(0, I),这是人为指定的;
  • qϕ(zx)q_{\phi}(z|x):编码器部分,前面变分推断部分已经提到,用高斯分布拟合,得到N(μ,σ2I)\mathcal{N}(\mu, \sigma^2 I)
  • pθ(xz)p_{\theta}(x|z):解码器部分,还没定,也可以选择一个简单分布拟合,如伯努利分布或者高斯分布。

pθ(xz)p_{\theta}(x|z)采用伯努利分布,即多元二项分布,有

pθ(xz)=k=1dpθ(zk)xk(1pθ(zk))1xk(16.1)p_{\theta}(x|z) = \prod_{k=1}^{d} p_{\theta}(z_k)^{x_{k}} (1 - p_{\theta}(z_k))^{1 - x_{k}} \tag{16.1}

其中dd表示随机变量xx的维度,此时xk{0,1},k=1,,dx_k \in \{ 0, 1 \}, k = 1, \cdots, d,那么

Lrecon=1ni=1nEzqϕ(zx(i))logpθ(x(i)z)=1ni=1nlog(k=1dpθ(zk(i))xk(i)(1pθ(zk(i)))1xk(i))=1ni=1nk=1d(xk(i)logpθ(zk(i))(1xk(i))log(1pθ(zk(i))))(16.2)\begin{aligned} L_{\text{recon}} &= \frac{1}{n} \sum_{i=1}^n - \mathbb{E}_{z \sim q_{\phi}(z|x^{(i)})}\log p_{\theta}(x^{(i)}|z) \\ &= \frac{1}{n} \sum_{i=1}^n \log \left( - \prod_{k=1}^{d} p_{\theta}(z^{(i)}_k)^{x^{(i)}_k} (1 - p_{\theta}(z^{(i)}_k))^{1 - x^{(i)}_k} \right) \\ &= \frac{1}{n} \sum_{i=1}^n \sum_{k=1}^{d} \left( - x^{(i)}_k \log p_{\theta}(z^{(i)}_k) - (1 - x^{(i)}_k) \log (1 - p_{\theta}(z^{(i)}_k)) \right) \end{aligned} \tag{16.2}

此时用二元交叉熵作为损失函数。

pθ(xz)p_{\theta}(x|z)采用高斯分布,回顾多维高斯分布:若随机变量xN(μ,Σ)x \sim \mathcal{N}(\mu, \Sigma),有

p(x)=1(2π)d/2Σ1/2exp[12(xμ)TΣ1(xμ)](17.1)p(x) = \frac{1}{(2\pi)^{d/2} |\Sigma|^{1/2}} \exp \left[ - \frac{1}{2} (x - \mu)^T \Sigma^{-1} (x - \mu) \right] \tag{17.1}

很容易得到pθ(x(i)z)p_{\theta}(x^{(i)}|z)的表达式,进一步地,简化假设各分量独立(即Σ\Sigma为对角阵σ2I\sigma^2 I),μ\mu为关于zz的函数,那么

Lrecon=1ni=1nEzqϕ(zx(i))logpθ(x(i)z)=1ni=1nlog(1k=1d(2π)dσk2(z(i))exp(12x(i)μ(z(i))σ(z(i))2))=1ni=1n(12x(i)μ(z(i))σ(z(i))2+12k=1dlog(2π)dσk2(z(i)))=1ni=1n(12x(i)μ(z(i))σ(z(i))2+d2k=1dlog2π+12k=1dσk2(z(i)))(17.2)\begin{aligned} L_{\text{recon}} &= \frac{1}{n} \sum_{i=1}^n - \mathbb{E}_{z \sim q_{\phi}(z|x^{(i)})}\log p_{\theta}(x^{(i)}|z) \\ &= \frac{1}{n} \sum_{i=1}^n \log \left( - \frac{1}{\prod_{k=1}^d \sqrt{(2 \pi)^d \sigma_k^2(z^{(i)})}} \exp \left( - \frac{1}{2} ||\frac{x^{(i)} - \mu(z^{(i)})}{\sigma(z^{(i)})}||^2 \right) \right) \\ &= \frac{1}{n} \sum_{i=1}^n \left( \frac{1}{2} ||\frac{x^{(i)} - \mu(z^{(i)})}{\sigma(z^{(i)})}||^2 + \frac{1}{2} \sum_{k=1}^d \log (2 \pi)^d \sigma_k^2(z^{(i)}) \right) \\ &= \frac{1}{n} \sum_{i=1}^n \left( \frac{1}{2} ||\frac{x^{(i)} - \mu(z^{(i)})}{\sigma(z^{(i)})}||^2 + \frac{d}{2} \sum_{k=1}^d \log 2 \pi + \frac{1}{2} \sum_{k=1}^d \sigma_k^2(z^{(i)}) \right) \end{aligned} \tag{17.2}

为简化计算,令方差项σ(z)\sigma(z)为常数cc,损失可以简化为MSE损失:

Lrecon=1ni=1n12cx(i)μθ(z(i))2+C(17.3)L_{\text{recon}} = \frac{1}{n} \sum_{i=1}^n \frac{1}{2c} ||x^{(i)} - \mu_{\theta}(z^{(i)})||^2 \cancel{+ C} \tag{17.3}

注意到,μθ(z(i))\mu_{\theta}(z^{(i)})即重构的数据x(i)x'^{(i)}

再看正则项损失,有

{qϕ(z(i)x(i))=1k=1h(2π)hσk2(x(i))exp(12z(i)μ(x(i))σ(x(i))2)pθ(z(i))=1k=1h(2π)hexp(12z(i)2)(18.1)\begin{cases} q_{\phi}(z^{(i)}|x^{(i)}) &= \frac{1}{ \prod_{k=1}^h \sqrt{(2 \pi)^h \sigma_k^2(x^{(i)})} } \exp \left( - \frac{1}{2} ||\frac{z^{(i)} - \mu(x^{(i)})}{\sigma(x^{(i)})}||^2 \right) \\ p_{\theta}(z^{(i)}) &= \frac{1}{ \prod_{k=1}^h \sqrt{(2 \pi)^h} } \exp \left( - \frac{1}{2} ||z^{(i)}||^2 \right) \\ \end{cases} \tag{18.1}

Lregu=1ni=1nKL(qϕ(z(i)x(i))pθ(z(i)))=1ni=1nqϕ(z(i)x(i))logqϕ(z(i)x(i))pθ(z(i))dz(i)=20.1式代入计算,略=1ni=1n12μ2(x(i))+σ2(x(i))logσ2(x(i))12(18.2)\begin{aligned} L_{\text{regu}} &= \frac{1}{n} \sum_{i=1}^n \text{KL}(q_{\phi}(z^{(i)}|x^{(i)})||p_{\theta}(z^{(i)})) \\ &= \frac{1}{n} \sum_{i=1}^n \int q_{\phi}(z^{(i)}|x^{(i)}) \log \frac{ q_{\phi}(z^{(i)}|x^{(i)}) }{ p_{\theta}(z^{(i)}) } d z^{(i)} \\ &= \cdots & \scriptstyle{20.1式代入计算,略} \\ &= \frac{1}{n} \sum_{i=1}^n \frac{1}{2} ||\mu^2(x^{(i)}) + \sigma^2(x^{(i)}) - \log \sigma^2(x^{(i)}) - 1||^2 \end{aligned} \tag{18.2}

也即

Lregu=1ni=1n12μ(i)2+σ(i)2logσ(i)212(18.3)L_{\text{regu}} = \frac{1}{n} \sum_{i=1}^n \frac{1}{2} ||\mu^{(i)2} + \sigma^{(i)2} - \log \sigma^{(i)2} - 1||^2 \tag{18.3}

实现细节

编码器与解码器网络:变分推断中提到用高斯分布来逼近pθ(zx)p_{\theta}(z|x),也就是说希望编码器qϕ(zx)q_{\phi}(z|x)输出高斯概率分布。直接令神经网络gϕ(x)g_{\phi}(x)拟合分布参数μ\muσ2\sigma^2(考虑到σ2\sigma^2非负,一般用logσ2\log \sigma^2),那么有

μ,logσ2=gϕ(x)(19.1)\mu, \log \sigma^2 = g_{\phi}(x) \tag{19.1}

解码器部分就比较简单了,只要将采样得到的zz重建,同样用神经网络fθ(z)f_{\theta}(z)表示,也就是

x=fθ(z)(19.2)x' = f_{\theta}(z) \tag{19.2}

隐层特征zz的采样:目前,已经令编码器得到分布N(μ(i),σ(i)2I)\mathcal{N}(\mu^{(i)}, \sigma^{(i)2} I)了,那么如何得到隐层特征z(i)z^{(i)}呢?能够直接从分布中采样得到呢?答案是不可以,因为采样操作是不可导的,导致最终误差无法通过网络反传到编码器实现参数更新。

解决方法是采用重参数技巧(Reparameterization Trick),希望从正态分布N(μ,σ2I)\mathcal{N}(\mu, \sigma^2 I)中采样,可以先从标准正态分布N(0,I)\mathcal{N}(0, I)中采样ϵ\epsilon,然后用以下变换得到zz(由正态分布性质可证):

z=μϵ+σ(20)z = \mu \epsilon + \sigma \tag{20}

这样做,就可以把不可导的采样操作移除到梯度计算图之外,实现误差反传。

具体实现:下面是在MNIST数据集上进实现的的变分自编码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义变分自编码器模型
class VAE(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super(VAE, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size

self.encoder = nn.Sequential(
nn.Linear(self.input_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU()
)

self.mean = nn.Linear(self.hidden_size, self.latent_size)
self.logvar = nn.Linear(self.hidden_size, self.latent_size)

self.decoder = nn.Sequential(
nn.Linear(self.latent_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, self.input_size),
nn.Sigmoid()
)

def encode(self, x):
h = self.encoder(x)
mean = self.mean(h)
logvar = self.logvar(h)
return mean, logvar

def reparameterize(self, mean, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mean + eps * std
return z

def decode(self, z):
x_hat = self.decoder(z)
return x_hat

def forward(self, x):
mean, logvar = self.encode(x)
z = self.reparameterize(mean, logvar)
x_hat = self.decode(z)
return x_hat, mean, logvar

# 定义训练函数
def train(model, dataloader, optimizer, criterion, device):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(dataloader):
data = data.view(data.size(0), -1)
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = criterion(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
return train_loss / len(dataloader.dataset)

# 定义测试函数
@torch.no_grad()
def test(model, dataloader, criterion, device):
model.eval()
test_loss = 0
for data, _ in dataloader:
data = data.view(data.size(0), -1)
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += criterion(recon_batch, data, mu, logvar).item()
return test_loss / len(dataloader.dataset)

# 定义损失函数
def loss_fn(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD

if __name__ == "__main__":
# 加载数据集
batch_size = 128
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型和优化器
input_size = 784
hidden_size = 256
latent_size = 20
model = VAE(input_size, hidden_size, latent_size).to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 训练模型
epochs = 10
for epoch in range(1, epochs+1):
train_loss = train(model, train_loader, optimizer, loss_fn, 'cuda')
test_loss = test(model, test_loader, loss_fn, 'cuda')
print('Epoch {}: Train Loss {:.4f}, Test Loss {:.4f}'.format(epoch, train_loss, test_loss))

torch.save(model.state_dict(), 'vae.pth')

可以用下面代码进行推断

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torchvision.utils import save_image
from vae import VAE

# 加载VAE模型
input_size = 784
hidden_size = 256
latent_size = 20

vae = VAE(input_size, hidden_size, latent_size).to('cuda')
vae.load_state_dict(torch.load('vae.pth'))
vae.eval()

# 从标准正态分布中采样潜在向量
z = torch.randn(64, latent_size)

# 生成新的样本
with torch.no_grad():
z = z.to("cuda")
x_hat = vae.decode(z)

# 将生成的样本保存到文件中
save_image(x_hat.view(64, 1, 28, 28), 'generated_samples.png')

可以多训练几轮,达到更好的效果

参考资料