TL;DR
最近,AIGC是极火热的讨论话题,而文生图可以说是AIGC的代表性工作。目前,效果最好的文生图模型是基于扩散模型的,当进一步深入扩散模型时,又对他的损失函数产生了很大的疑问。通过查找各方资料,才发现扩散模型与变分自编码器在损失定义上同出一门,理解了变分自编码器的损失自然也能理解扩散模型的损失。
另外,变分自编码器已经作为基础模型,集成到许多后续工作中,例如:
- Stable Diffusion用变分自编码器获取图片的潜在表征(latents)进行前向扩散,避免直接在像素空间中前向扩散,极大地提升了计算效率;
- 作为变分自编码器的拓展性工作,向量化离散变分自编码器(Vector Quantised-Variational AutoEncoder, VQ-VAE)已经被广泛用作图像分词器,如BEIT、DALL·E等。
可以说,变分自编码器是过不去的一个坎,极有必要对变分自编码器做细致的了解。
但是,查阅已有资料发现,有关变分自编码器的教程总是伴随复杂的公式推导,而实现的代码又难以与公式严格对应。另外,理论部分还涉及变分推断、ELBO、重参数等等多种技巧,让人摸不着头脑。本文将从基本原理入手,逐步介绍变分自编码器的概念、损失函数、推断过程等关键内容,旨在对变分自编码器理论的来龙去脉进行详细的解释,并将推导过程与具体实现相结合,帮助更好地理解变分自编码器。
理论部分
什么是自编码器?:自编码器(AutoEncoder, AE)是一种无监督方式训练的神经网络,主要思想是将高维的输入数据进行编码、压缩,得到低维的特征表示,然后将该特征解码回原始数据,从而学习数据的特征表示。可以用于数据压缩、降维、异常检测、图像去噪等。
如图所示,自编码器包含两个部分:
- 编码器(Encoder):将原始高维数据映射到低维隐空间中,以得到低维特征表示;
- 解码器(Decoder):低维隐空间中的特征表示作为输入,将其重新映射到原始数据空间,以得到重建数据。
记原始输入数据点为x,编码器为gϕ,编码后的特征为z,解码器为fθ,解码重建后的数据为x′,那么就有
zx′=gϕ(x)=fθ(z)(1)
其中ϕ和θ分别为编码器g(⋅)和解码器f(⋅)的参数。最终的目标是学习一个恒等映射,即
x′≈fθ(gϕ(x))(2)
损失可以用x′与x间的距离度量定义,如熵、MSE等,下面用MSE定义损失
LAE(θ,ϕ)=n1i=1∑n(x(i)−fθ(gϕ(x(i))))2(3)
自编码器与内容生成:那么训练结束后,获得了编码器、解码器两个网络,除了对原始数据的压缩、降维,是否还可以用来生成数据?比如在隐空间随机取一个特征,用解码器对这个特征进行重构,从而得到新的数据。
这听起来是合理的,但事实上这样做的结果却不尽如人意,原因是:
- 自编码器的训练目标是重构输入数据,模型规模较大、数据量较小的情况下,能做到一对一的映射,但也引入了过拟合问题;
- 训练过程中没有对隐空间作任何限制,也就是说隐空间是以任意方式组织的,导致是不连续的,呈现不规则的、无界的分布。
也就是说,隐空间中随机选取特征可能不具有任何实际含义,导致解码后的结果无意义。
变分自编码器如何解决这个问题?:变分自编码器(Variational AutoEncoder)是一种改进的自编码器,目的是使自编码器能应用于内容生成。其思想是:将原始数据编码为隐空间中的概率分布,而不是特定的单个特征,使隐空间具有可采样的特性。
进一步地,为了使隐空间具有可采样的特性,可以令隐变量z服从某简单分布(如正态分布),那么可以通过下面步骤采样得到隐层表征,并重构生成数据:
- 从先验概率pθ(z)中采样,得到特征z(i);
- 用似然函数pθ(x∣z=z(i))重构数据,得到x′。
那么,接下来的问题就是如何估计变分自编码器的参数θ。在解决这个问题前,先从贝叶斯模型角度讲解“变分推断”是怎么回事。
从贝叶斯模型谈起:假设输入变量为x,隐变量是z(在分类问题中即标签y,回归问题中就是预测值),那么贝叶斯模型中有
- 先验概率p(z)
- 似然函数p(x∣z)
- 后验概率p(z∣x)
它们之间的联系可以用贝叶斯公式描述:
p(z∣x)=p(x)p(x∣z)p(z)(4.1)
其中
p(x)=∫p(x,z)dz=∫p(x∣z)p(z)dz(4.2)
其中,p(z)和p(x∣z)可以从数据集估计得到,那么目的就是为了求解后验概率分布p(z∣x)。将已知项代入上式就能得到结果,但可以看到,p(z∣x)=∫p(x∣z)p(z)dzp(x∣z)p(z)涉及积分计算,这就很难求解了,需要通过近似推断的方法求解,这就引入了变分推断。
“变分”是什么意思?:“变分”来自变分推断(Variational Inference, VI),是通过引入一个已知分布(如高斯分布)q(z∣x)来逼近复杂分布p(z∣x),设已知分布参数为ϕ、复杂分布参数为θ,将两个分布记作qϕ(z∣x)和pθ(z∣x)。那么希望两个分布越接近越好,可以用KL散度来度量。
但注意到,KL散度是非对称的:
- KL(P∣∣Q)=Ez∼P(z)logQ(z)P(z),是指用分布Q近似分布P,需要保证任意P(z)>0的地方都有Q(z)>0,结果是Q的分布会覆盖整个P的分布;
- KL(Q∣∣P)=Ez∼Q(z)logP(z)Q(z),是指用分布P近似分布Q,当P(z)→0时一定有Q(z)→0,结果是使Q逼近P的其中一个峰。
在变分推断中,一般用反向KL散度,即
ϕ∗=argϕminKL(qϕ(z∣x)∣∣pθ(z∣x))=argϕminEz∼qϕ(z∣x)logpθ(z∣x)qϕ(z∣x)(5)
其中pθ(z∣x)未知,需要经过一系列变换才能进行优化。
变分推断与ELBO:对上式进行变换,由贝叶斯公式有pθ(z∣x)=pθ(x)pθ(x∣z)pθ(z),代入可以得到
KL(qϕ(z∣x)∣∣pθ(z∣x))=Ez∼qϕ(z∣x)logpθ(x∣z)pθ(z)qϕ(z∣x)pθ(x)=Ez∼qϕ(z∣x)logpθ(x∣z)pθ(z)qϕ(z∣x)+logpθ(x)=Ez∼qϕ(z∣x)(logpθ(z)qϕ(z∣x)−logpθ(x∣z))+logpθ(x)=KL(qϕ(z∣x)∣∣pθ(z))−Ez∼qϕ(z∣x)logpθ(x∣z)+logpθ(x)Ez∼qϕ(z∣x)logpθ(x)=logpθ(x)(6)
多项式移项整理后,可以得到
logpθ(x)=KL(qϕ(z∣x)∣∣pθ(z∣x))−KL(qϕ(z∣x)∣∣pθ(z))+Ez∼qϕ(z∣x)logpθ(x∣z)(7)
由于KL散度非负,即KL(qϕ(z∣x)∣∣pθ(z∣x))≥0,因此
logpθ(x)≥−KL(qϕ(z∣x)∣∣pθ(z))+Ez∼qϕ(z∣x)logpθ(x∣z)(8)
右边多项式可以视作logpθ(x)的下界,或称证据变量x的下界,定义为证据下界(Evidence Lower Bound, ELBO),即
−LVI=−KL(qϕ(z∣x)∣∣pθ(z))+Ez∼qϕ(z∣x)logpθ(x∣z)(9)
那么优化目标就可以进行转换,即
ϕ∗=argϕminKL(qϕ(z∣x)∣∣pθ(z∣x))=argϕminLVI(10)
回到变分自编码器:VAE的训练目标定义为最大化真实数据的概率分布,也即
θ∗=argθmaxi=1∏npθ(x(i))=argθmaxi=1∑nlogpθ(x(i))(11)
上面提到,用贝叶斯公式直接展开上式,会引入积分项导致难以求解。而由式(8)又可知,(−LVI)是logpθ(x)的一个下界,那么通过最大化下界,可以间接地最大化logpθ(x),也就是
θ∗,ϕ∗=argθ,ϕmaxi=1∑n−KL(qϕ(z(i)∣x(i))∣∣pθ(z(i)))+Ez∼qϕ(z∣x(i))logpθ(x(i)∣z)(12)
通常最小化损失,因此记变分自编码器的损失为
LVAE=n1i=1∑n−Ez∼qϕ(z∣x(i))logpθ(x(i)∣z)+KL(qϕ(z(i)∣x(i))∣∣pθ(z(i)))(13)
其中,qϕ(z∣x)是编码器部分,pθ(x∣z)是解码器部分,pθ(z)是期望的令z服从的已知简单分布(如正态分布、均匀分布等)。
损失的具体形式:写到这里,已经完成了形式化的损失函数定义,许多教程在这里就结束了。但阅读一些具体实现的代码,发现损失如式(14)所示,很难将其联系到式(13)上:
LVAE=n1i=1∑n∣∣x(i)−x′(i)∣∣2+21∣∣μ(i)2+σ(i)2−logσ(i)2−1∣∣2(14)
其中x(i)是样本点,x′(i)是重构后的样本点。上面引入近似分布(也即编码器)qϕ(z∣x)是高斯分布,即qϕ(z(i)∣x(i))∼N(μ(i),σ(i)2I),μ(i)和σ(i)2表示x(i)输入对应的均值、方差。
接下来说明,如何从式(13)得到(14)。
形式化损失与具体损失的联系:回到式(13),我们可以将其拆分为重构损失、正则项损失两部分:
{LreconLregu=n1∑i=1n−Ez∼qϕ(z∣x(i))logpθ(x(i)∣z)=n1∑i=1nKL(qϕ(z(i)∣x(i))∣∣pθ(z(i)))(15)
其中:
- z∼qϕ(z∣x(i))表示采样过程,涉及到重参数技巧;
- Lrecon是重构损失,与自编码器一致,Lregu是正则项损失,目的是更好地组织隐空间,使其具有可采样的特性,并防止过拟合;
- 注意到这两项是相互对抗的,因为最小化Lregu使KL(qϕ(z(i)∣x(i))∣∣pθ(z(i)))=0时,z就没有了任何差异,这样重建准确率就很低,导致Lrecon很高,因此最终目的是达到两项的平衡状态。
再看式(15)中各项概率分布:
- pθ(z):为了方便采样,一般令z∼N(0,I),这是人为指定的;
- qϕ(z∣x):编码器部分,前面变分推断部分已经提到,用高斯分布拟合,得到N(μ,σ2I);
- pθ(x∣z):解码器部分,还没定,也可以选择一个简单分布拟合,如伯努利分布或者高斯分布。
当pθ(x∣z)采用伯努利分布,即多元二项分布,有
pθ(x∣z)=k=1∏dpθ(zk)xk(1−pθ(zk))1−xk(16.1)
其中d表示随机变量x的维度,此时xk∈{0,1},k=1,⋯,d,那么
Lrecon=n1i=1∑n−Ez∼qϕ(z∣x(i))logpθ(x(i)∣z)=n1i=1∑nlog(−k=1∏dpθ(zk(i))xk(i)(1−pθ(zk(i)))1−xk(i))=n1i=1∑nk=1∑d(−xk(i)logpθ(zk(i))−(1−xk(i))log(1−pθ(zk(i))))(16.2)
此时用二元交叉熵作为损失函数。
当pθ(x∣z)采用高斯分布,回顾多维高斯分布:若随机变量x∼N(μ,Σ),有
p(x)=(2π)d/2∣Σ∣1/21exp[−21(x−μ)TΣ−1(x−μ)](17.1)
很容易得到pθ(x(i)∣z)的表达式,进一步地,简化假设各分量独立(即Σ为对角阵σ2I),μ为关于z的函数,那么
Lrecon=n1i=1∑n−Ez∼qϕ(z∣x(i))logpθ(x(i)∣z)=n1i=1∑nlog⎝⎛−∏k=1d(2π)dσk2(z(i))1exp(−21∣∣σ(z(i))x(i)−μ(z(i))∣∣2)⎠⎞=n1i=1∑n(21∣∣σ(z(i))x(i)−μ(z(i))∣∣2+21k=1∑dlog(2π)dσk2(z(i)))=n1i=1∑n(21∣∣σ(z(i))x(i)−μ(z(i))∣∣2+2dk=1∑dlog2π+21k=1∑dσk2(z(i)))(17.2)
为简化计算,令方差项σ(z)为常数c,损失可以简化为MSE损失:
Lrecon=n1i=1∑n2c1∣∣x(i)−μθ(z(i))∣∣2+C(17.3)
注意到,μθ(z(i))即重构的数据x′(i)。
再看正则项损失,有
⎩⎨⎧qϕ(z(i)∣x(i))pθ(z(i))=∏k=1h(2π)hσk2(x(i))1exp(−21∣∣σ(x(i))z(i)−μ(x(i))∣∣2)=∏k=1h(2π)h1exp(−21∣∣z(i)∣∣2)(18.1)
Lregu=n1i=1∑nKL(qϕ(z(i)∣x(i))∣∣pθ(z(i)))=n1i=1∑n∫qϕ(z(i)∣x(i))logpθ(z(i))qϕ(z(i)∣x(i))dz(i)=⋯=n1i=1∑n21∣∣μ2(x(i))+σ2(x(i))−logσ2(x(i))−1∣∣220.1式代入计算,略(18.2)
也即
Lregu=n1i=1∑n21∣∣μ(i)2+σ(i)2−logσ(i)2−1∣∣2(18.3)
实现细节
编码器与解码器网络:变分推断中提到用高斯分布来逼近pθ(z∣x),也就是说希望编码器qϕ(z∣x)输出高斯概率分布。直接令神经网络gϕ(x)拟合分布参数μ和σ2(考虑到σ2非负,一般用logσ2),那么有
μ,logσ2=gϕ(x)(19.1)
解码器部分就比较简单了,只要将采样得到的z重建,同样用神经网络fθ(z)表示,也就是
x′=fθ(z)(19.2)
隐层特征z的采样:目前,已经令编码器得到分布N(μ(i),σ(i)2I)了,那么如何得到隐层特征z(i)呢?能够直接从分布中采样得到呢?答案是不可以,因为采样操作是不可导的,导致最终误差无法通过网络反传到编码器实现参数更新。
解决方法是采用重参数技巧(Reparameterization Trick),希望从正态分布N(μ,σ2I)中采样,可以先从标准正态分布N(0,I)中采样ϵ,然后用以下变换得到z(由正态分布性质可证):
z=μϵ+σ(20)
这样做,就可以把不可导的采样操作移除到梯度计算图之外,实现误差反传。
具体实现:下面是在MNIST数据集上进实现的的变分自编码器
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
| import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader
class VAE(nn.Module): def __init__(self, input_size, hidden_size, latent_size): super(VAE, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.latent_size = latent_size self.encoder = nn.Sequential( nn.Linear(self.input_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU() ) self.mean = nn.Linear(self.hidden_size, self.latent_size) self.logvar = nn.Linear(self.hidden_size, self.latent_size) self.decoder = nn.Sequential( nn.Linear(self.latent_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, self.input_size), nn.Sigmoid() ) def encode(self, x): h = self.encoder(x) mean = self.mean(h) logvar = self.logvar(h) return mean, logvar def reparameterize(self, mean, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mean + eps * std return z def decode(self, z): x_hat = self.decoder(z) return x_hat def forward(self, x): mean, logvar = self.encode(x) z = self.reparameterize(mean, logvar) x_hat = self.decode(z) return x_hat, mean, logvar
def train(model, dataloader, optimizer, criterion, device): model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(dataloader): data = data.view(data.size(0), -1) data = data.to(device) optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = criterion(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() return train_loss / len(dataloader.dataset)
@torch.no_grad() def test(model, dataloader, criterion, device): model.eval() test_loss = 0 for data, _ in dataloader: data = data.view(data.size(0), -1) data = data.to(device) recon_batch, mu, logvar = model(data) test_loss += criterion(recon_batch, data, mu, logvar).item() return test_loss / len(dataloader.dataset)
def loss_fn(recon_x, x, mu, logvar): BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD
if __name__ == "__main__": batch_size = 128 train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
input_size = 784 hidden_size = 256 latent_size = 20 model = VAE(input_size, hidden_size, latent_size).to('cuda') optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 10 for epoch in range(1, epochs+1): train_loss = train(model, train_loader, optimizer, loss_fn, 'cuda') test_loss = test(model, test_loader, loss_fn, 'cuda') print('Epoch {}: Train Loss {:.4f}, Test Loss {:.4f}'.format(epoch, train_loss, test_loss))
torch.save(model.state_dict(), 'vae.pth')
|
可以用下面代码进行推断
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| import torch from torchvision.utils import save_image from vae import VAE
input_size = 784 hidden_size = 256 latent_size = 20
vae = VAE(input_size, hidden_size, latent_size).to('cuda') vae.load_state_dict(torch.load('vae.pth')) vae.eval()
z = torch.randn(64, latent_size)
with torch.no_grad(): z = z.to("cuda") x_hat = vae.decode(z)
save_image(x_hat.view(64, 1, 28, 28), 'generated_samples.png')
|
可以多训练几轮,达到更好的效果
参考资料