计算机系统应用  2019, Vol. 28 Issue (10): 201-206   PDF    
基于生成对抗网络的数据增强方法
张晓峰, 吴刚     
中国科学技术大学 信息科学技术学院, 合肥 230031
摘要:深度学习在分类任务上取得了革命性的突破, 但是需要大量的有标签数据作为支撑. 当数据匮乏的时候,神经网络极易出现过拟合的问题, 这种现象在小规模数据集上尤为明显. 针对这一难题, 本文提出了一种基于生成对抗网络的数据增强方法, 并将其应用于解决由于数据匮乏, 神经网络难以训练的问题. 实验结果表明, 合成的数据和真实的数据相比既具有语义上的相似性, 同时又能呈现出文本上的多样性; 加入合成的数据后, 神经网络能够更加稳定地训练, 而且分类的准确度也有了进一步的提高. 将提出的算法和其他一些数据增强的技术对比, 我们的方法结果最好, 从而证明了这种技术的可行性和有效性.
关键词: 生成对抗网络    数据增强    图片分类    卷积神经网络    深度学习    
Data Augmentation Method Based on Generative Adversarial Network
ZHANG Xiao-Feng, WU Gang     
School of Information Science and Technology, University of Science and Technology of China, Hefei 230031, China
Abstract: Deep learning has revolutionized the performance of classification, but meanwhile demands sufficient labeled data for training. Given insufficient data, neural network is apt to overfitting, which is quite general in low data regime. We propose a data augmentation technique based on generative adversarial network to address the network training and data shortage problem. The experimental results show that the synthesized data has semantic similarity compared with the real data, and at the same time it can present the diversity of the context. After adding the synthesized data, the neural network can be trained more stably, and the accuracy of the classification is further improved. Comparing the proposed algorithm with some other data augmentation techniques, the proposed method has the best performance, which proves the feasibility and effectiveness of this technique.
Key words: generative adversarial network     data augmentation     image classification     convolutional network     deep learnning    

随着近些年来深度学习的发展, 深度神经网络[1]在分类任务上取得了革命性的突破. 基于深度神经网络的分类器在有充足标签样本为训练数据的前提下可以达到很高的准确度. 但是往往在一些场景下, 有标签的数据难以收集或者获取这些数据成本高昂, 费时费力. 当数据不足时, 神经网络很难稳定训练并且泛化能力较弱. 如何在小规模数据集上有效的训练神经网络成为当下的一个研究热点. 常见的应对小规模数据集训练问题的措施主要有以下3种:

(1) 无监督预训练和有监督微调相结合的方法. 通过引入和训练数据具有相同分布的大量无标签数据的方式, 神经网络可以先收敛到一个较优的初始点, 然后再在小数据上微调. 但是这种方式存在一个潜在的假设: 无标签的数据容易获得而且收集成本不高, 但是在一些数据难以获取的场景中, 例如医疗图像, 这种方法将无法应用.

(2) 迁移学习[2]的方法. 相比于第一种方法, 迁移学习的要求更加宽泛, 额外的无标签数据不需要和训练数据具有相同的分布, 只要相似或者分布有重叠即可. 在视觉识别当中, 一些视觉的基本模式像边缘、纹理等在自然图像中都是共通的, 这一点构成了迁移学习的理论保证. 大量的实践表明, 在源领域(source domain)上学习大量数据后的网络再迁移到目标领域(target domain)上, 网络的性能会得到极大的提升. 但是当源领域和目标领域之间差距甚大时, 迁移学习是否有所帮助, 目前还未有研究.

(3) 数据增强[3]的方法. 通过合成或者转换的方式, 从有限的数据中生成新的数据, 数据增强技术一直以来都是一种重要的克服数据不足的手段. 传统的图像领域的数据增强技术是建立在一系列已知的仿射变换——例如旋转、缩放、位移等, 以及一些简单的图像处理手段——例如光照色彩变换、对比度变换、添加噪声等基础上的. 这些变化的前提是不改变图像的标签, 并且只能局限在图像领域. 这种基于几何变换和图像操作的数据增强方法可以在一定程度上缓解神经网络过拟合的问题, 提高泛化能力. 但是相比与原始数据而言, 增加的数据点并没有从根本上解决数据不足的难题; 同时, 这种数据增强方式需要人为设定转换函数和对应的参数, 一般都是凭借经验知识, 最优数据增强通常难以实现, 所以模型的泛化性能只能得到有限的提升.

