开发者

Pytorch中的modle.train,model.eval,with torch.no_grad解读

目录
  • modle.train,model.eval,with torch.no_grad解读
  • model.eval()与torch.no_grad()的作用
    • model.eval()
    • torch.no_grad()
    • 异同
  • 总结

    modle.train,model.eval,androidwith torch.no_grad解读

    1. 最近在学习pytorch过程中遇到了几个问题

    不理解为什么在训练和测试函数中model.eval(),和mophpdel.train()的区别,经查阅后做如下整理

    一般情况下,我们训练过程如下:

    www.devze.com

    拿到数据后进行训练,在训练过程中,使用

    • model.train():告诉我们的网络,这个阶段是用来训练的,可以更新参数。

    训练完成后进行预测,在预http://www.devze.com测过程中,使用

    • model.eval(): 告诉我们的网络,这个阶段是用来测试的,于是模型的参数在该阶段不进行更新。

    2. 但是为什么在eval()阶段会使用with torch.no_grad()?

    查阅相关资料:传送门

    with torch.no_grad -开发者_自学开发 disables tracking of gradients in autograd.

    model.eval() changes the forward() behaviour of the module it is called upon

           eg, it disables dropout and has BATch norm use the entire population statistics

    总结一下就是说,在eval阶段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都会进行预测,而使用no_grad则设置让梯度Autograd设置为False(因为在训练中我们默认是True),这样保证了反向过程为纯粹的测试,而不变参数。

    另外,参考文档说这样避免每一个参数都要设置,解放了GPU底层的时间开销,在测试阶段统一梯度设置为False

    model.eval()与torch.no_grad()的作用

    model.eval()

    经常在模型推理代码的前面, 都会添加model.eval(), 主要有3个作用:

    • 1.不进行dropout
    • 2.不更新batchnorm的mean 和var 参数
    • 3.不进行梯度反向传播, 但梯度仍然会计算

    torch.no_grad()

    torch.no_grad的一般使用方法是, 在代码块外面用with torch.no_grad()给包起来。 如下面这样:

    with torch.no_grad():
     # your code

    它的主要作用有2个:

    • 1.不进行梯度的计算(当然也就没办法反向传播了), 节约显存和算力
    • 2.dropout和batchnorn还是会正常更新

    异同

    从上面的介绍中可以非常明确的看出,它们的相同点是一般都用在推理阶段, 但它们的作用是完全不同的, 也没有重叠。 可以一起使用。

    总结

    以上为个人经验,希望能给大家一javascript个参考,也希望大家多多支持我们。

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