Pytorch-lightning: load checkpoint from URL

Created on 20 Apr 2020  路  5Comments  路  Source: PyTorchLightning/pytorch-lightning

Let's enable loading weights from a URL directly

Option 1:

Automate it with our current API

Trainer.load_from_checkpoint('http://')

Option 2:

Have a separate method

Trainer.load_from_checkpoint_at_url('http://')

Resources

We can use this under the hood:
(https://pytorch.org/docs/stable/hub.html#torch.hub.load_state_dict_from_url)

Any thoughts on which one is better?
@PyTorchLightning/core-contributors

enhancement good first issue help wanted let's do it!

Most helpful comment

@yukw777 nice. want to submit a PR?
Seems like overkill for the first instance no? we want to keep this simple.

Maybe we do a v1 that supports http, https only?
Expand functionality in v2 for other protocols?

All 5 comments

I recently had to implement a similar functionality and I found this library very useful: https://github.com/RaRe-Technologies/smart_open. It handles various protocols other than http(s) like s3 with a very simple interface like python鈥檚 default file open. Thought you might be interested.

slight preference for option 1

@yukw777 nice. want to submit a PR?
Seems like overkill for the first instance no? we want to keep this simple.

Maybe we do a v1 that supports http, https only?
Expand functionality in v2 for other protocols?

@williamFalcon yeah i can give it a shot. I just read the code of load_state_dict_from_url and it seems like we don't really need to use smart_open for this, so I'll proceed without that. I also prefer Option 1.

A few questions:

  1. Should we provide a default directory to which we'd download the checkpoints? If so, where?
  2. What's your recommendation on writing tests for this? It'd be simplest to have a test checkpoint file uploaded somewhere and just download that...

Maybe we'll also use torch cache for that? Since after all it's still a PyTorch model? Or do we want to create yet another cache?

Was this page helpful?
0 / 5 - 0 ratings