deeper

  • 1

    获得赞
  • 1

    发布的文章
  • 1

    答辩的项目

半弱监督CNN模型在蘑菇识别比赛中的应用

Batch大小为8,循环次数为30次,通过在线上环境完成训练,模型最优精度评分为97.60。

什么蘑菇?
PyTorch CNN ResNeXt
最后更新 2021/08/27 17:42 阅读 5489

什么蘑菇?

最后更新 2021/08/27 17:42

阅读 5489

PyTorch CNN ResNeXt

比赛结束,第四名,97.6的准确率。虽然是个简单的图像分类比赛,但是想拿高分也是不容易的。

下面分三个部分讲述下解决这个蘑菇分类问题的心路历程。

1 问题分析与探索型数据分析

        一是分析题目和测试数据集可知,这个问题是个多类别标签图像分类问题,第一直觉就是使用卷积神经网络CNN来解决这个问题;二是从测试数据集分析可知,标签总共有9类,这9类分别是什么,不知道哈;三是分析测试数据集,这些采集的照片都是通过拍照得到的,切照片的中心一般就是我们的蘑菇了;四是分析图像可知,蘑菇从外观、大小上看区别不打,主要从形状和颜色上进行区分,因此在数据扩增的时候要注意颜色抖动的设置;五是纠结于使用CNN还是现在比较火的vision transformer,考虑到平台要用认证过的模型,因此决定使用CNN;六是数据扩增的选择问题,考虑到类别不均衡问题,我这里用Yolov5中的Mosaic增强方法对同类别的图像进行了增强;七是训练技巧的选择问题,尽量考虑一些提点特别明显的方法,后面会介绍。

2 基线模型与调优

2.1 基线模型的选择    

       常用的CNN基线模型有Resnet家族、efficientnet家族等,做了简单的测试,使用resnet50和efficientnetb3的结果不是很理想,因此寻求泛化能力比较强的模型。我选择的模型来自论文Billion-scale semi-supervised learning for image classification,该论文中提出使用海量的无标签图像,通过弱监督的方式,来提升resnet50和resnext101模型在ImageNet数据集上的准确率。通过测试确实效果很炸,因此我们采用半弱监督学习得到的预训练resnet50和resnext101模型。    

        smsl-resnet50和smsl-resnext101 32x8在ImageNet的Top-1准确率分别为81.2和84.3,效果很不错。在没有太多数据增强和训练技巧加持下,在我们比赛中的成绩,很容易就超过了95%。实验证明确实是非常好的基线模型。    模型确立好以后,那么就需要对模型的输出进行处理,这是常规操作了,将输出的特征维度不变,将输出维度改为9,使用torch很容易实现。    

path = remote_helper.get_remote_data('https://www.flyai.com/m/semi_weakly_supervised_resnext101_32x8-b4712904.pth')    
net = resnext101_32x8d(pretrained=False)    
net.load_state_dict(torch.load(path), strict=True)    
feature = net.fc.in_features    net.fc = nn.Linear(in_features=feature, out_features=9, bias=True)

2.2 调优    

       一是数据扩增。将图像尺寸从224调制416, 最后确立在380的时候,可以兼顾训练时间和准确率。批尺寸大小设为8, 因此为了得到较好的收敛速度,使用梯度累积的技术,没8个step做一次梯度更新,这样就相当于批尺寸为64了,其实还可以试下更大的梯度累积,也许效果更好。另外三个比较好的方法,Mosaic,CutMix和TTA对涨点效果较好。对于Mosaic是使用同类别的图像进行马赛克拼接得到一个新图像,来提升该类别图像的数量和多样性,对于CutMix就是拼接两张图片,并按照大小来求损失,这里的CutMix和Mosaic有点不同,就是CutMix是可以用于不同类别图像融合,但是这个比赛中使用的Mosaic只用于同类别的图像融合和扩增。对于TTA,这是测试时增强,在推理的时候,对图像做随机变化,然后融合结果,提升结果的准确性。    

        二是模型的优化,这里仅仅简单的替换了输出线性层,没有做复杂的特征处理或者Arcface等操作。学习率为0.01,权重衰减为0.001。        

        三是训练技巧,主要包括标签标签平滑、交叉验证、偏置无L2衰减约束、指数移动平均。其中标签平滑和CutMix类似,主要是针对损失函数处理,用这个技巧将分数从95%提升到96%;交叉验证是常规操作,将分数从96%提升到97%;偏置无L2衰减是指在训练的过程对模型中的偏置不加衰减,可以提升模型的拟合能力;指数移动平均,在最后分数超过95%的时候,开始对每个step的模型做指数移动平均,提升模型的泛化能力。

3 总结

       一是这个比赛我比第一名少了将近0.1个点,差距还是挺大的。说明图像分类问题虽然实现baseline简单,但是要取得好的成绩还是很困难的,再接再厉吧!二是在比赛中的图像尺寸带来的增益确实不错,如果能够在图像加载时对尺寸做一些优化,比如多尺度图像学习,也许还有不错的收益;三是比赛中只使用了单个模型,效果可能不如多个模型带来增益多;四是弱监督学习可以加入进来;五是在比赛中EMA让人有爱有恨,效果很难感觉出来;在图像识别中,记住label smoothing, 图像扩增、大尺寸图像、强基线模型、CutMix都是可以稳定提点的方法。

本文为作者在FlyAI平台发布的原创内容,采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可,转载请附上原文出处链接和本声明。
本文链接地址:https://flyai.com/n/172544
立即参加 什么蘑菇?
代码展示

PyTorch CNN ResNeXt

选择查看文件
$vue{codeKeys}
  • $vue{ix}
赞赏贡献者还可以查看和下载优质代码内容哦!
赞赏 ¥16.53元
©以上内容仅用于在FlyAI平台交流学习,禁止转载、商用;违者将依法追究法律责任。
讨论
500字
表情
每日优质讨论奖励 20FAI
发送
每日优质讨论奖励 20FAI
删除确认
是否删除该条评论?
取消 删除
感谢您的关注
该篇内容公开后我们将会给你推送公开通知
好的
发布成功!
您的公开申请已发送至后台审核,
通过后将公开展示本详情页!
知道了
向贡献者赞赏
¥16.53
微信支付
支付宝

请先绑定您的微信账号 点击立即绑定

立即支付
温馨提示:
支付成功后不支持申请退款,请理性消费;
支付成功将自动解锁当前页面代码内容,付款前请确认账号信息。
微信扫码支付
请前往Web网页进行支付

敬请谅解,如有疑问请联系FlyAI客服

知道了
举报
请选择举报理由
确定
提示
确定要删除?
取消删除

今日签到成功

获得 $vue{sianData.sign_fai} FAI的GPU算力积分

知道了