开发者

Pytorch中关于model.eval()的作用及分析

目录
  • model.eval()的作用及分析
    • 结论
  • Pytorch踩坑之model.eval()问题
    • 比较常见的有两方面的原因
    • 1) data
    • 2)model.state_dict()
    • model.eval()   vs   torch.no_grad()
  • 总结

    model.eval()的作用及分析

    • model.eval() 作用等同于 self.train(False)

    简而言之,就是评估模式。而非训练模式。

    在评估模式下,BATchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。

    结论

    在对模型进行评估时,应该配合使用with torch.no_grad() 与 model.eval():

      loop:
        model.train()  # 切换至训练模式
    http://www.devze.com    train……
        model.eval()
        wandroidith torch.no_grad():
          EvaLuation
      end loop

    Pytorch踩坑之model.eval()问题

    最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,这简直不能忍啊~下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。 

    对于训练好的模型加载进来准确率和原先的不符,

    比较常见的有两方面的原因

    • data
    • model.state_dict() 

    1) data

    数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。

    2)model.state_dict()

    第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。

    如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。

    而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()’,造成了模型的参数发生变化,再次调用则出现了掉点。于是又回顾了一下model.eval()和model.train()的具体作用。如下:

    model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch Normalization 和 Dropout 方法模式不同:

    • a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOupythont固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_开发者_Pythonsize过小,很容易就会因BN层导致模型performance损失较大;
    • b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和bandroidatch normalization的操作在训练q起到防止网络过拟合的问题。

    因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。

    model.eval()   vs   torch.no_grad()

    虽然二者都是eval的时候使用,但其作用并不相同:

    model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。

    见下方代码:

     import torch
     import torch.nn as nn
    
     drop = nn.Dropout()
     x = torch.ones(10)
    
     # Train mode 
     drop.train()
     print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.]) 
    
     # Eval mode 
     drop.eval()
     jsprint(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

    torch.no_grad() 负责关掉梯度计算,节省eval的时间。

    只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。

    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