最近兴起的一些生成模型, 由于其出色的性能引起了人们的广泛关注. 例如变分自编码网络(Variational Auto-Encoding network, VAE)[4]和生成对抗网络(Generative Adversarial Network, GAN)[5], 其生成样本的方法也可以用于数据增强. 这种基于网络合成的方法相比于传统的数据增强技术虽然过程更加复杂, 但是生成的样本更加多样, 同时还可以应用于图像编辑, 图像去噪等各种场景. 本文主要介绍的是基于生成对抗网络的数据增强技术, 并将这种方法应用于小规模数据集的分类任务.

1 相关生成对抗网络模型介绍 1.1 生成对抗网络

生成模型可以分成显式密度模型和隐式密度模型两种. 生成对抗网络是一种隐式密度模型, 即网络没有显式的给出数据分布的密度函数, GAN的网络结构如图1所示, 是由生成网络(Generator, G)和判别网络(Discriminator, D)两部分组成. 假设在低维空间Z存在一个简单容易采样的分布p(z), 例如标准正态分布N(0, I), 生成网络构成一个映射函数G: ZX, 判别网络需要判别输入是来自真实数据还是生成网络生成的数据. 生成网络输入噪声z, 输出生成的图像数据; 判别网络输入的数据或者来自真实数据集, 或者来自生成网络合成的数据, 输出数据为真的概率.

图 1 GAN结构示意图

GD相互竞争: G试图欺骗D从而以假乱真, 而D则不断提高甄别能力防止G合成的数据鱼目混珠, 理论上最终生成的数据分布Pg和真实的数据分布Pdata可以相等. 可以用式(1)概括整个GAN网络的优化函数:

$\begin{split} & \mathop {\min }\limits_\theta \mathop {\max }\limits_\phi ({E_{x \sim {p_{data}}(x)}}[\log D(x,\phi )] \\ & \;\;\;\;\;\;\;\;\;\;\;\;+ {E_{z \sim p(z)}}[\log (1 - D(G(z,\theta ),\phi ))]) \\ \end{split} $ (1)
1.2 条件生成对抗网络

GAN本质上属于无监督学习的范畴, 其判别网络仅仅输出数据真假的概率. 条件生成对抗网络(Conditional-GAN)[6]在GAN的基础上, 加入类别的信息Y, 从而可以生成指定类别的数据. Conditional-GAN的优化函数可以写成式(2):

$\begin{split} & \mathop {\min }\limits_\theta \mathop {\max }\limits_\phi ({E_{x \sim {p_{data}}(x)}}[\log D(x{\rm{|}}y,\phi )] \\ &\;\;\;\;\;\;\;\;\;\;\;\;+ {E_{z \sim p(z)}}[\log (1 - D(G(z|y,\theta ),\phi ))]) \\ \end{split} $ (2)

Conditional-GAN的判别器D仍然只有一个输出来判断真假, 而半监督学习生成对抗网络(Semi-GAN)[7]在Conditional-GAN 的基础上, 判别器输出增加到K+1个(K代表数据的类别个数), K个输出表示真实数据的分类概率, 第K+1个表示数据为假的概. Conditional-GAN和Semi-GAN的结构如图2所示.

图 2 Conditional-GAN与Semi-GAN结构对比

2 数据增强网络模型与算法

本文从数据增强的目的出发, 通过改进生成对抗网络的结构和训练算法, 设计了一种基于生成对抗网络的数据增强技术, 并提出了一种新的网络结构, 即数据增强生成对抗网络(Data Augmentation GAN, DAGAN). 与其他的GAN结构相比, 我们提出的网络结构更加适用于数据增强任务, 即生成的样本和原始数据真假难分的同时, 还可以做到类间可分, 从而有利于分类器在在合成的数据点上学习到分类界限. 在训练算法上, 本文将DAGAN的训练过程和分类器的训练过程相结合, 并提出一种新的损失函数, 称之为“2K”损失函数, 从而可以做到在线数据增强, 即数据处理和分类器训练可以在内存中同步处理, 不需要另外的数据存储空间.

2.1 基于GAN的数据增强网络

