Vision: [feature request] Add optional 'directory/path' parameter to save available pretrained models.

Created on 3 Oct 2018  路  6Comments  路  Source: pytorch/vision

I may be wrong but I could not find any simple way of specifying where the pre-trained models from torchvision.models are saved.

By default, they are saved in the ...\.torch\models\ directory.
It would be very helpful to have an optional parameter to set the download path for the pre-trained models. (something like model_dir parameter in torch.utils.model_zoo.load_url())

If there is an existing way to do this already, please correct me in the comments.

enhancement models needs discussion

Most helpful comment

The documentation is a bit confusing in this regard, to be honest. As documented here (the same function you mentioned), the TORCH_HOME environment variable can be used for this. Since torchvision uses this function (e.g., here) models will be downloaded to $TORCH_HOME/models.

Still, it feels like a workaround. I agree that a model_dir argument would be nice in this case.

EDIT: I forgot to mention that the TORCH_MODEL_ZOO environment variable can override the above, as can be seen here.

All 6 comments

The documentation is a bit confusing in this regard, to be honest. As documented here (the same function you mentioned), the TORCH_HOME environment variable can be used for this. Since torchvision uses this function (e.g., here) models will be downloaded to $TORCH_HOME/models.

Still, it feels like a workaround. I agree that a model_dir argument would be nice in this case.

EDIT: I forgot to mention that the TORCH_MODEL_ZOO environment variable can override the above, as can be seen here.

Thanks for the feedback,
I agree that having a better documentation for TORCH_MODEL_ZOO env var would be great. Could you send a PR improving the doc?

Thanks!

Thanks a lot @fadel.
This is exactly what I was looking for.

Also, the changes made to the docs in the above pull request aren't reflected in the documentations webiste YET: https://pytorch.org/docs/stable/torchvision/models.html

@fmassa Could you please tell me if having model_dir directly as a parameter in torchvision.models alongside pretrained be a good idea or would that be counter-intuitive?

(I can send a PR to do the same if that is required.)

Has the model_dir parameter been added yet? If not, it would be a great addition

@rsk2327 I had submitted a PR regarding this but as you can see it is still open. Meanwhile, you can save it in the desired directory using the environment variable TORCH_HOME.

I am just coming from issue #2299 , and inclined to agree with adding an extra parameter for controlling the model-path of both downloading and loading.

As @fmassa explained in #2299 , using TORCH_HOME makes our code agnostic to the model-path . After thinking it over, I find it convenient for the novice, as it makes the users don't need to consider the model-path for both downloading and loading.

But it leads to much more inflexibility for those familiar with Pytorch and Python

For this reason, I have worked out a compromise as a temporary solution. (Many codes are following the source codes in torchvision)

For downloading:

from urllib.parse import urlparse
import torch.utils.model_zoo as model_zoo
import re
import os
def download_model(url, dst_path):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)

    HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
    hash_prefix = HASH_REGEX.search(filename).group(1)

    model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
    return filename

Downloading demo

model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
path = 'D:/Software/DataSet/models/vgg'
if not (os.path.exists(path)):
    os.makedirs(path)
for url in model_urls.values():
    download_model(url, path)

For loading:

import torchvision.models as models
import glob
import os
def load_model(model_name, model_dir):
    model  = eval('models.%s(init_weights=False)' % model_name)
    path_format = os.path.join(model_dir, '%s-[a-z0-9]*.pth' % model_name)

    model_path = glob.glob(path_format)[0]

    model.load_state_dict(torch.load(model_path))
    return model

Loading demo

model_dir = 'D:/Software/DataSet/models/vgg/'
model = load_model('vgg11', model_dir)

Hope it helpful~

Was this page helpful?
0 / 5 - 0 ratings

Related issues

timonbimon picture timonbimon  路  28Comments

lpuglia picture lpuglia  路  44Comments

rbrigden picture rbrigden  路  59Comments

h6197627 picture h6197627  路  23Comments

davidsteinar picture davidsteinar  路  23Comments