Describe the bug
When loading a model directly from a checkpoint I get an error "OSError: Checkpoint does not contain hyperparameters. Are your model hyperparameters storedin self.hparams?"
But my model clearly has the hparams.
To Reproduce
Just create a model save a checkpoint and try to load it like explained in the documentation:
pretrained_model = MyLightningModule.load_from_checkpoint(
checkpoint_path='/path/to/pytorch_checkpoint.ckpt'
)
Possible reason
I found that code in the trainer_io.py class line 301:
try:
torch.save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']
torch.save(checkpoint, filepath)
Obviously if the code to save the checkpoint deletes de hparams the load checkpoint function will not find that...
Expected behavior
A more concise way to easily load a checkpoint without the need for the load_from_metrics function.
IIRC, that was a hack to workaround an edge case where the hparams weren't pickleable. Seems like the original ticket #433 is still open. @williamFalcon do we still need this hack?
Maybe the problem is about saving an object with a lamba function. I see this line in the issue log:
File "/private/home/falc/.local/lib/python3.7/site-packages/torch/serialization.py", line 224, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
Pickle doesn't allow lambda functions to be saved but if it's this the reason, it's an easy fix I believe.
https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions
I'm also having the same issue but I'm not using any lambda
in general lambda function is not serializable, so all these items should be removed before saving
@wakandan @ricardorei @neggert interested in sending PR?
I've encountered the same problem. Looks like the culprit is this line: https://github.com/williamFalcon/pytorch-lightning/blob/6666ca5af39aa2d3e5a483da3d7f6bb76514cc9f/pytorch_lightning/trainer/trainer_io.py#L321
After a bit of debugging I've figured out that the return of vars
actually contains all of the bound methods of hparams
! This seems to produce a pickling error. In my case it goes like this:
AttributeError: Can't pickle local object 'ArgumentParser.__init__.<locals>.identity'
This exception in its turn gets handled in https://github.com/williamFalcon/pytorch-lightning/blob/6666ca5af39aa2d3e5a483da3d7f6bb76514cc9f/pytorch_lightning/trainer/trainer_io.py#L264
by removing the hparams
altogether.
Do we need this vars
call at all? The TTNamespace
that is normally used here is perfectly picklable on its own.
Ahh, good catch. So this works with an argparse.Namespace
, but will fail with a TTNamespace
? Want to send a PR that changes it to just pickle the Namespace directly and removes the hacky exception handling?
I just encountered the same issue - does this mean I can't load the models I trained in the past couple of days or is there some workaround until this is fixed?
If you still have access to hparams, here is a quick fix for load_from_checkpoint
(I am not suggesting to change the method, just in case someone needs this functionality before it is fixed).
@classmethod
def load_from_checkpoint(cls, checkpoint_path, hparams, map_location=None):
"""
Primary way of loading model from a checkpoint
:param checkpoint_path:
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
:return:
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# try:
# ckpt_hparams = checkpoint['hparams']
# except KeyError:
# raise IOError(
# "Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
# "in self.hparams?"
# )
# hparams = Namespace(**ckpt_hparams)
# load the state_dict on the model automatically
model = cls(hparams)
model.load_state_dict(checkpoint['state_dict'])
# give model a chance to load something
model.on_load_checkpoint(checkpoint)
return model
This is how to use it:
from model_utils.model_definitions.my_classifier import MyCoolModule
from argparse import Namespace
checkpoint_path='/home/verena/.../checkpoints/_ckpt_epoch_18.ckpt'
hparams = {
"batch_size":32,
...
}
namespace = Namespace(**hparams)
model = MyCoolModule.load_from_checkpoint(checkpoint_path=checkpoint_path, hparams=namespace)
I faced the same issue, thanks @expectopatronum for the workaround, it helps me a lot.
Here's a solution that doesn't require modifying your model (from #599).
model = MyModel(whatever, args, you, want)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
Hi guys.. it seems that in my case the load_from_checkpoint
function didn't load the params for me..
I had to use the code posted above by @neggert instead. ( manually load_state_dict)
Hope it helps.
solved in 0.7.1
if not, we can reopen
@williamFalcon
Based on the documentation I found (this seemed most relevant) I'm not totally sure best practices for how saving / loading is supposed to work:
I initialized a trainer as follows (disabled tensorboard because it was erring due to TF dep):
trainer = pl.Trainer(gpus=1, val_check_interval=0.25, use_amp=True, logger=False)
Then observed that the models were saved in ./checkpoints/
by default and thus assumed when I restarted training for same dir it would load the weights. But that was not the case (it seemed), instead I got this message:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:82: UserWarning: Checkpoint directory ~/answerbot/accuracy/checkpoints exists and is not empty with save_top_k != 0.All files in this directory will be deleted when a checkpoint is saved!
What is best practice for loading/ saving model with no logger? is this possible? Thank you in advance.
@pertschuk As i remember that auto loading is disabled in latest master, can you check your lightning version?
@Ir1d 0.7.1
@pertschuk I believe that auto restoring is removed, and you should load the weights on your own. The doc you linked is not updated yet.
@Ir1d Is there a callback or function to override to integrate weight loading / saving with PL checkpointing?
For example I'm training a huggingface/transformers model and want to save checkpoints in that format.
@pertschuk sorry, I dont understand transformer model. You see, PL checkpoints is a wrap of a lot of things.YOu can get this by calling dict.keys()
. And you'll find that model.load_state_dict(dict['state_dict'])
is exactly the weight loading for pure pytorch. Hope this helps. Also, the above script by @neggert works perfectly for me .
`
@Ir1d Is there a callback or function to override to integrate weight loading / saving with PL checkpointing?
For example I'm training a huggingface/transformers model and want to save checkpoints in that format.
I also have been using lightning with pytorch transformers. I save checkpoints normally without changing anything in lightning.
If for some reason I need to resume training from a given checkpoint I just use the resume_from_checkpoint
Trainer attribute.
If I just want to load weights from a pretrained model I use the load_weights
flag and call the function load_weights_from_checkpoint
that is implemented in my "base" model.
parser = HyperOptArgumentParser(strategy="random_search", add_help=False)
parser.add_argument(
"--resume_from_checkpoint",
default=None,
type=str,
help=(
"To resume training from a specific checkpoint pass in the path here."
"(e.g. 'some/path/to/my_checkpoint.ckpt')"
),
)
parser.add_argument(
"--load_weights",
default=None,
type=str,
help=(
"Loads the model weights from a given checkpoint while resume_from_checkpoint "
"resumes the entire training session (model/optimizer/scheduler etc..). "
"If architectures are different this will load only the common module parts."
),
)
.....
trainer = Trainer(
logger=setup_testube_logger(),
checkpoint_callback=True,
early_stop_callback=early_stop_callback,
default_save_path="experiments/",
gradient_clip_val=hparams.gradient_clip_val,
gpus=hparams.gpus,
show_progress_bar=False,
overfit_pct=hparams.overfit_pct,
check_val_every_n_epoch=hparams.check_val_every_n_epoch,
fast_dev_run=False,
accumulate_grad_batches=hparams.accumulate_grad_batches,
max_epochs=hparams.max_epochs,
min_epochs=hparams.min_epochs,
train_percent_check=hparams.train_percent_check,
val_percent_check=hparams.val_percent_check,
val_check_interval=hparams.val_check_interval,
log_save_interval=hparams.log_save_interval,
row_log_interval=hparams.row_log_interval,
distributed_backend=hparams.distributed_backend,
precision=hparams.precision,
weights_summary=hparams.weights_summary,
resume_from_checkpoint=hparams.resume_from_checkpoint,
profiler=hparams.profiler,
log_gpu_memory="all",
)
model = build_model(hparams)
if hparams.load_weights:
model.load_weights_from_checkpoint(hparams.load_weights)
log.info(f"{model.__class__.__name__} train starting:")
trainer.fit(model)
My load_weights_from_checkpoint
function:
def load_weights_from_checkpoint(self, checkpoint: str) -> None:
""" Function that loads the weights from a given checkpoint file.
Note:
If the checkpoint model architecture is different then `self`, only
the common parts will be loaded.
:param checkpoint: Path to the checkpoint containing the weights to be loaded.
"""
log.info(f"loading model weights from {checkpoint}.")
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,)
pretrained_dict = checkpoint["state_dict"]
model_dict = self.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
self.load_state_dict(pretrained_dict)
Does this solve your problem of loading pre-trained weights and resuming training sessions?
PS: The initial purpose of this issue was solved some versions ago and it's now working.
@ricardorei yes it does, thank you
EDIT: this seems to be a apex/amp fp16 precision bug
Okay sorry to keep posting here but have run into VERY confusing issue and would appreciate any ideas for guidance @ricardorei. I am trying to export the models to save in huggingface/transformers format for reuse and the saved model appears to have identical state_dict to the model wrapped in a Pytorch Lightning module, but the results of passing the same inputs through are not the same?
import os
os.makedirs('./test-1', exist_ok=True)
# model is PytorchLightning Module and model.model = Transformers model
model.model.save_pretrained('./test-1')
loaded_model = AlbertForSequenceClassification.from_pretrained('./test-1')
loaded_model.cuda()
for k, v in loaded_model.state_dict().items():
assert torch.all(model.model.state_dict()[k].eq(v)) # this assert works
correct = 0
total = 0
def call_model(inputs, model):
return model(inputs['input_ids'].cuda(),
token_type_ids=inputs['token_type_ids'].cuda(),
attention_mask=inputs['attention_mask'].cuda())[0]
for ex in get_data():
label = 0 if ex['is_impossible'] else 1
inputs = tokenizer.encode_plus(ex['question'],
ex['context'],
add_special_tokens=True,
max_length=256,
return_tensors='pt')
lightning_logits = call_model(inputs, model.model)
transformers_logits = call_model(inputs, loaded_model)
assert torch.all(lightning_logits.eq(transformers_logits)) # this assert fails ???
Note: I also tried saving / loading the state_dict for the PytorchLightning module itself and same problem, state dicts match up but different outputs during inference? I'm totally lost.
@pertschuk you should check how big is the difference. I noticed some small differences when using big transformer models. I actually have an issue in lightning regarding the subject and in Fairseq.
https://github.com/pytorch/fairseq/issues/1605
https://github.com/PyTorchLightning/pytorch-lightning/issues/669
If the difference is really small this should not affect your results and is basically a precision issue.
@ricardorei unfortunately it was a very large error but I fixed by disabling mixed precision training if anyone else finds this thread. Frustrating as training is much slower now.... but at least it works!
For those that have the issue of not being able to load the model using the load_from_checkpoint
method, I tried using the workaround here. After playing around, I noticed that there is a problem with the way the state_dict
is being loaded. After using the code snippet from a post on PyTorch's forum, I managed to solve the problem.
To be more specific, the weights were being loaded into the model but there was no error message.
@williamFalcon Please take note of this.
For those that have the issue of not being able to load the model using the
load_from_checkpoint
method, I tried using the workaround here. After playing around, I noticed that there is a problem with the way thestate_dict
is being loaded. After using the code snippet from a post on PyTorch's forum, I managed to solve the problem.To be more specific, the weights were being loaded into the model but there was no error message.
@williamFalcon Please take note of this.
I checked: The problem arises when we are using a self.model
value to define our forward pass and also our parameters.
I encountered the same issue when passing self.model
For those that have the issue of not being able to load the model using the
load_from_checkpoint
method, I tried using the workaround here. After playing around, I noticed that there is a problem with the way thestate_dict
is being loaded. After using the code snippet from a post on PyTorch's forum, I managed to solve the problem.
To be more specific, the weights were being loaded into the model but there was no error message.
@williamFalcon Please take note of this.I checked: The problem arises when we are using a
self.model
value to define our forward pass and also our parameters.
Here's a solution that doesn't require modifying your model (from #599).
model = MyModel(whatever, args, you, want) checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict'])
For some reason even after the fix I am forced to use quoted solution. The normal load_from_checkpoint function still gives me pytorch_lightning.utilities.exceptions.MisconfigurationException: Checkpoint contains hyperparameters but MyModule's __init__ is missing the argument 'hparams'. Are you loading the correct checkpoint?
solved in 0.7.1
if not, we can reopen
My version of PL is 0.7.6
solved in 0.7.1
if not, we can reopenMy version of PL is
0.7.6
mind try v0.8rc1 or latest master
?
Most helpful comment
Here's a solution that doesn't require modifying your model (from #599).