一般的GAN网络其判别器仅仅只有一个输出——判断输入的真假, 如果直接用来生成数据用来做数据增强是不可行的, 因为不能做到按类别生成样本. Conditional-GAN和Semi-Supervised GAN虽然可以利用数据的标签信息, 并且按照给定的类别生成相应的数据, 但是相关的研究工作表明这样的GAN结构其生成的样本多样性不足, 对数据增强的贡献十分有限. 因此, 需要针对我们数据增强的这一特定需求,即生成的数据有利于分类器学习更加紧凑的分类界限, 提升分类性能来设计网络结构. 基于以上考虑, 从生成网络的角度来看, 最优的判别网络需要:

1) 能够正确地将真实数据和生成数据分类;

2) 不能分辨数据是真实的还是合成的.

据此, 在GAN的基础上设计出适合于小规模数据增强任务的GAN网络结构, 即DAGAN. 结构如图3所示.

图 3 DAGAN网络结构

这里, 生成网络采用Conditional-GAN的结构, 隐向量z和类别信息y作为输入, 输出对应类别的数据; 判别网络的输入有两个来源——真实数据或者生成的数据, 输出则变为2K个, 前K个表示输入为真实数据K类的概率, 后K个表示输入为生成数据K类的概率.

可以看出, 就判别网络而言, 从GAN到Semi-Supervised GAN, 再到本文提出的DAGAN, 输出的维度不断增加, 同时应用的领域也更加广泛. 就生成网络而言, Conditional-GAN, Semi-Supervised GAN以及本文的DAGAN都利用了数据的标签信息, 可以根据指定的类别生成相应的数据. DAGAN在利用Conditional-GAN生成器结构的同时, 又增强了判别网络的判别能力, 使之适用于小规模数据集的增强. 表1总结了以上几种GAN网络的特点对比.

表 1 几种GAN网络的对比

2.2 数据增强网络的训练算法

DAGAN的训练分成两个阶段, 第一阶段为数据生成阶段. 生成网络和判别网络优化相反的目标函数, 在不断的对抗中达到平衡. 与GAN不同的是, 由于判别网络有2K个输出, 因此相应的损失函数也将发生改变, 称之为“2K”损失函数. 对于判别网络, 其损失函数如下:

$\begin{split} {L_C} =& - {E_{x.y \sim {p_{data}}(x, y)}}\log [p(y|x,y < k + 1)]\\ &-{E_{x,y \sim {p_g}(x, y)}}\log [p(y|x,k < y < 2k + 1)] \\ \end{split} $ (3)

对于生成网络, 除了对应的判别真假的损失函数之外, 还包括正则化项, 用来保证生成的数据和真实的数据在特征层面尽可能保持相近, 损失函数如(4)式所示:

$\begin{split} {L_G} = & - {E_{x,y \sim p{}_g(x, y)}}\log [p(y - k|x,k < y \leqslant 2k)] \\ &+ \lambda {L_{fm}} \end{split} $ (4)

其中, Lfm为正则化项, 具体形式如下:

$\begin{array}{l} {L_{fm}}\! =\! ||{E_{x,y \sim {p_{data}}(x, y)}}f(x|y) \!-\! {E_{z \sim {p_z}(z),y \sim {p_c}}}f(G(z,y|y))|| \!\!\!\!\!\! \end{array} $ (5)

这里f (x)函数判别网络中间某一层的输出, 即要求在相同类别的前提下, 生成数据和真实数据特征应当相近, 这进一步保证了生成数据和真实数据在同一类别下具有相同的语义.

第二阶段为分类训练阶段, 假设第一阶段训练完成之后, 生成网络已经学习到真实数据的分布. 因此在这一阶段, 生成网络将不再进行训练, 仅仅作为一个数据的提供者, 生成的数据和真实数据一起训练分类网络. 值得注意的是, 这里不需要单独搭建新的分类网络, 判别网络直接作为分类器进行训练. 由于判别网络有2K个输出, 这里规定第i个与第k+i个输出的概率之和表示输入为第i(i=1, 2, …, k)类数据的概率. 第二阶段的判别网络的损失函数由两部分构成, 分别是真实数据和生成数据:

$L_C' = {L_{data}} + {L_{gen}}$ (6)

其中,

$\begin{split} {L_{data}} =& - {E_{x,y \sim {p_{data}}(x, y)}}\log [p(y|x,y < k + 1) \\ & {\rm{ + }}p(y + k|x,y < k + 1)] \end{split} $ (7)
$\begin{split} {L_{gen}} = & - {E_{x,y \sim {p_g}(x, y)}}\log [p(y|x,k < y < 2k + 1) \\ & {\rm{ + }}p(y - k|x,k < y < 2k + 1)] \end{split} $ (8)

