注意力机制在图像和自然语言处理方面有着广泛的应用. Zhang等[5]提出的SAGAN首次将自我注意力机制与GAN结合, 减少参数计算量的同时, 也聚焦了更多的全局信息. Fu等[6]提出双重注意力机制, 在空间和通道两个维度进行特征融合, 用于语义分割. Tang等[7]结合双重注意力机制, 应用于语义图像合成.
受到以上实验的启发, 针对出现的问题, 我们提出结合双重注意力机制的端到端模型, 该模型基于StackGAN++基本结构, 以双重注意力机制去最大化融合文本和图像的特征, 树状结构生成低到高分辨率(
该模型旨在提高生成图像的全局真实度. 全局真实度指图像内容的完整度, 颜色的鲜明度, 场景的对比度和亮度符合人眼视觉感知的程度.
我们在CUB[9]鸟类数据集上验证了该方法, 并使用IS和SSIM指标判定生成图像的多样性、质量和全局真实度. 实验结果分析证明, 与原有技术相比, 我们模型生成的图像一定程度上呈现了更多的鸟类特征, 并提升了整体的亮度和颜色鲜明度, 使生成图像感知上更加接近于真实图像.
1 模型及方法 1.1 模型结构本文网络结构整体框图如图1所示. 结构主要由文本编码器、2个生成器、3个判别器和VGG19网络构成. 文本编码器使用文献[10]中提供的字符级编码器(char-CNN-RNN), 生成器采用前后级联的方式, 第一个生成器包含1个全连接层和4个上采样层, 第二个生成器包含连接层, 空间和通道注意力模块, 2个残差网络[11]和1个上采样层. VGG19网络作为额外约束, 判别生成图像和真实图像的相似度.
网络大致分为两个阶段, 每个阶段都包含多个输入, 如式(1)所示:
$ \left\{ \begin{gathered} {c_i} = \varphi ({t_i}) \hfill \\ {C_i} = {F_{ca}}({c_i}) \hfill \\ {f_0} = {F_0}(Z) \hfill \\ {f_i} = {F_i}({f_{i - 1}}, {C_i}) \hfill \\ {I_i} = {G_j}({f_i}), \; j \in \left\{ {1, 2} \right\} \hfill \\ \end{gathered} \right. $ | (1) |
其中,
由于图像像素区域和文本存在对应关系, 不同通道存在依赖关系, 我们引入空间和通道注意力机制, 输入为文本向量和低分辨率特征的融合矩阵, 引导生成器更多关注整体特征的关联性和匹配度. 由于高分辨率图像是在低分辨率图像的基础上进行细化, 所以低分辨率图像的好坏决定着最终输出的好坏. 虽然低分辨率图像更加的模糊, 缺少细节, 但是却保留着更多的全局特征. 所以我们将机制放置在G1的连接层后, 即残差模块前, 引导生成器在低分辨率维度上关注更多的全局特征. 注意力机制模块如图2, 图3所示.
$ \left\{ \begin{gathered} \alpha = {\omega _q}\omega _k^{\rm{T}}, \; \alpha \in {R^{C \times C}} \hfill \\ \beta = {\textit{Softmax}} (\alpha ) \hfill \\ \gamma = \beta {\omega _v}, \; \beta \in {R^{C \times H \times W}} \hfill \\ H = \sigma \gamma + h \hfill \\ \end{gathered} \right. $ | (2) |
其中,
空间注意力机制忽略了通道间的语义关联性, 关注像素间的特征信息, 运算与通道注意力机制类似. 两个模块输出最后从通道维度进行拼接, 得到最终的结果.
1.1.2 VGG19
增强型超分辨率生成对抗网络(ESRGAN)[12]中指出, 使用VGG19的第5个maxpool层前的最后一层卷积层去提取图像特征, 使得生成图像特征在亮度和颜色感知上更接近于真实图像. 受其启发, 我们引入VGG19的前35层网络层进行预训练处理, 用来提取生成图像和真实图像的特征, 求取两者的L1损失, 作为生成图像真实度的判别约束.
1.2 时间复杂度空间注意力模块输入
生成器损失包含非条件损失和条件损失两部分. 非条件损失用来判别图像是真实的或是虚假的; 条件损失用来判别图像和文本是否匹配.
$ \begin{array}{l}{L}_{G}=-\underset{非条件损失}{\underbrace{{{E}}_{{I}_{i}\sim {p}_{{G}_{i}}}\left[\mathrm{log}{D}_{i}({G}_{j}({f}_{i}))\right]}}- \underset{条件损失}{\underbrace{{E}_{{I}_{i}\sim {p}_{{G}_{i}}}\left[\mathrm{log}{D}_{i}({G}_{j}({f}_{i}), C)\right]}}\end{array} $ | (3) |
其中,
两个生成器对应两个尺度的图像分布生成, 各自后面接一个判别器. 不同尺度生成图像送入判别器中, 计算交叉熵损失, 返回真假概率和图像文本匹配概率. 生成器
判别器损失包含非条件损失、条件损失和真实度损失3部分.
$ \begin{array}{l}{L}_{D} = -\underset{非条件损失}{\underbrace{{{E}}_{{R}_{i}\sim {p}_{{{\rm{data}}}_{i}}}[\mathrm{log}{D}_{i}({R}_{i})]-{E}_{{I}_{i}\sim {p}_{{G}_{i}}}[\mathrm{log}(1-{D}_{i}({G}_{j}({f}_{i})))]}}\\ \underset{条件损失}{\underbrace{-{{E}}_{{R}_{i}\sim {p}_{{{\rm{data}}}_{i}}}[\mathrm{log}{D}_{i}({R}_{i}, C)]-{E}_{{I}_{i}\sim{p}_{{G}_{i}}}[\mathrm{log}(1-{D}_{i}({G}_{j}({f}_{i}), C))]}}\\ -\underset{真实度损失}{\underbrace{\mu \text{L}1}}\end{array} $ | (4) |
$ {{L}}1 = {E_{{f_i}}}{\left\| {G({f_i}) - {R_i}} \right\|_1} $ | (5) |
其中, L1表示真实度损失. 由VGG19提取真实图像和不同尺度图像的特征空间, 送入判别器计算L1范数距离损失, 通过最小化损失, 达到优化效果.
非条件损失分别计算真实图像、各个尺度生成图像的交叉熵损失, 优化判别器判别真假的能力. 条件损失采用正负对比计算, 正计算包括真实图像和对应标签, 生成图像和对应标签两个组合, 负计算指真实图像和不对应标签. 通过正负对比学习, 优化判别器判别图像文本匹配能力.
2 实验结果和分析 2.1 实验环境本文实验基于搭载GTX1070i显卡的CentOS 7操作系统, 使用Python 2.7编程语言, PyTorch框架.
实验设置训练过程中生成器和判别器学习率为0.0001, batch_size为8, 迭代次数为160次.
2.2 实验数据集及评估指标 2.2.1 数据集2.2.2 评估指标
本文采用Inception Score (IS)和SSIM作为评估标准. IS基于预先在ImageNet数据集[13]上训练好的Inception V3网络. 其计算公式如下:
$ IS(G) = \exp ({E_{x \sim {p_g}}}{D_{\rm{KL}}}(p(y|x)\parallel p(y))) $ | (6) |
其中,
公式表明, IS评估生成图像的多样性和质量, 好的模型应该生成清晰且多样的图像, 所以边际分布
SSIM (structural similarity), 结构相似性度量指标, 已被证明更符合人眼的视觉感知特性. 我们用其评估生成图像的真实度. SSIM包含亮度、对比度、结构3个度量模块. 其计算公式如下:
亮度对比函数:
$ l(x, y) = \frac{{2{\mu _x}{\mu _y} + {C_1}}}{{\mu _x^2 + \mu _y^2 + {C_1}}} $ | (7) |
对比度对比函数:
$ c(x, y) = \frac{{2{\sigma _x}{\sigma _y} + {C_2}}}{{\sigma _x^2 + \sigma _y^2 + {C_2}}} $ | (8) |
结构对比函数:
$ s(x, y) = \frac{{{\sigma _{xy}} + {c_a}}}{{{\sigma _x}{\sigma _y} + {c_a}}} $ | (9) |
最后把3个函数组合起来得到SSIM指数函数:
$ {\textit{SSIM}}(x, y) = {\left[ {l(x, y)} \right]^\alpha }{\left[ {c(x, y)} \right]^\beta }{\left[ {s(x, y)} \right]^\gamma } $ | (10) |
为了节省内存占用率, 我们将StackGAN++缩减为两个阶段, 生成
由图7可以很明显观察到, StackGAN++模型生成的
我们列举以往不同模型在CUB数据集上的IS值, 进行一个对比, 见表2. 我们所提方法评估的IS值能够达到5.4, 高于所比较的以往模型.
为了定量地评估我们模型对真实度提升的贡献, 我们用SSIM指标在生成图像和真实图像做相似性评估, 在StackGAN++模型和我们模型做了对比实验, 见表3.
由表3看出, 相同模型下, 更高分辨率的生成图像具有更高的SSIM值, 符合图像质量提升导致真实度提升的逻辑. 以此为前提, 对比不同模型在相同分辨率的SSIM值, 我们的模型值更高, 则图像真实度相比更高. 结合实验结果图来看, 我们模型生成的图像人眼感知与真实图像样本也更加相似.
3 结论本文提出一种以堆叠式结构为基础, 着重关注图像全局特征真实度的生成对抗网络, 应用于文本生成图像任务. 实验结果证明, 同以往的模型对比, 结果图像更加专注于全局特征, 颜色的鲜明度和整体视觉效果更加具有真实感, 更接近于真实图片. 这是因为我们引入双重注意力机制引导图像学习对应文本的更多特征; 使用真实感损失约束, 提高生成图像的真实感. 在文本单词向量级别, 增添图像子区域的细节, 提升文本和图像的语义一致性, 应用于更加复杂的数据集, 会是接下来研究的一个方向.
[1] |
Goodfellow IJ, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets. Proceedings of the 27th International Conference on Neural Information Processing Systems. Montreal, Quebec: NIPS, 2014. 2672–2680.
|
[2] |
Reed S, Akata Z, Yan XC, et al. Generative adversarial text to image synthesis. Proceedings of the 33rd International Conference on Machine Learning. New York: JMLR.org, 2016. 1060–1069.
|
[3] |
Zhang H, Xu T, Li HS, et al. StackGAN: Text to photo-realistic image synthesis with stacked generative adversarial networks. 2017 IEEE International Conference on Computer Vision (ICCV). Venice: IEEE, 2017. 5908–5916.
|
[4] |
Zhang H, Xu T, Li HS, et al. StackGAN++: Realistic image synthesis with stacked generative adversarial networks. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2019, 41(8): 1947-1962. DOI:10.1109/TPAMI.2018.2856256 |
[5] |
Zhang H, Goodfellow I, Metaxas D, et al. Self-attention generative adversarial networks. Proceedings of the 36th International Conference on Machine Learning. Long Beach: PMLR, 2019. 7354–7363.
|
[6] |
Fu J, Liu J, Tian HJ, et al. Dual attention network for scene segmentation. 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Long Beach: IEEE, 2019. 3146–3154.
|
[7] |
Tang H, Bai S, Sebe N. Dual attention GANs for semantic image synthesis. Proceedings of the 28th ACM International Conference on Multimedia. Seattle: ACM, 2020. 1994–2002.
|
[8] |
Simonyan K, Zisserman A. Very deep convolutional networks for large-scale image recognition. 3rd International Conference on Learning Representations. San Diego: ICLR, 2015.
|
[9] |
Wah C, Branson S, Welinder P, et al. The caltech-UCSD birds-200-2011 dataset. California: California Institute of Technology, 2011.
|
[10] |
Reed S, Akata Z, Lee H, et al. Learning deep representations of fine-grained visual descriptions. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Las Vegas: IEEE, 2016. 49–58.
|
[11] |
He KM, Zhang XY, Ren SQ, et al. Deep residual learning for image recognition. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Las Vegas: IEEE, 2016. 770–778.
|
[12] |
Wang XT, Yu K, Wu SX, et al. ESRGAN: Enhanced super-resolution generative adversarial networks. Leal-Taixé L, Roth S. Computer Vision—ECCV 2018 Workshops. Cham: Springer, 2018. 63–79.
|
[13] |
Russakovsky O, Deng J, Su H, et al. ImageNet large scale visual recognition challenge. International Journal of Computer Vision, 2015, 115(3): 211-252. DOI:10.1007/s11263-015-0816-y |