近些年深度神经网络改变了人类的生活, 无论是在工业、科研、日常生活中都有着很高的应用需求, 然而这种深度神经网络模型对设备的要求苛刻, 通常需要大量的计算量和内存, 解决这一问题目前比较受欢迎的方法是知识蒸馏(knowledge distillation, KD)[1].
知识蒸馏作为一种模型压缩的技术, 通过将复杂模型(教师模型)的知识传递给简单模型(学生模型)来实现, 目的是在花费最小化代价的情况下尽量使其效果向复杂模型不断靠近, 以满足大多数设备的性能. 自vanilla-KD提出之后基于响应的(response-based)[2]知识蒸馏走近人们的视野, 随后基于特征的(feature-based)[3–7]、基于关系的(relation-based)[8–10]的知识蒸馏也被相继提出, 目前受到大家关注的方法大多数都是基于中间特征层的, 这些方法忽略掉了logit层的语义信息更高的特点, 解耦知识蒸馏(decoupled knowledge distillation, DKD)[11]的提出使得基于响应的知识蒸馏重新回到了SOTA行列.
解耦知识蒸馏(DKD)是一种基于响应的方法, 传统的知识蒸馏logit部分存在高度耦合, 这是导致logit蒸馏局限性的根本原因, 该方法通过将logit中的信息拆分为目标类(target class)和非目标类(non-target class)两部分, 来达到解耦的效果, 将经典的KD改为TCKD和NCKD, 从实验结果来看, TCKD只是起到了转移训练样本的“困难”知识, NCKD中转移了大量的暗知识, 这一部分才是logit蒸馏的主角, 将学生模型输出的logit不断地向教师模型的logit靠近, 然而, 这种方法忽略了一个问题, 当师生网络模型差距较大的时候, 即使通过上述方法, 将经典KD进行解耦, 效果也无法达到理想状态[12], 因为教师网络模型和学生网络模型之间的表征差距过大, 学生模型的特征无法对齐教师模型的特征.
为了解决上述的问题, 本文从一个新的角度出发, 认为解决这种体量差距较大的知识蒸馏的关键点是在于如何摒弃掉无用的噪声, 学生模型可以视作一个低配版的教师模型, 由于其体量较小, 学习能力相对较弱, 在训练过程中不能准确地分辨出哪些是需要学习的知识, 这种对噪声的提炼会导致训练效果退化. 本文采用扩散模型[13,14]来执行去噪模块, 先消除掉学生模型内部的噪声信息, 再进行蒸馏. 在这种方法下的学生模型更加贴合教师模型.
本文主要贡献如下.
1)本文提出了一个噪声自适应模块, 能够更精确的指定学生特征的噪声水平.
2)使用了一个由ResNet[15]中两个瓶颈块组成的轻量级扩散模型, 从实际情况出发搭建更加符合知识蒸馏的扩散模型.
3)解决了以前知识蒸馏中师生特征无法对齐的问题, 在利用解耦知识蒸馏在logits上可以获取低纬度语义信息的优势, 来优化模型效果.
1 相关工作 1.1 知识蒸馏知识蒸馏这一理论最早出现是在Hinton等人[1]所发表的vanilla-KD论文中, 采用的是一种师生模型网络, 将预训练的教师模型的知识通过网络传递给学生模型, 输出一般为模型的预测和中间特征. 在这个过程中学生模型的输出为
$ {L}_{\rm kd}: =d({F}^{(t)}{, }{F}^{(s)}) $ | (1) |
其中,
扩散模型是一种对图像增强和图像恢复的非线性的处理方法, 正向扩散过程遵循马尔可夫链的概念, 通过控制时间步t不断向图像中添加噪声, 然后利用预测和去噪逆转这一过程. 用
$ q\left( {{y_t}\mid{y_0}} \right) = \mathcal{N}\left( {{y_t};\sqrt {{{\bar \alpha }_t}} {y_0}, \left( {1 - {{\bar \alpha }_t}} \right)I} \right) $ | (2) |
其中, 定义
$ {\textit{z}} \sim \mathcal{N}\left( {{\textit{z}}, \mu , {\sigma ^2}I} \right) \to {\textit{z}} = \mu + \sigma \cdot \mathcal{E}, \mathcal{E} \sim \mathcal{N}\left( {0, I} \right) $ | (3) |
因此可以将
$ {y_t} = \sqrt {{{\bar \alpha }_t}{y_0}} + \sqrt {1 - {{\bar \alpha }_t}\mathcal{E}} $ | (4) |
在逆向扩散过程中需要训练一个神经网络
$ {\mathcal{L}_{\rm diff}}: = \left\| {{\mathcal{E}_t} - {\Phi _\theta }\left( {{y_t}, t} \right)} \right\|_2^2 $ | (5) |
在推理过程中, 初始噪声为
$ {p_\theta }\left( {{y_{t - 1}}\mid{y_t}} \right): = \mathcal{N}\left( {{y_{t - 1}};{\Phi _\theta }\left( {{y_t}, t} \right), \sigma _t^2I} \right) $ | (6) |
其中,
本文受到上述方法的启发, 利用扩散模型对学生模型进行去噪, 如图1 DDKD架构图所示, 学生模型的特征在初始阶段可视作带有噪声的教师特征, 然后利用教师特征训练扩散模型, 最后用训练好的扩散模型对学生特征进行去噪, 在得到无噪声的学生特征后进行蒸馏.
2 本文方法
在本节中将介绍本文所提出的基于扩散模型的解耦知识蒸馏, 首先KD中的特征对齐任务会被转换为扩散模型的去噪过程, 通过对齐师生模型的特征来获得更有效的蒸馏.
为了更好地提高计算效率, 本文引入了特征自编码器[17]来降低特征映射的维数, 从而简化了扩散过程. 此外, 本文还提出了自适应噪声匹配模块, 以提高学生特征的去噪性能. 在最后将特征的高语义信息进行蒸馏.
2.1 基于扩散模型的解耦知识蒸馏一般来说不同的模型有着不同的架构, 在特征的提取上也会有不同的关注点, 即使是在相同的数据集上训练也是如此. 教师模型体量较大, 在特征提取上比学生模型的效果好, 当教师模型与学生模型之间体量相差越大的时候, 他们的差距就会越加明显. 所以解决知识蒸馏问题的关键在于如何缩小这种师生差距. 在以前的论文中[18]也调查过教师模型和学生模型之间的差异. 教师的预测概率分布比学生的预测概率分布更加清晰, 而且教师模型的预测错误答案的方差也比学生模型要小[19], 以上这些说明, 教师模型的输出相对于学生模型来说要更加显著. 学生模型的预测比教师模型的预测包含了更多的噪声.
由于教师模型和学生模型之间的差距, 这些噪声无法仅凭通过简单的模仿蒸馏中的教师模型来消除. 但可以将教师模型和学生模型关注点更多地放在有价值的信息上面, 本文受到扩散模型的启发[20], 将学生模型视为教师模型的带噪版本, 先用教师特征去训练一个扩散模型, 然后再用它对学生特征进行去噪.
对于一个样本在蒸馏过程中教师特征表示为
然后将学生特征输入到学习扩散模型的迭代去噪过程中, 即式(6)中的
根据经验来说logits在语义水平上要比深层特征更高级, 所以本文中提到的方法是在整个模型架构的最后一部分使用解耦知识蒸馏, 将去噪后的学生输出
由于教师特征尺寸相对较大, 模型在去噪过程会耗费大量的计算资源, 需要转发T次(在本文方法中使用T = 5)噪声预测网络
自编码器只使用重构损失进行训练, 重构损失是原始教师
$ {\mathcal{L}_{\rm ae}}: = \left\| {{{\tilde F}^{\left( {\rm Tea} \right)}} - {F^{\left( {\rm Tea} \right)}}} \right\|_2^2 $ | (7) |
用于训练扩散模型的潜在教师特征
本文还使用卷积层将学生特征投影到与教师潜在特征
$ {\mathcal{L}_{\rm DDKD}}: = d\left( {{{\hat Z}^{\left( {\rm Stu} \right)}}, {Z^{\left( {\rm Tea} \right)}}} \right) $ | (8) |
本文使用简单的MSE损失和KL散度损失作为距离函数
如前文所述, 本文将学生特征视为教师特征的嘈杂版本. 然而, 表示教师和学生特征之间差距的噪声水平是未知的, 并且可能因不同的训练样本而变化. 因此, 不能直接确定应该从哪个初始时间步长开始扩散过程. 为了解决这个问题, 本文引入了一个自适应噪声匹配模块, 将学生特征的噪声水平与预定义的噪声水平相匹配.
如图2所示, 将构建一个简单的卷积模块来学习一个融合了学生输出和高斯噪声的权值γ, 这有助于学生输出能更快匹配与初始时间步长t的噪声特征的相同噪声水平. 因此, 去噪过程中的初始噪声特征变为:
$ Z_T^{\left( {\rm Stu} \right)} = \gamma {Z^{\left( {\rm Stu} \right)}} + \left( {1 - \gamma } \right) \in T $ | (9) |
这种噪声适应可以自然地通过KD损失
DDKD的整体损失函数由原始任务损失、优化扩散模型的扩散损失、学习自编码器的重建损失以及对教师特征和去噪学生特征进行蒸馏的KD损失组成, 即:
$ {\mathcal{L}_{\rm train}} = {\mathcal{L}_{\rm task}} + {\lambda _1}{\mathcal{L}_{\rm diff}} + {\lambda _2}{\mathcal{L}_{\rm ae}} + {\lambda _3}{\mathcal{L}_{\rm DDKD}} $ | (10) |
其中,
为了验证本文方法的有效性, 我们在ImageNet[21]和CIFAR-100[22]上进行实验, 首先会介绍本文的实验细节和相关数据集, 在实验部分会介绍我们的消融实验, 来展示本文方法的有效性.
3.1 数据集介绍ImageNet数据集, 是目前最大的图像识别数据库, 在分类、定位和检测任务中使用较多, 其中包含了
在本文中我们使用的是ImageNet的子数据集, 具有相同的效果, 训练集有
CIFAR-100是一个常用的图像数据集, 用于图像分类任务和计算机视觉研究. 图像内容丰富多样, 涵盖了各种日常物体和动物, 可以更好地检验算法的泛化能力, 该数据集总共包含
为了验证本文方法的可行性, 将在ResNet[15]、ShuffleNet[23]、MobileNet[24]、Wide ResNet (WRN)[25]和VGG[26]网络中进行对比.
对于ImageNet数据集, 本文的训练策略将Epoch设为100, Batchsize设为256, 学习率(LR)初始值设为0.1, 每30个Epoch再衰减为当前的0.1, 优化器(optimizer)使用的是随机梯度下降算法(SGD), 超参数权重衰退(weight decay)为
对于CIFAR-100 数据集, 本文的训练策略将Epoch设为240, Batchsize设为64, 学习率(LR)初始值设为0.05, 优化器选择随机梯度下降算法, 权重衰退为
为了验证不同方法对实验效果的影响, 本文进行了消融实验, 在基线设置上, 使用了ResNet-18和MobileNet V1作为学生模型, 教师模型分别为ResNet-34和ResNet-50网络. 表1为在ImageNet数据集上的实验结果. 可以发现vanilla-KD的效果无论是在同构网络还是异构网络中, 在TOP-1和TOP-5上的准确率都是最低的, 当使用DKD方法时, 准确率会有明显提升, 在MobileNet V1和ResNet-50这种设置下更为明显, 这种方法解决了传统知识蒸馏中的目标类与非目标类的高耦合问题, 比传统知识蒸馏TOP-1的准确率提高了1.37%. 本文的方法相对于目前最先进的方法(DKD)在ResNet-18和ResNet-34这种网络设置下TOP-1的准确率提升了0.64%, 在MobileNet V1和ResNet-50设置下TOP-1的准确率提升了1.82%, 这是因为在蒸馏前利用扩散模型对学生模型进行了去噪过程, 进行了特征对齐. 这个过程减少了学生与教师模型的差距.
3.4 对比实验
为了验证本文方法的有效性, 本文还与其他方法进行对比, 如表2、表3所示, 分别为同构网络和异构网络准确率对比的结果.
从结果上来看, 本文方法在异构网络中有更为显著的效果, 考虑到其他方法是通过添加中间网络, 将知识过渡给学生模型, 本文提出了一个更为可靠的方法从根本上解决师生的差距问题, 在蒸馏开始前利用轻量级的扩散模型解决了在训练前师生特征对齐的问题, 然后再利用解耦知识蒸馏的方法来解决传统知识蒸馏中目标类与非目标类高度耦合的问题.
在CIFAR-100数据集上, 同系列网络架构设置的情况下最好的结果比基准方法提升了0.61%, 准确率提高了2.6个百分点. 在不同系列的网络架构设置下, 最好的结果比基准方法提升了0.71%, 准确率提高了3.4个百分点.
4 结束语本文研究了教师和学生在知识蒸馏方面的差异. 从使用到的基准方法来看, 解耦知识蒸馏在logits部分将其分为目标类和非目标类确实有很大的提升, 但本文从开始就将学生与教师特征对齐, 从本质上缩小了师生之间的差距, 为了减少差异, 提高蒸馏性能, 本文从一个新的角度出发, 提出用扩散模型显式地消除学生特征中的噪声. 在此基础上, 进一步引入了一个带有线性自编码器的轻量级扩散模型来降低该方法的计算成本, 并引入了一个自适应噪声匹配模块来将学生特征与正确的噪声水平相匹配, 从而提高了去噪性能. 在图像分类任务上的大量实验验证了本文的有效性和泛化性.
[1] |
Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network. arXiv:1503.02531, 2015.
|
[2] |
Gou JP, Yu BS, Maybank SJ, et al. Knowledge distillation: A survey. International Journal of Computer Vision, 2021, 129(6): 1789-1819. DOI:10.1007/s11263-021-01453-z |
[3] |
Romero A, Ballas N, Kahou SE, et al. FitNets: Hints for thin deep nets. arXiv:1412.6550, 2014.
|
[4] |
Guo ZY, Yan HN, Li H, et al. Class attention transfer based knowledge distillation. Proceedings of the 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Vancouver: IEEE, 2023. 11868–11877.
|
[5] |
Chen ZH, Shamsabadi EA, Jiang S, et al. Robust feature knowledge distillation for enhanced performance of lightweight crack segmentation models. arXiv:2404.06258, 2024.
|
[6] |
Tian YL, Krishnan D, Isola P. Contrastive representation distillation. Proceedings of the 8th International Conference on Learning Representations. Addis Ababa: OpenReview.net, 2020.
|
[7] |
Chen PG, Liu S, Zhao HS, et al. Distilling knowledge via knowledge review. Proceedings of the 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Nashville: IEEE, 2021. 5008–5017.
|
[8] |
Yim J, Joo D, Bae J, et al. A gift from knowledge distillation: Fast optimization, network minimization and transfer learning. Proceedings of the 2017 IEEE Conference on Computer Vision and Pattern Recognition. Honolulu: IEEE, 2017. 4133–4141.
|
[9] |
Lee SH, Kim DH, Song BC. Self-supervised knowledge distillation using singular value decomposition. Proceedings of the 15th European Conference on Computer Vision. Munich: Springer, 2018. 335–350.
|
[10] |
Park W, Kim D, Lu Y, et al. Relational knowledge distillation. Proceedings of the 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Long Beach: IEEE, 2019. 3967–3976.
|
[11] |
Zhao BR, Cui Q, Song RJ, et al. Decoupled knowledge distillation. Proceedings of the 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition. New Orleans: IEEE, 2022. 11953–11962.
|
[12] |
Huang T, You S, Wang F, et al. Knowledge distillation from a stronger teacher. Proceedings of the 36th International Conference on Neural Information Processing Systems. New Orleans, 2022. 33716–33727.
|
[13] |
Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. Proceedings of the 34th International Conference on Neural Information Processing Systems. Vancouver: Curran Associates Inc., 2020. 574.
|
[14] |
Ponglertnapakorn P, Tritrong N, Suwajanakorn S. DiFaReli: Diffusion face relighting. Proceedings of the 2023 IEEE/CVF International Conference on Computer Vision. Paris: IEEE, 2023. 22646–22657.
|
[15] |
He KM, Zhang XY, Ren SQ, et al. Deep residual learning for image recognition. Proceedings of the 2016 IEEE Conference on Computer Vision and Pattern Recognition. Las Vegas: IEEE, 2016. 770–778.
|
[16] |
Song JM, Meng CL, Ermon S. Denoising diffusion implicit models. arXiv:2010.02502, 2020.
|
[17] |
Hinton GE, Salakhutdinov RR. Reducing the dimensionality of data with neural networks. Science, 2006, 313(5786): 504-507. DOI:10.1126/science.1127647 |
[18] |
Kundu S, Sun QR, Fu Y, et al. Analyzing the confidentiality of undistillable teachers in knowledge distillation. Proceedings of the 35th Conference on Neural Information Processing Systems. 2021. 9181–9192.
|
[19] |
Li XC, Fan WS, Song SM, et al. Asymmetric temperature scaling makes larger networks teach well again. Proceedings of the 36th International Conference on Neural Information Processing Systems. New Orleans: Curran Associates Inc., 2022. 277.
|
[20] |
Rombach R, Blattmann A, Lorenz D, et al. High-resolution image synthesis with latent diffusion models. Proceedings of the 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition. New Orleans: IEEE, 2022. 10684–10695.
|
[21] |
Russakovsky O, Deng J, Su H, et al. ImageNet large scale visual recognition challenge. International Journal of Computer Vision, 2015, 115(3): 211-252. DOI:10.1007/s11263-015-0816-y |
[22] |
Krizhevsky A. Learning multiple layers of features from tiny images. Techinical Report, Toronto: University of Toronto. 2009.
|
[23] |
Zhang XY, Zhou XY, Lin MX, et al. ShuffleNet: An extremely efficient convolutional neural network for mobile devices. Proceedings of the 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City: IEEE, 2018. 6848–6856.
|
[24] |
Sandler M, Howard A, Zhu ML, et al. MobileNetV2: Inverted residuals and linear bottlenecks. Proceedings of the 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City: IEEE, 2018. 4510–4520.
|
[25] |
Zagoruyko S, Komodakis N. Wide residual networks. arXiv:1605.07146, 2016.
|
[26] |
Simonyan K, Zisserman A. Very deep convolutional networks for large-scale image recognition. arXiv:1409.1556, 2014.
|
[27] |
Ahn S, Hu SX, Damianou A, et al. Variational information distillation for knowledge transfer. Proceedings of the 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Long Beach: IEEE, 2019. 9163–9171.
|
[28] |
Passalis N, Tzelepi M, Tefas A. Probabilistic knowledge transfer for lightweight deep representation learning. IEEE Transactions on Neural Networks and Learning Systems, 2021, 32(5): 2030-2039. DOI:10.1109/TNNLS.2020.2995884 |
[29] |
Tian YL, Krishnan D, Isola P. Contrastive representation distillation. arXiv:1910.10699, 2019.
|