林业资源对维护生态环境和促进国家经济发展具有重要意义, 而林业病害防治是林业发展和建设过程中一项至关重要的基础性工作, 精准的诊断林业病害可以及时减少林业病害对经济带来的损失. 传统的林业病害识别工作主要以人工调查为主, 相关工作人员不仅需要具备林业领域相关知识, 并且整个工作流程存在成本高、数据存储量大、耗时久、效率低等问题, 无法保证林区病害检测的轻量级、时效性和准确性.
随着计算机视觉技术的飞速发展, 越来越多的学者开始关注于运用深度学习相关技术对林业病害进行识别, 并且取得了不错的成果. 如李浩等[1]提出的基于深度学习的松线虫病害松木识别, 吴云志等[2]提出的一种植物病害图像识别卷积网络架构, 牟文芊等[3]提出的基于SENet和深度可分离卷积胶囊网络的茶树叶部病害图像识别, Amara等[4]提出的基于LeNet的卷积神经网络模型对香蕉病害进行识别以及Brahimi等[5]采用迁移学习的方法, 利用GoogLeNet和AlexNet模型识别PlantVillage数据集中的西红柿病害图像. 以上的深度学习模型需要大量的数据样本才能训练出一个有着良好精确度的分类器, 并且往往只针对某一种林业病害. 目前网络上关于林业病害相关的标准公开数据集较少, 一些不常见的林业病害往往只能获取到几张病害图片数据, 这会导致模型在训练阶段遭遇过拟合问题. 因此, 如何利用少量的数据, 让深度学习模型能够有效学习和泛化, 成为当前亟需解决的问题.
针对上述问题, 本文采用了小样本学习的方法, 提出了一种基于深度相互学习策略的元基线模型(deep mutual learning-meta-baseline, DML-MB)对多种林业病害进行识别, 其实验结果表明, DML-MB模型能够在少量样本的情况下实现较好的分类精度. 从而为林业病害识别领域提供了一种全新的解决方法.
1 材料和方法 1.1 材料本文使用的数据集由宜宾市林业局提供, 包含了黄葛树、小叶榕、马尾松等8种宜宾市范围内常见树种. 首先, 为了得到一个较为标准的林业病害数据集, 对原始数据集按照病害种类进行二次分类, 最终得到了一个包含7个病害种类的数据集. 其次对数据集进行数据增强和归一化处理, 对每类图片进行水平翻转、随机旋转、添加噪声等操作增加每类样本数量. 最终得到了一个新的林业病害数据集, 该数据集包含7种病害类, 每个类拥有30张图像, 样本总量为210张. 林业病害各类数据见表1.
部分林业病害图像如图1所示.
1.2 方法
小样本学习主要利用训练任务之间的共性, 通过学习少量的标签样本后获得一个有效的分类器, 从而让模型具有学会学习的能力, 对于新的类别, 只需要少量的样本就能够实现较好精度的识别. 小样本学习也是元学习在监督学习领域的应用. 元学习是一种高层次的跨任务学习策略, 可以天然地应用在各种小样本学习模型中. 目前, 常见的基于元学习框架的小样本学习策略[6, 7]有以下3种: 采用度量学习的策略、采用外部记忆的策略, 以及采用参数优化的策略.
采用度量学习的策略指学习一个成对相似性度量, 其代表了度量支持集样本和查询集样本之间相似性的度量模块, 它可以是一个距离度量也可以是一个可学习的网络. Vinyals等[8]提出的匹配网络和Snell等[9]提出的原型网络采用了固定的度量, Sung等[10]提出的关系网络则采用了可学习的CNN来评估成对样本之间的相似性.
采用外部记忆的策略是指在模型中添加一个额外的记忆模块来保存从支持集中提取出来的特征信息, 帮助网络进行学习, 从而辅助后面的学习任务. 如Munkhdalai等[11]提出了一种带有外部记忆模块的网络, Ravi等[12]提出了将LSTM和元学习相结合的优化算法.
采用参数优化的策略是指通过一个学习算法来优化特定任务的模型. 这种策略可以为模型元学习到一个好的初始化参数, 使得模型能够在少量样本下, 几次迭代就能够很好地适应新的任务. 如Finn等[13]提出了元学习MAML算法, 该算法通过跨任务训练策略为基础学习其找到一个良好的初始化参数, 从而模型能够更好地进行训练.
上述的基于元学习框架的小样本学习策略在图像领域已经取得了一定的成果, 但是随着Chen等[14]基于元学习提出了一种新的元基线模型, 我们发现其表现出来的效果明显优于之前的几种小样本学习策略. 所谓的基线就是指采用特征提取器+线性分类器的组合, 或者特征提取器+距离度量分类器组合, 而元基线模型就是将基线组合套入元学习框架中. 但是, 在使用该元基线模型来对本文的小样本任务进行识别时, 发现其分类器识别精度不高, 不能够充分地提取到图像特征, 模型还有上升的空间. 针对该问题, 本文引入了一种新的分类器改进, 即深度相互学习策略, 并提出了DML-MB模型. 该模型结合了元基线模型和深度相互学习策略的思想, 提升了深度神经网络的泛化性, 获得了更多的图像特征, 解决了模型在少量样本训练下容易过拟合的问题, 识别精度得到了有效提升.
2 模型构建 2.1 模型建立本文提出的DML-MB模型分为两个阶段: 第1个是预训练阶段, 即在一个带有大量标签的基类数据集上训练出一个分类器, 并删除该分类器中的全连接层作为元学习阶段的特征提取器. 第2个是元学习阶段, 在这个阶段新类图像数据会在元学习框架中进行训练. 本文使用CIFAR-100作为基类数据集, 新类数据集是在林业局获取并整理后的林业病害数据集.
2.1.1 分类器模型传统的分类器模型在训练阶段需要大量数据, 其训练效果的优劣取决于是否能提取到充分的特征, 并且其模型参数量大, 执行可能缓慢, 需要大量内存进行存储. 为了解决以上的问题, Hinton等[15]在2015年提出了知识蒸馏算法, 即利用一个预训练好的教师模型向一个未训练的学生模型进行单向知识转移. 实验表明学生网络通过模仿教师网络的类别概率, 优化过程变得更为容易, 而且能够表现出与教师网络相近甚至更好的性能. 但知识蒸馏算法需要提前预训练好的教师网络, 仅对学生网络进行单向的知识传递, 难以在学生网络学习的过程中获得反馈信息来对训练过程进行优化调整.
在本文的预训练阶段, 分类器模型引入了由Zhang等[16]提出的深度相互学习(deep mutual learning, DML)策略. 该策略采用多个网络同时进行训练, 每个网络之间相互进行学习, 在训练过程中不仅接受真值标记的监督, 还参考了同伴网络的学习经验来进一步提升泛化能力. 每个网络由常规的监督学习损失和拟态损失来共同训练, 监督学习损失是为了度量网络预测目标类别与真实标签之间的差异, 拟态损失采用KL散度来度量两个网络预概率分布之间的差异. 分类器模型训练的具体流程如图2所示.
给定M个类中的N个样本
$ ^{}p_1^m\left( {{x_i}} \right) = \frac{{{\rm{exp}}\left( {{\textit{z}}_1^m} \right)}}{{\displaystyle\mathop \sum \nolimits_{m = 1}^{{M}} {\rm{exp}}\left( {{\textit{z}}_1^m} \right)}}\;\;\; $ | (1) |
其中,
Zhang等[16]的研究表明, DML策略提供了一种简单有效的方法, 通过与其他网络协同训练提高了模型的泛化能力, 相较于传统分类器模型, 该策略能够提取到更多的图像特征.
2.1.2 DML-MB模型在进入元学习模型之前, 需要将训练之后的分类器网络去掉全连接层, 得到特征提取器. 采用元学习的策略, 将林业病害数据集中的4个类划分成训练集, 3个类划分成测试集. 该策略的好处在于, 模型在支持集的帮助下能够更好地学习查询集与标签之间的映射关系. 在元学习训练阶段, 每次会在训练集中采样得到不同任务, 在每个任务中, 从训练集中选择出N个类, 再从N类中选取K个样本(即N-way K-shot)构成了支持集Nsupport, 查询集Nquery会在N个类中的剩余样本数据中采样得到. 采用任务机制使得模型学会不同任务中的共性部分, 让模型泛化性得到增强. 元训练阶段具体过程为: 首先, 将支持集的样本输入到特征提取器中提取出样本的特征. 然后, 计算支持集中每个类样本的平均特征, 同时在查询集也抽取一定数量的样本进行特征提取. 接下来使用余弦相似度分别计算查询集和支持集之间样本的相关程度. 最后将计算出的相关度分数与从查询集抽取的样本标签进行对比. 元学习模型框架如图3所示. 在元测试阶段, 按照同样的方法, 在测试集上抽取支持集和查询集用于二次训练[17]. 经过了上述两个阶段, 模型对于病害特征的识别性能得到了增强.
2.2 损失函数
本文采用DML策略的分类器模型在学习过程中有两个损失函数, 一个是传统的监督损失函数, 一个是网络间的拟态损失函数. 对于多类别分类, 将训练网络
$ {L_{{C_1}}} = - \mathop \sum \limits_{i = 1}^{{N}} \mathop \sum \limits_{m = 1}^{{M}} I\left( {\left. {{y_i}, m} \right)} \right.{\text{ln}}\left( {p_1^m\left( {{x_i}} \right)} \right) $ | (2) |
带有一个指标函数I, 定义见式(3):
$ I\left( {{y_i}, m} \right) = \left\{ {\begin{array}{*{20}{c}} 1,&{{y_i} = m} \\ 0,&{{y_i} \ne m} \end{array}} \right. $ | (3) |
其中,
拟态损失函数采用KL散度来度量两个网络
$ {D_{{\rm{KL}}}}\left( {{p_2}||{p_1}} \right) = \mathop \sum \limits_{i = 1}^N \mathop \sum \limits_{m = 1}^M p_2^m\left( {{x_i}} \right)\ln\frac{{p_2^m\left( {{x_i}} \right)}}{{p_1^m\left( {{x_i}} \right)}} $ | (4) |
最后, 网络
$ {L_{{\theta _1}}} = {L_{{C_1}}} + {D_{{\rm{KL}}}}\left( {{p_2}||{p_1}} \right) $ | (5) |
类似地, 网络
$ {L_{{\theta _2}}} = {L_{{C_2}}} + {D_{{\rm{KL}}}}\left( {{p_1}||{p_2}} \right) $ | (6) |
元学习框架的损失函数是由每个训练任务的损失函数一起构成的, 所以需要计算每个任务的损失函数. 首先, 在支持集中计算N个类的质心, 质心的定义见式(7):
$ {W_c} = \frac{1}{{|{S_c}|}}\sum\nolimits_{x \in {S_c}} {{F_\theta }} (x) $ | (7) |
然后, 计算查询集中每个样本的预测概率分布, 具体定义见式(8):
$ p\left( {y = c|x} \right) = \frac{{{\rm{exp}}\left( { < {F_\theta }\left( x \right), {W_c} > } \right)}}{{\displaystyle\mathop \sum \nolimits_{c'} {\rm{exp}}\left( { < {F_\theta }\left( x \right), {W_{c'}} > } \right)}} $ | (8) |
其中, S为支持集,
最后, 损失函数为p和查询集样本标签计算的交叉熵损失, 具体定义见式(9):
$ loss = - \frac{1}{N}\mathop \sum \limits_{i = 1}^N \mathop \sum \limits_{k = 1}^K {y_{i, k}}\ln\left( {{p_{i, k}}} \right) $ | (9) |
本实验使用的是Linux操作系统, 采用NVIDIA的RTX-A4000 16 GB显卡、CUDA 11.0、Python 3.7、PyTorch 1.7.
具体实验分为3个部分: 采用DML策略的分类器网络, 与未采用DML策略的分类器网络、采用知识蒸馏的分类器网络进行实验精度的对比; 使用本文构建的DML-MB模型实现林业病害任务的识别, 并对比不同分类器网络对实验精度影响; 与目前主流的小样本学习模型以及元基线模型进行分类精度对比.
3.1 分类器网络实验及分析分类器网络采用了DML的策略, 为了对比不同网络组合对于DML-MB模型分类精度的影响, 分别选择了ResNet-32, MobileNetV3-Large, WRN-28-10 (表2和表3中用Res、Mob、WRN表示)网络进行实验, 并在CIFAR-100数据集上进行精度对比. 训练的基本设置为: 迭代200次, 初始学习率为0.1, 动量为0.9, 学习率每60个epochs下降0.1. 实验的具体结果如表2、表3所示.
由表2可知, 相较于单独用一个网络作为分类器而言, 分类器网络间采用DML策略后识别精度有着明显的提升, 精度提升最大的一组甚至达到3.82%. 这表明DML策略是有效的, 我们关于分类器使用多种网络协调训练能够提升精度的设想可行性较大.
经研究发现, DML策略与知识蒸馏策略有异曲同工之妙, 因此我们进一步对比了两种方法. 由表3结果可知, 采用WRN-28-10、MobileNetV3-Large作为教师网络蒸馏ResNet-32学生网络后, 学生网络的精度分别提升了0.43%和0.16%, 而采用DML策略后ResNet-32网络精度分别提升了1.64%和2.02%, 这表明DML策略在精度提升上明显优于知识蒸馏策略, 我们猜测是因为知识蒸馏是教师网络向学生网络知识的单向传递, 教师网络难以在学生网络学习的过程中获得反馈信息来对训练过程进行优化调整, 所以精度提升效果不佳. 而DML策略让网络间知识进行相互传递学习, 能够很好地解决了以上的问题.
3.2 DML-MB模型实验本次实验采用的数据是由林业局获取并整理后的病害数据集. 实验方法采用小样本学习中最常采用的1-shot和5-shot方式.
进行实验时, 选择采用了DML策略训练后精度最高的ResNet32、WRN-28-10、MobileNetV3-Large网络去掉全连接层作为元基线模型的特征提取器, 并用支持集数据进行fine-tuning. DML-MB具体的训练设置为: 迭代20次, 每个任务的批处理量为4, 学习率为0.001, 学习动量为0.9, 权重衰减为1E–4, 优化器选择SGD. 最后, 我们列出模型在1-shot和5-shot上的测试精度, 如表4所示.
由表4可知, 采用WRN-28-10网络作为特征提取器的模型精度显著高于另外两者, 因此, 可用其进行更深层次的分类精确度研究. 但考虑到WRN-28-10的参数量高达38.6 M, 并不能很好地满足林业病害识别领域轻量化的需求, 综合考虑下, 我们选择参数量仅有3.78 M, 而精度却只稍逊WRN一筹的MobileNetV3-Large作为元基线模型的特征提取器进行实验. 模型的训练精度如图4所示.
由图4可知, 训练后的DML-MB模型1-shot和5-shot的训练精度分别为61.38%以及73.56%.
3.3 与主流小样本学习模型及元基线模型对比试验及分析在本节中, 主要选择了两种应用广泛的小样本模型—原型网络、匹配网络以及元基线模型与本文提出的模型进行对比实验. 在进行实验时, 4种模型训练的基本设置均为迭代100次, 学习率为0.001, 并采用随机梯度下降法. 最后, 实验对比结果如表5所示.
由表5可知, 匹配网络和原型网络在林业病害数据集上1-shot和5-shot识别精度皆低于元基线网络, 而本文提出的改进相较于元基线模型在5-shot上的识别精度提升了1.82%, 这表明在深度学习小样本领域, 无法从数据层面提供可靠支持的情况下, 我们可以使用不同网络协同训练、相互学习, 来提升模型的泛化能力, 这不失为小样本领域提高精度的一个有效解决方案.
4 结论与展望本文针对传统林业病害识别领域存在可用数据少、分类准确度低等问题, 建立了林业病害图像数据集, 并提出了一种基于小样本学习的改进元基线模型, DML-MB模型, 该模型引入了DML策略, 提升了深度神经网络的泛化性, 提取到了更丰富的图像特征. 模型在林业病害图像数据集上1-shot和5-shot测试精度达到了61.38%和73.56%, 相比于现有的基于小样本学习的模型拥有更好的识别精度, 为小样本林业病害识别领域了一个全新的解决办法. 下一步的研究工作可考虑从以下两方面着手.
(1)建设公共林业病害数据集和统一的评估标准. 由于数据的采集方式各不相同, 得到的林业病害数据在质量、尺寸和数据方面都差异较大. 因此制定统一的林业病害数据集标准对林业病害识别领域至关重要.
(2)结合Transformer模型的优点继续对DML-MB模型进行优化和改进. 通过利用Transformer模型结构中含有的自注意力层, 提升DML-MB模型感知关键区域的能力, 进而增强模型在小样本图像分类任务上的表现.
[1] |
李浩, 方伟泉, 李浪浪, 等. 基于深度学习的松材线虫病害松木识别. 林业工程学报, 2021, 6(6): 142-147. |
[2] |
吴云志, 刘翱宇, 朱小宁, 等. 一种植物病害图像识别卷积网络架构. 安徽农业大学学报, 2021, 48(1): 150-156. |
[3] |
牟文芊, 董萌萍, 孙文杰, 等. 基于SENet和深度可分离卷积胶囊网络的茶树叶部病害图像识别. 山东农业大学学报(自然科学版), 2021, 52(1): 23-28. |
[4] |
Amara J, Bouaziz B, Algergawy A. A deep learning-based approach for banana leaf diseases classification. In: Mitschang B, Nicklas D, Leymann F, et al., eds. Datenbanksysteme für Business, Technologie und Web (BTW 2017)-Workshopband. Bonn: Gesellschaft für Informatik, 2017. 79–88.
|
[5] |
Brahimi M, Boukhalfa K, Moussaoui A. Deep learning for tomato diseases: Classification and symptoms visualization. Applied Artificial Intelligence, 2017, 31(4): 299-315. DOI:10.1080/08839514.2017.1315516 |
[6] |
祝钧桃, 姚光乐, 张葛祥, 等. 深度神经网络的小样本学习综述. 计算机工程与应用, 2021, 57(7): 22-33. DOI:10.3778/j.issn.1002-8331.2012-0200 |
[7] |
李新叶, 龙慎鹏, 朱婧. 基于深度神经网络的少样本学习综述. 计算机应用研究, 2020, 37(8): 2241-2247. |
[8] |
Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning. Proceedings of the 30th International Conference on Neural Information Processing Systems. Barcelona: Curran Associates Inc., 2016. 3637–3645.
|
[9] |
Snell J, Swersky K, Zemel R. Prototypical networks for few-shot learning. Proceedings of the 31st Conference on Neural Information Processing Systems. Long Beach: NIPS, 2017. 4077–4087.
|
[10] |
Sung F, Yang YX, Zhang L, et al. Learning to compare: Relation network for few-shot learning. Proceedings of the 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City: IEEE, 2018. 1199–1208.
|
[11] |
Munkhdalai T, Yu H. Meta networks. Proceedings of the 34th International Conference on Machine Learning. Sydney: JMLR.org, 2017. 2554–2563.
|
[12] |
Ravi S, Larochelle H. Optimization as a model for few-shot learning. Proceedings of the 5th International Conference on Learning Representations. Toulon: OpenReview.net, 2017. 1–11.
|
[13] |
Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks. Proceedings of the 34th International Conference on Machine Learning. Sydney: PMLR, 2017. 1126–1135.
|
[14] |
Chen YB, Wang XL, Liu Z, et al. A new meta-baseline for few-shot learning. arXiv:2003.04390v2, 2020.
|
[15] |
Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network. arXiv:1503.02531, 2015.
|
[16] |
Zhang Y, Xiang T, Hospedales TM, et al. Deep mutual learning. Proceedings of 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. Salt Lake City: IEEE, 2018. 4320–4328.
|
[17] |
肖伟, 冯全, 张建华, 等. 基于小样本学习的植物病害识别研究. 中国农机化学报, 2021, 42(11): 138-143. |