用 Keras 搭建 GAN:图像去模糊中的应用(附代码)

有生之年
6个月前 阅读 150 点赞 2

2014年 Ian Goodfellow 提出了生成对抗网络(GAN)。这篇文章主要介绍在Keras中搭建GAN实现图像去模糊。所有的Keras代码可点击这里。


  快速回忆生成对抗网络

GAN中两个网络的训练相互竞争。生成器( generator) 合成具有说服力的假输入来误导判别器(discriminator ),而判别器则是来识别这个输入是真的还是假的

生成对抗网络训练过程— 来源


训练过程主要有三步

  • 根据噪声,生成器合成假的输入
  • 用真的输入和假的输入共同训练判别器
  • 训练整个模型:整个模型中判别器与生成器连接

注意:在第三步中,判别器的权重是固定的


将这两个网络连接起来是由于生成器的输出没有可用的反馈。我们唯一的准则就是看判别器是否接受生成器的合成的例子。

这些只是对生成对抗网络的一个简单回顾,如果还是不够明白的话,可以参考完整介绍。


  数据

Ian Goodfellow首次使用GAN模型是生成MNIST数据。 而本篇文章是使用生成对抗网络进行图像去模糊。因此生成器的输入不是噪声,而是模糊图像。

数据集来自GOPRO数据,你可以下载精简版数据集(https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing),也可以下载完整版数据集(https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view?usp=sharing)。其中包含了来自不同街道视角的人造模糊图像,根据不同的场景将数据集分在各个子文件夹中。

我们先把图像分到 A(模糊)和 B(清晰)两个文件夹。这个 A&B 结构对应于原始文章pix2pix 。我创建了一个 自定义脚本来执行这个任务。 看看 README 后尝试一下吧。 


  模型

训练过程还是一样,首先来看一下神经网络结构。


生成器

生成器要生成清晰图像,网络是基于ResNet blocks的,它可以记录对原始模糊图像操作的过程。原文还使用了基于UNet的版本,但我目前还没有实现。这两种结构都可以很好地进行图像去模糊。

DeblurGAN 生成器网络 结构 — 来源


核心是采用9个ResNet blocks对原始图像进行上采样。来看一下Keras上的实现!



ResNet 层就是一个基本的卷积层,其中,输入和输出相加,形成最终输出。


生成器结构的 Keras 实现


按照计划,用9个ResNet blocks对输入进行上采样。我们在输入到输出增加一个连接,然后除以2 来对输出进行归一化。

这就是生成器了! 我们再来看看判别器的结构吧。


判别器

判别器的目标就是要确定一张输入图片是否为合成的。因此判别器的结构采用卷积结构,而且是一个单值输出

判别器结构的 Keras 实现

最后一步就是建立完整的模型。这个GAN的一个特点就是输入的是真实图片而不是噪声 。因此我们就有了一个对生成器输出的直接反馈


接下来看看采用两个损失如何充分利用这个特殊性。


  训练


损失

我们提取生成器最后和整个模型最后的损失。

第一个是感知损失,根据生成器输出直接可以计算得到。第一个损失保证 GAN 模型针对的是去模糊任务。它比较了VGG第一次卷积的输出


第二个损失是对整个模型输出计算的 Wasserstein loss,计算了两张图像的平均差值。众所周知,这种损失可以提高生成对抗网络的收敛性。



训练流程

第一步是加载数据并初始化模型。我们使用自定义函数加载数据集,然后对模型使用 Adam 优化器。我们设置 Keras 可训练选项来防止判别器进行训练。



然后我们进行epochs(一个完整的数据集通过了神经网络一次并且返回了一次的过程,称为一个epoch),并将整个数据集分批次(batches)。


最后根据两者的损失,可以相继训练判别器和生成器。用生成器生成假的输入,训练判别器区别真假输入,并对整个模型进行训练。


你可以参考Github来查看完整的循环。


  实验

我使用的是在AWS 实例(p2.xlarge)上配置深度学习 AMI (version 3.0)进行的 。对GOPRO 精简版数据集的训练时间大约有 5 个小时(50个epochs)。

图像去模糊结果


从左到右:原始图像,模糊图像,GAN 输出


上面的输出结果都是我们用 Keras 进行 Deblur GAN 的结果。即使是对高度模糊,网络也可以减小模糊,产生一张具有更多信息的图片,使得车灯更加汇聚,树枝更加清晰。


左图: GOPRO 测试图像,右图:GAN 输出结果


因为引入了 VGG 来计算损失,所以会产生图像顶部出现感应特征的局限。


左图: GOPRO 测试图像,右图:GAN 输出结果


希望你们可以喜欢这篇关于生成对抗网络用于图像去模糊的文章。 你可以评论,关注我或者联系我。

如果你对机器视觉感兴趣,我们还写过一篇用Keras实现基于内容的图像复原 。下面是生成对抗网络资源的列表。


左图: GOPRO 测试图像,右图:GAN 输出结果


生成对抗网络资源

NIPS 2016: Generative Adversarial Networks by Ian Goodfellow

ICCV 2017: Tutorials on GAN

GAN Implementations with Keras by Eric Linder-Noren

A List of Generative Adversarial Networks Resources by deeplearning4j

Really-awesome-gan by Holger Caesar

| 2
登录后可评论,马上登录吧~
评论 ( 0 )

还没有人评论...