PyTorch loading pretrained weights
I am trying to load a pretrained model resnet_18.pth
file into pytorch. Online documentation suggested importing like so:
weights = tor开发者_如何学编程ch.load("resnet_18.pth")
When I print the output of weights
, it gives something like the following:
('module.layer4.1.bn2.running_mean', tensor([ 9.1797e+01, -2.4204e+02, 5.6480e+01, -2.0762e+02, 4.5270e+01,
-3.2356e+02, 1.8662e+02, -1.4498e+02, -2.3701e+02, 3.2354e+01,
...
All of the tutorials mentioned loading weights using a base model:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
I want to use a default resnet-18 model to apply the weights on, but I the resent18
from tensorflow vision
does not have the load_state_dict
function. Help is appreciated.
from torchvision.models import resnet18
resnet18.load_state_dict(torch.load("resnet_18.pth"))
# 'function' object has no attribute 'load_state_dict'
resnet18
is itself a function that returns a ResNet18
model. What you can do to load your own pretrained weights is to use
model = resnet18()
model.load_state_dict(torch.load("resnet_18.pth"))
Note that load_state_dict(...)
loads the weights in-place and does not return model itself.
精彩评论