GAN的基础理解及代码解析

本文最后更新于:1 年前

引言

在先前的图像方法增强合集(这个还没写)里,了解到GAN可以用于生成图像,且是无监督学习,即意味着我们无需人为对样本打标签,就可以学习到样本数据里的相关图像信息

论文原文:Generative Adversarial Nets

注:PO在上面的是作者后面经修改的版本,在arXiv上的是14年的初稿,related work部分比较空荡荡

模型介绍

GAN(Generative Adversarial Nets)生成对抗网络,它是由两个部分组成,一个部分是生成器GG(Generator),用于生成图像;一个部分是判别器DD(Discriminator),用于判别图像的真假,即图像来源是我们的样本数据集,还是生成的图像

我们通过名字以及引言部分的介绍,可以知道GAN是用来生成图像的网络,那么对抗指的是什么?

在原论文中,作者是以一个制造假币的团伙和警察来介绍对抗的概念:制造假币的罪犯(这个就是生成模型)希望自己制造的假币能像真币一样流通在市场上,即假币假得跟真币一样;而警察(这个是判别模型)则是负责抓造假币的罪犯.二者是一个对立、对抗的关系.那么通过警察识别并查获假币,造假币的罪犯为了使假币流通不断提升技术以期假币能以假乱真,在二者的对抗关系下,发展到最后,警察将无法试别假币与真币.

上述即解释了整个模型的思想,生成模型GG期望能生成足够类似样本数据集的图片以欺骗判别模型DD,而判别模型则期望能够最大限度的区分输入的图片是来自样本数据集的还是生成模型GG.

原理解释

上述对GAN的模型解释,从原文中抽象一点的角度来说,即生成器GG是用于学习样本数据的数据分布情况,当我们学习到合适的分布情况下,则通过高斯分布或是均匀分布生成的噪音变量zz用于生成器GG的输入,可以映射到对应的分布中去,生成x~\widetilde{x},则此刻判别模型DD无法判别输入的图片是来自数据集中的xx还是生成的x~\widetilde{x},则此刻判别概率达到最优,为12\frac{1}{2}

因此我们为了使生成器GG能学习到在xx上的分布pgp_g,定义了输入噪音变量pz(z)p_z(z)作为先验,GG的多层感知机的参数为(论文中GGDD都是这个模型,便于利用反向传播更新参数)θg\theta_g,利用zzθg\theta_g共同表示由G到数据分布空间的映射函数:G(z;θg)G(z;\theta_g);而我们的DD的映射函数则用:D(x,θd)D(x,\theta_d)表示
根据本节开头所述,我们所期望的是:
1.DD尽可能的区分开真实样本和生成样本;
2.GG尽可能的骗过D,让它无法分辨真是样本和生成样本;
于是就有以下价值函数V(G,D)V(G,D)成立

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\min_{G}\max_{D}V(D,G) = \mathbb{E}_ {x \sim p_{data}(x)}\left[\log{D(x)} \right] +\mathbb{E}_ {z \sim p_z(z)}\left[\log{(1-D(G(z)))} \right]

我们知道对于映射DD来说,其输出是一个介于[0,1]\left[0,1 \right]的概率值,而log2N\log_{2}{N}对应的是单调递增函数,则欲使DD取得max,即使得D(x)D(x)趋于1,即分辨出来自样本数据集中的图片,让D(G(z))D(G(z))趋于0,即分辨出来自GG中生成的图片,则可以使得整个V(D,G)V(D,G)max;此刻则满足期望的第1点;
而欲使得GG取得min,则是使得D(G(z))D(G(z))为1,即使得DD将其误认为是样本数据集中的数据,则此刻整个log\log取值为-\infty,为最小(当然D(G(z))也取不到1,只是趋向1),则此刻满足期望的第二点

以下是论文中的算法部分:

GAN-Algorithm

从上图可以看出,我们先是对判别模型通过循环进行了优化,对于它而言,它的优化是对其参数加上所求得的梯度值以更新模型(即增强其区分能力);
而后出了判别模型优化的循环后再对我们的生成模型做部分的改变,于它而言对参数的更新则是减去梯度值,以更新模型,增强图形的生成能力.

