Vision: [Discussion] squeezenet model; different num of classes

Created on 22 May 2017  路  8Comments  路  Source: pytorch/vision

For training from scratch, the code would work fine. But mostly, people would use the pre-trained weights, so on top of setting the self.num_classes, you also have to change the final_conv layer, after setting pre-trained weights.

        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AvgPool2d(13))

While this may be straightforward in other models, in this the final_conv (local to __init__) is actually inside a sequential module. And when I tried to change only the final_conv layer from self.classifier (by writing classifier[1] = nn.Conv2d etc) I get the error 'Sequential object doesnt support item assignment').

Eventually I just changed the classifier block to a nn.Sequential module with the new conv, reflecting new number of classes.

Few things to ask/discuss:

  1. Am I making some mistake in changing only the final_conv layer from inside classifier sequential module.
  2. If not, should we change that part so that its easier to just change only the final_conv layer?
  3. If the above isn't preferred, should I just add a comment somewhere so that people will know they have to change the classifier itself?
  4. People should be able to figure it out, this is minor :P

3rd option makes sense as they might have to change AvgPool also, but I had removed it and was using the functional API for that, so didnt realize until now.

PS: Wasnt sure if this should be in Issues or in discuss. If latter is better, I will remove it and post there.

Most helpful comment

@fmassa Your example doesn't work because it doesn't handle different parameter sizes. Also, pretrained_dict will have less or equal amount of items than model_dict. Passing less items to load_state_dict() will throw intentionally.

I pruned a working example down to this:

    #Update
    model_dict = model.state_dict()    
    pretrained_dict = model_zoo.load_url(model_urls[name])

    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() != v.size()}

    pretrained_dict.update(diff)

    model.load_state_dict(pretrained_dict)
    return model, diff

With model loading code, do you mean adding a method to nn.Module or a helper function to torchvision? ....I'd really like to see update_state_dict() for the nn.Module (my second suggestion above). Then, there is no need to repeat that for text, sound, you name it.

All 8 comments

Adding to '2.' Converting it to ordered dict will help?

I just realized, this is the structure for all the models in the repo. A features seq module and a classifier seq module. So this issue is applicable everywhere, so people would know to change the classifier i guess?

I just submitted this PR with a rough version of a function that can change the number of outputs to various networks, including SqueezeNet

https://github.com/pytorch/vision/pull/175

Yeah, looks useful! Although I think people would prefer to do this themselves? Not sure. Let's see what comments you receive on the PR. Also, can you write some block comment in the function so that people don't have to read the code to understand what it does?

Ack, being able to use num_classes and pretrained=True at the same time for every model would be a big win. All functions instantiate pretrained models in two steps. Therefore, the approach was to update the state_dict() only if the size() of loaded parameters match the size() of vanilla modules instantiated by all the **kwargs. You can reuse this diff to freeze layers too.

Both, the version of merging and @danielhauagge's programatic version loaded the whole model zoo: https://github.com/ahirner/pytorch-retraining/blob/test_pr/load_pretrained.ipynb
Merging on the level of state_dict() doesn't assume that layers have specific types, though.

Should I make a PR?

PS: if so, what would be a nicer place / API?

# update_model(model, state_dict) in torchvision/models/utils.py
model = AlexNet(**kwargs)
    if pretrained:
        update_model(model, (model_zoo.load_url(model_urls['alexnet']))
    return model

or, because this will also work for models beyond torchvision:

# update_state_dict(self, state_dict, diff=False) for nn.Module 
model = AlexNet(**kwargs)
    if pretrained:
        diff_keys = model.update_state_dict(model_zoo.load_url(model_urls['alexnet'])
        return (model, diff_keys) if diff else model
    return model

I think we should change the code in the model loading to handle the case of different models compared to the saved weight (allowing to modify the structure of the network).
Here is an example

def _update_with_pretrained_weights(model, model_path):
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path)

    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

@fmassa Your example doesn't work because it doesn't handle different parameter sizes. Also, pretrained_dict will have less or equal amount of items than model_dict. Passing less items to load_state_dict() will throw intentionally.

I pruned a working example down to this:

    #Update
    model_dict = model.state_dict()    
    pretrained_dict = model_zoo.load_url(model_urls[name])

    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() != v.size()}

    pretrained_dict.update(diff)

    model.load_state_dict(pretrained_dict)
    return model, diff

With model loading code, do you mean adding a method to nn.Module or a helper function to torchvision? ....I'd really like to see update_state_dict() for the nn.Module (my second suggestion above). Then, there is no need to repeat that for text, sound, you name it.

@fmassa Your example doesn't work because it doesn't handle different parameter sizes. Also, pretrained_dict will have less or equal amount of items than model_dict. Passing less items to load_state_dict() will throw intentionally.

I pruned a working example down to this:

    #Update
    model_dict = model.state_dict()    
    pretrained_dict = model_zoo.load_url(model_urls[name])

    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() != v.size()}

    pretrained_dict.update(diff)

    model.load_state_dict(pretrained_dict)
    return model, diff

With model loading code, do you mean adding a method to nn.Module or a helper function to torchvision? ....I'd really like to see update_state_dict() for the nn.Module (my second suggestion above). Then, there is no need to repeat that for text, sound, you name it.

Hi@ahirner,
to freeze some of weights using the keys layers namein state_dict, what should I do?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ibtingzon picture ibtingzon  路  3Comments

datumbox picture datumbox  路  3Comments

300LiterPropofol picture 300LiterPropofol  路  3Comments

IssamLaradji picture IssamLaradji  路  3Comments

xuanqing94 picture xuanqing94  路  3Comments