黑羽

  • 1

    获得赞
  • 0

    发布的文章
  • 2

    答辩的项目

蘑菇分类 -黑羽

Batch大小为1,循环次数为1次,通过在线上环境完成训练,模型最优精度评分为97.68。

什么蘑菇?
CNN PyTorch 图像分类
最后更新 2021/08/27 17:42 阅读 525

什么蘑菇?

最后更新 2021/08/27 17:42

阅读 525

CNN PyTorch 图像分类

赛题背景: 

         本次赛题背景是日常生活中蘑菇类别较多,而其中包含了很多有毒蘑菇,人很难辨别,容易引起误食,该比赛想通过深度学习的方法帮助人们识别出毒蘑菇。本次竞赛使用的数据集由北欧真菌学家协会提供的9 种常见北欧蘑菇属的图像组成。 

 赛题分析: 

     这个题目就是一个图像分类问题,给定了6045张图片作为训练集,675张图片作为测试集,对于线下的10%数据。

 实验过程: 

      1、5折交叉划分训练集和验证集,遍历每一个类别,训练与验证数据比例大致相同 

       2、试验了各种分类模型

       3、预测的时候使用水平翻转求平均的简单策略

        4、测试各种超参数 batchsize,学习率,增强策略,优化器,损失函数等 

 提分的一些操作:

       1、标签平滑

       2、模型选择 Swim-Transformer

       3、数据集划分 9:1验证集有过拟合问题,8:2稍微好一些

       4、模型融合 

       5、随机裁剪、随机擦除

       6、tta

最终模型和参数:

     (1)数据集划分:     训练集5折交叉 训练集:验证集=8:2

     (2)模型选择:     Swim-Transformer    预训练模型: swin_large_patch4_window12_384_22k.pth

     (3)inputsize: 383*384

     (4)batchsize: 4

     (5)max_epochs: 30

     (6)学习率:    阶梯下降学习率

      def lr_scheduler(optimizer, epoch):

             if epoch < 10:    

                 lr = 1e-3

             elif epoch < 20:

                  lr = 1e-4

            else: lr = 1e-5

           for param_group in optimizer.param_groups:

               param_group['lr'] = lr

          return optimizer

     ( 7)优化器:

             self.optimizer = optim.SGD((self.model_ft.parameters()), lr=1e-3, momentum=momentum, weight_decay=0.0005)         (8)损失函数:    

              LabelSmoothingCrossEntropy(smoothing=0.1).cuda()  

      (9)增强策略:     

              训练集: 随机裁剪、水平翻转、随机擦除

              验证集:简单缩放     

              data_transforms = {         

                     'train': transforms.Compose([             

                          transforms.RandomResizedCrop((input_size, input_size)), 

                          transforms.RandomHorizontalFlip(),            

                          transforms.ToTensor(),            

                          transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),    

                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  ]),         

                    'val': transforms.Compose([             

                        transforms.Resize((input_size,input_size)),             

                        transforms.ToTensor(),            

                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])         ]),  

(10)测试增强:      

           正常图片+ 水平翻转 -> 求平均   

(11)融合策略:     

          保存每折训练过程中最大Acc和最小Loss的模型,数据集5折交叉,共5*2=10个模型进行融合  

 教训总结:

     1、过早的模型融合,忽略了单个模型的性能比较

     2、验证集过拟合问题

     3、代码没有线下测试,导致线上出问题,浪费了很多实验时间

总结和展望

     1、对图片没有进行预处理等操作,数据分析还不足

      2、多看看最新的研究,最后阶段才尝试使用最新的分类模型,对比之前的模型提分很明显

      3、训练速度较慢,前排大神两次训练间隔才10几分钟就可以得到比较好的分数,我每次训练都要好几个小时,很想学习下,以后可以改进下训练策略

      4、参加几次比赛,我也学到了很多东西,希望以后有更多的人来参与这个平台进行比赛,通过比赛和分享让大家共同进步。谢谢大家!

本文为作者在FlyAI平台发布的原创内容,采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可,转载请附上原文出处链接和本声明。
本文链接地址:https://flyai.com/n/173660
立即参加 什么蘑菇?
代码展示

CNN PyTorch 图像分类

选择查看文件
$vue{codeKeys}
  • $vue{ix}
赞赏贡献者还可以查看和下载优质代码内容哦!
赞赏 ¥16.72元
©以上内容仅用于在FlyAI平台交流学习,禁止转载、商用;违者将依法追究法律责任。
讨论
500字
表情
每日优质讨论奖励 20FAI
发送
每日优质讨论奖励 20FAI
删除确认
是否删除该条评论?
取消 删除
感谢您的关注
该篇内容公开后我们将会给你推送公开通知
好的
发布成功!
您的公开申请已发送至后台审核,
通过后将公开展示本详情页!
知道了
向贡献者赞赏
¥16.72
微信支付
支付宝

请先绑定您的微信账号 点击立即绑定

立即支付
温馨提示:
支付成功后不支持申请退款,请理性消费;
支付成功将自动解锁当前页面代码内容,付款前请确认账号信息。
微信扫码支付
请前往Web网页进行支付

敬请谅解,如有疑问请联系FlyAI客服

知道了
举报
请选择举报理由
确定
提示
确定要删除?
取消删除

今日签到成功

获得 $vue{sianData.sign_fai} FAI的GPU算力积分

知道了