两个阶段均采用批量随机梯度下降的算法进行参数更新, 具体流程见算法1.

算法1. DAGAN批量随机梯度下降训练算法

输入: 第一阶段的迭代次数KG, 第二阶段的迭代次数KC, 训练集D, 测试集T, 批次数量B

1) 数据生成阶段训练: 分别采样真实数据(x, y)~Pdata(x, y), 以及隐向量数据z~P(z), 随机类别数据y~Pg. 在KG次迭代中, 采用随机梯度下降的方法, 交替更新生成网络和判别网络, 损失函数分别为LGLC.

2)数据分类阶段训练: 分别采样真实数据(x, y)~Pdata(x, y), 以及隐向量数据z~P(z), 随机类别数据y~Pg, 在KC次迭代中, 采用随机梯度下降的算法, 只更新判别网络, 损失函数为L’C.

3) 在测试集上测试判别网络的准确率.

3 实验分析

为了验证DAGAN的生成能力以及生成样本能否提升分类器的准确率, 我们分别在3个数据集上做了验证实验, 分别为CIFAR-10、SVHN以及KDEF数据集. 实验中的网络结构都是基于DCGAN[8]这个网络搭建, 详细的网络结构参数如表2所示.

表 2 CIFAR10实验网络参数与网络结构(SVHN与KDEF数据集实验与之类似)

这里需要说明: G, D, T-Conv, Conv, NIN, NL分别表示生成网络, 判别网络, 反卷积, 卷积, Network in Network, 非线性激活函数.

3.1 CIFAR-10数据集验证实验

CIFAR-10数据集总共包含60 000张RGB图片,其中50 000张为训练图片, 10 000张为测试集图片. 图片为32×32的分辨率, 总共可以分成10类. 为了探究各种数据增强方式对于不同程度的小规模数据集的影响, 我们人为地从该数据集中抽取不同数量的子数据集, 每类从50到1000不等. 实验主要对比以下几种不同的数据增强方式: (1)不采用任何的数据增强方式(C); (2)传统的基于仿射变换和图像操作的数据增强方式(C_aug); (3) GAN在每一类上分别训练, 然后每一类单独生成数据(Vanilla GAN); (4) Semi-Supervised GAN 生成数据(Semi GAN); (5)本文所提出的方法(DAGAN); (6)本文所提出的方法加上传统的数据增强技术(DAGAN_aug). 实验对比了不同方法下训练出来的分类器在测试集上的分类准确率(Acc), 结果见表3.

表 3 不同数据增强方式在CIFAR10数据集上测试集的准确率(%)

从实验结果可以看出, DAGAN_aug是所有方法中对分类器提升最显著的, 表明DAGAN可以在传统数据增强的基础上进一步提升模型的性能, 突破传统数据增强的瓶颈. 另外可以看出DAGAN在数据量较少的时候(每类图片数量小于500张)要优于Vanilla GAN和Semi GAN, 说明本文针对数据增强目的设计的DAGAN网络结构和训练算法更加有利于分类器的性能提升.

3.2 SVHN数据集验证实验

SVHN[9]是真实世界的街道门牌号码识别数据集, 每张图片代表0-9中的一个数字, 分辨率为32×32. 由于每种图片中可能包含不止一种数字, 而标签为中心的数字. 传统的数据增强方式例如翻转、移位等在这样的数据中将不能应用, 因为这些转换方式可能会改变图像的标签. 同样地, 表4给出了不同种数据增强方式在SVHN数据集上的性能对比. 实验仅仅考虑了3种数据增强方式的对比, 即(1)不采用任何的数据增强方式(C); (2)Semi-Supervised GAN生成数据(Semi GAN); (3)本文提出的方法(DAGAN).

表 4 不同数据增强方式在SVHN数据集上测试集的准确率(%)

实验结果和CIFAR10数据集是一致的, 在数据量较少的情况下, DAGAN能够最大程度的提升分类器的分类性能, 且优于Semi GAN的方法.有一点需要注意, 当数据量较多时(每类图片数为500张), Semi GAN和DAGAN两种方法几乎都不起作用, 这主要是因为对于相对比较简单的SVHN数据集, 当训练数据达到一定规模后, 限制网络性能的因素不再是数据, 而是分类网络的结构还有分类算法.