我们肯定注意得到,它先是对DD优化,再对GG做小部分的优化.从前面那个造假币的例子中其逻辑可以想清,此处不赘述

当然,这一部分算法最大的问题,应该(我觉得)是出在超参数难以调控,难以达到收敛情况,即使得DD的判断变为12\frac{1}{2}

代码讲解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Generator(nn.Module):
def __init__(self):
super().__init__()

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh())

def forward(self, z):
img = self.model(z) # [64,784]
img = img.view(img.size(0), *img_shape) # [64,1,28,28]
return img

上面这个是Generator的代码,通过一个nn.Sequential的序列容器严格规定网络中layer的执行顺序,前一层输出作为后一层输入,严格遵守,否则报RuntimeError
初始进去的zz的维度是[64,100],64是batch_size,100是它的初始维度,维度通过全连接层变成了1024后变为img_shape:[1,28,28]的所有元素的乘积即784,然后通过view将维度转换为[64,1,28,28]
此即完成了由一个高斯分布生成的噪声变为灰度图片的过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Discriminator(nn.Module):

def __init__(self):
super().__init__()

self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid())

def forward(self, img):
# img: [64,1,28,28]
img_flat = img.view(img.size(0), -1) # img_flat: [64,784]
validity = self.model(img_flat) # validity: [64,1]
return validity

上面这个是Discriminator的代码,其实就是Generator的反过程,将输入的MNIST数据集中的手写图片转为[64,784]后由model训练,根据sigmoid做分类,其中LeakyReLU激活函数就是在ReLU的负半轴开启了衰弱的梯度衰减(LeakyReLU=max(0,x)+leakmin(0,x)LeakyReLU= \max(0,x) + leak * \min(0,x)),返回的分类结果是[64,1]的向量,里面的值是图片来自generator或是数据集的概率

1
adversarial_loss = torch.nn.BCELoss()

采用的损失函数是二分类交叉熵,我们的V(D,G)V(D,G)价值函数可以从该损失函数公式推得具体没细看

1
2
3
4
5
6
optimizer_g = torch.optim.Adam(generator.parameters(),
lr=opt.lr,
betas=(opt.b1, opt.b2))
optimizer_d = torch.optim.Adam(discriminator.parameters(),
lr=opt.lr,
betas=(opt.b1, opt.b2))

对生成器和判别器都采用了Adam优化器,之后就是对优化器使用的三部曲:

  1. 梯度清零
  2. backward
  3. 更新梯度

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# imgs: [64,1,28,28]

valid = Variable(
Tensor(imgs.size(0), 1).fill_(1.0), # [64,1]值为1的向量
requires_grad=False)
fake = Variable(
Tensor(imgs.size(0), 1).fill_(0.0), # [64,1]值为0的向量
requires_grad=False)

# 配置输入
real_imgs = Variable(imgs.type(Tensor)) # 转换type类型为torch.cuda.FloatTensor, [64,1,28,28]

# 训练G
optimizer_g.zero_grad() # 梯度置0/梯度清零

# 噪声作为输入
z = Variable(Tensor(np.random.normal( 0, 1, # normal->符合高斯分布的概率密度随机数
(imgs.shape[0], opt.latent_dim)))) # [64,100]

# 生成批量images
gen_imgs = generator(z)

# generator's loss 尽可能欺骗D
g_loss = adversarial_loss(discriminator(gen_imgs), valid)

g_loss.backward() # backward
optimizer_g.step() # 更新梯度

# 训练D
optimizer_d.zero_grad()

# discriminator's loss 尽可能试别数据集图片和生成器生成的图片
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2

d_loss.backward() # backward
optimizer_d.step() # 更新梯度

代码参考源于这个大佬:GitHub跳转

实验展示

所用数据集是MNIST
以下是100个epoch出来的结果

GAN-EX-0

GAN-EX-1

GAN-EX-2

可以看出GAN的生成器学到了一些东西,但其实效果不是那么好,因为超参数很难控制好,置其收敛

总结

总的来说,GAN是一个开创性的想法,随它之后可以看到与之相关的论文呈井喷式增加,GAN的他引次数也达50k加次,确实很厉害
它的效果虽然不尽人意,但是开创性的思维带来了一个领域的突破,后期会更新1~2个改良后的GAN.


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!