Yolov5: Loading model custom trained weights using Pytorch hub

Created on 5 Dec 2020  ·  13Comments  ·  Source: ultralytics/yolov5

❔Question

Loading model custom trained weights using Pytorch hub

Additional context

Hi,

I'm trying to load my custom model weights using torch hub. As you can see in the image. In one case It was able to load it and second case it failed.

image

But in the other case i got following error.
image

image
image

Lets say, if i choose first method to load the custom weights. how can i make inference directly of a PIL image. Because when i was trying to do that i got following error.
image
image

question

Most helpful comment

@p9anand @zhiqwang I've updated the PyTorch Hub tutorial as follows and implemented a default class names list in PR #1608.

@p9anand can you confirm that the new tutorial directions work for you? They are here. I think names is the only attribute that was missing before.

Load a State Dict

To load a custom state dict, first load a PyTorch Hub model of the same kind with the same number of classes:

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', classes=10)  # create model
ckpt = torch.load('yolov5s_10cls.pt')['model']  # load checkpoint
model.load_state_dict(ckpt.state_dict())
model.names = ckpt.names  # transfer class names (recommended)

All 13 comments

Hello @p9anand, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at [email protected].

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

I think that

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', classes=6)
model.load_state_dict(torch.load('...')['model'].state_dict())

model = model.fuse().autoshape()

would be a method?

I tried that but getting following error:
image
image
image

custom weights are loaded successfully and Autoshape was also loaded successfully.

It seems that this line

https://github.com/ultralytics/yolov5/blob/ba48f867ead1fcdc59f2a69fbe40c915bd4936bd/models/yolo.py#L192

didn't work as expected, could you please upload your model so that I could debug this problem (sorry I didn't train any custom model)

Hi @p9anand

I did a quick fixes as below:

import torch
from PIL import Image


def copy_attr(a, b, include=(), exclude=()):
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
            continue
        else:
            setattr(a, k, v)

# my network so slow that I can't download your model above successfully :(
# so I just use ultralytics's pretrained model here 
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=False, classes=80)
checkpoint_ = torch.load('../yolov5-ultralytics/weights/yolov5s.pt')['model']
model.load_state_dict(checkpoint_.state_dict())

copy_attr(model, checkpoint_, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=())

model = model.fuse().autoshape()

img = Image.open('./notebooks/assets/zidane.jpg')
output = model(img)

print(f'prediction: {output.pred}')

Hope it can help you

It worked. Thanks for the help!!

@zhiqwang would you be able to submit a PR with your updates to copy_attr()? Thanks!!

Hi @glenn-jocher It's my pleasure to contribute this awesome repo.

The copy_attr() function is actually copied from

https://github.com/ultralytics/yolov5/blob/ba48f867ead1fcdc59f2a69fbe40c915bd4936bd/utils/torch_utils.py#L199-L205

This issue only occur when somebody are loading the trained model with torch.hub and setting pretrained to be False. It's normal using

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, classes=80)

And there is a tutorial of the usage with torch.hub in #36, and you comment that

Load a State Dict

To load a custom state dict, first load a PyTorch Hub model of the same kind with the same number of classes:

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', classes=10)
model.load_state_dict(torch.load('yolov5s_10cls.pt')['model'].state_dict())

So maybe it would be better to put this directions of the usage of copy_attr() there?

@zhiqwang ah yes, I understand. We probably want to make changes in two places:

  1. tutorial with the changes you recommended
  2. yolo.py addition to automatically create a default names attribute that can be used for when people don't pass their own names in 1.

@p9anand @zhiqwang I've updated the PyTorch Hub tutorial as follows and implemented a default class names list in PR #1608.

@p9anand can you confirm that the new tutorial directions work for you? They are here. I think names is the only attribute that was missing before.

Load a State Dict

To load a custom state dict, first load a PyTorch Hub model of the same kind with the same number of classes:

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', classes=10)  # create model
ckpt = torch.load('yolov5s_10cls.pt')['model']  # load checkpoint
model.load_state_dict(ckpt.state_dict())
model.names = ckpt.names  # transfer class names (recommended)

@zhiqwang ah also, the .autoshape() method contains a .fuse() call inside it, so you can simply do model.autoshape().

Hi @glenn-jocher , Got it, it's more clean now.

@glenn-jocher : It worked. Thanks for the help!!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

milind-soni picture milind-soni  ·  3Comments

jaqub-manuel picture jaqub-manuel  ·  4Comments

cswwp picture cswwp  ·  4Comments

abhiksark picture abhiksark  ·  3Comments

krishnam3065 picture krishnam3065  ·  4Comments