3.3 KDEF数据集验证实验

KDEF[10]数据集是一种人脸表情数据集, 包含35个男性和35个女性, 年龄在20至30岁之间. 没有胡须, 耳环或眼镜, 且没有明显的化妆. 7种不同的表情, 每个表情有5个角度. 总共4900张彩色图, 尺寸为562×762像素. 实验中我们仅采用正面角度, 因此只有490张图片, 根据表情进行分类.

本次实验生成网络的结构没有变化, 与表2类似,判别网络采用VGG-16,由于数据量过少, 因此我们采用的VGG-16是在ImageNet数据集上预训练过的. 实验对比了以下几种数据增强方式的性能: (1)不采用任何数据增强方式, 仅仅是预训练的分类器(C); (2) GAN在每一类上分别训练, 然后每一类单独生成数据(Vanilla GAN); (3) Semi-Supervised GAN生成数据(Semi GAN); (4)本文所提的方法(DAGAN). 实验结果如表5所示, 从结果来看, DAGAN依然是性能最好的结构, 同时说明DAGAN可以和预训练的策略相结合, 进一步提升分类器的性能, 突破数据增强技术的瓶颈.

表 5 不同数据增强方式在KDEF数据集上测试集的准确率(%)

3.4 生成图片展示

以上3个数据集的实验说明了DAGAN结构的可行性和有效性, 为了进一步表明DAGAN生成的图片和原始图片具有相同的语义, 而且呈现出内容上的多样性, 这一部分将展示3个数据集上DAGAN生成的数据样本, 并和原始数据相比较, 如图4所示.

从生成图片来看, CIFAR-10数据集每一行都是有着渐变的效果, 这是通过对隐变量z差值实现的; 而每一列都是一个不同的类别, 这是通过控制类别信息y实现的. SVHN数据集每一行都是属于相同的类别, 而每一列图片的z保持相同, 所以每一列的图片具有相同的风格. 以上都说明DAGAN生成的图片是可编辑的, 同时也可以看出生成的图像呈现比较丰富的多样性, 从而印证了DAGAN可以用于数据增强任务.

图 4 CIFAR-10数据集、SVHN数据集和KDEF数据集原始图片和生成图片对比

4 结论

由于深度神经网络在小规模数据集上难以训练,容易出现过拟合的问题, 本文提出一种基于生成对抗网络的数据增强技术, 通过在大量的实验, 以及和其他模型的对比, 验证了所提方法的可行性和有效性. DAGAN既可以有效提升分类器的分类性能, 同时生成的图像数据和真实数据相比具有语义的相似性和内容的多样性.

参考文献
[1]
Krizhevsky A, Sutskever I, Hinton GE. Imagenet classification with deep convolutional neural networks. Proceedings of the 25th International Conference on Neural Information Processing Systems. Lake Tahoe, NV, USA. 2012. 1097–1105.
[2]
Raina R, Battle A, Lee H, et al. Self-taught learning: Transfer learning from unlabeled data. Proceedings of the 24th International Conference on Machine Learning. Corvalis, OR, USA. 2007. 759–766.
[3]
Perez L, Wang J. The effectiveness of data augmentation in image classification using deep learning. arXiv: 1712.04621, 2017.
[4]
Kingma DP, Welling M. Auto-encoding variational Bayes. arXiv: 1312.6114, 2013.
[5]
Goodfellow IJ, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets. Proceedings of the 27th International Conference on Neural Information Processing Systems. Montreal, Quebec, Canada. 2014. 2672–2680.
[6]
Mirza M, Osindero S. Conditional generative adversarial nets. arXiv: 1411.1784, 2014.
[7]
Odena A. Semi-supervised learning with generative adversarial networks. arXiv: 1606.01583, 2016.
[8]
Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv: 1511.06434, 2015.
[9]
Netzer Y, Wang T, Coates A, et al. Reading digits in natural images with unsupervised feature learning. Proceedings of 2011 NIPS Workshop on Deep Learning and Unsupervised Feature Learning. Granda, Spain. 2011. 4.
[10]
Goeleven E, De Raedt R, Leyman L, et al. The Karolinska directed emotional faces: A validation study. Cognition and Emotion, 2008, 22(6): 1094-1118. DOI:10.1080/02699930701626582