Pytorch-lightning: Error in `load_from_checkpoint` when LightningModule init contains 'hparams'

Created on 8 Oct 2020  Β·  25Comments  Β·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

What is your question?

Just pulled master today, and load_from_checkpoint no longer works. I am wondering if this is a backwards compatibility issue, or I need to do something differently to load now? I've provided the HEAD commits of my local repo I'm installing from in the environment section below.

Error in loading

Here's the error I'm getting now without changing any code:

File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 193, in _load_model_state
    model = cls(**_cls_kwargs)
TypeError: __init__() missing 1 required positional argument: 'hparams'

The lightning module has always had a member variable hparams, so everything should have been saved (and indeed worked before this pull) in accordance with the docs on stable: https://pytorch-lightning.readthedocs.io/en/stable/hyperparameters.html#lightningmodule-hyperparameters

What's your environment?

❯ git rev-parse --short HEAD                                             
1d3c7dc8    # DOES NOT WORK

❯ git rev-parse --short HEAD@{1}                                          
90929fa4    # WORKS OKAY
question

All 25 comments

can you put your LightningModule init code and load_from_checkpoint line here?

Hi Rohit,

Sure. Here's the init code:

class ProcessSystemBase(pl.LightningModule):

    def __init__(self, hparams: Namespace, builder: RecurrentBuilder) -> None:
        """ Initialize the Process module.

        Args:
            hparams             -- The hyperparameters for the experiment
            builder                -- The factory to build backbone components

        """
        super().__init__()

        ##----------------------------------------
        # Hack to combat lightning loading hparams as dicts
        # Refer issue #924 
        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)
        ##----------------------------------------
        self.hparams = hparams
        self.components = builder.init_components(hparams)

I have a factory object called builder that handles custom logic for initializing components but I am sure the details of it aren't relevant here. The conversion to Namespace is needed following #924 where the hparams type is not saved in the checkpoint.

And here is how I load it:

builder = RecurrentBuilder()
process = ProcessSystemBase.load_from_checkpoint(
        str(ckpt_path), builder=builder
)
process.freeze()

ckpt_path is a pathlib.Path object containing the filepath to the checkpoint object. This worked before my pull as mentioned with the rev-parse outputs in the original post.

Thanks for taking a look @rohitgr7

@rohitgr7 did you manage to find anything on this? I tested again by training again and loading and still run into the same problem, so it doesn't seem like a backwards compatibility issue anymore

ah sorry, this issue totally slipped my mind. can you reproduce this issue with bug_report_model?
Would be easier to check the issue :)

I was blocked on this so dove in and found the issue. It's essentially in these lines (also on master so I assume this applies to 1.0 as well) :

https://github.com/PyTorchLightning/pytorch-lightning/blob/01402e35948620e4ff77922ec30becf1f9b2d2fb/pytorch_lightning/core/saving.py#L189-L191

So you see here that when _cls_kwargs is updated, it doesn't maintain the rest of the keys as hparams. In my case above, the cls_init_args_name contains ['hparams', 'builder'] as expected. But after filtering, the kwargs only contains the key ['builder']. I fixed this locally by explicitly checking if cls_init_args_name contains hparams and passing in the rest of the keys as the hparams dict. Of course, ideally this needs to be converted to Namespace if that's how it was saved ( see #924 , which is still a problem anyways, my checkpoint didn't have the hparams type saved). So here is my inelegant fix (this is ugly, meant to debug, if you want me to issue a PR I will come up with a better solution):

    if not cls_spec.varkw:
        # filter kwargs according to class init unless it allows any argument via kwargs
        _cls_kwargs_filtered = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
        if "hparams" in cls_init_args_name:
            _cls_hparam_kwargs = {k: v for k, v in _cls_kwargs.items() if k not in cls_init_args_name}
            _cls_kwargs_filtered["hparams"] = _cls_hparam_kwargs

    model = cls(**_cls_kwargs_filtered)

So here is my question @rohitgr7 , the docs suggest that assigning anything to hparams in init implies that it gets automatically saved. Am I mistaken in assuming that implies that it will also be loaded correctly and assigned to hparams? If so, I think the problem is that the loading code doesn't seem to handle this at all. The bug_report_model doesn't have an hparams in the init args, so you won't be able to see this problem there as is.

@chiragraman I have created a simple version of your workflow, can you please check this is what you are trying to do? If yes, I am not getting any error with this so maybe some backward compatibility issue is there in the new update.
https://gist.github.com/rohitgr7/f8637d73ed0840e771af8757ba8f34c5

@rohitgr7 thanks for this! I ran it and it matches the behavior you're seeing, so I went back and debugged a bit more. The only difference I found between my actual code and the simple version you created is that the value for hparams_name in my checkpoint is missing, while in the simple version it correctly returns the value hparams:

# My code:
(Pdb) checkpoint["hparams_name"]
(Pdb)

# Simple setup
(Pdb) checkpoint["hparams_name"]
'hparams'
(Pdb)

The debugger here is in the _load_model_state method. But all this definitely makes things a lot clearer. Because there is no hparams value in the dictionary, the problem arises from here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/01402e35948620e4ff77922ec30becf1f9b2d2fb/pytorch_lightning/core/saving.py#L181-L183

This is where the reassignment I'm doing in my debug fix above is supposed to happen. But that if check for args_name fails and the whole thing is skipped. I have no idea currently why the key exists but the value is empty in the checkpoint. If this further isolation helps and triggers an epiphany do let me know :) Thanks for the help Rohit

@rohitgr7 Yo Rohit, I found the problem and can reproduce it with the bug report model. The issue is that saving the value for cls.CHECKPOINT_HYPER_PARAMS_NAME to checkpoint fails for subclassed lightning modules.

The hparams_name is set by looking for ".hparams" in the class spec. This will obviously fail if your LightningModule is subclassed from a parent LightningModule that handles the assigning of the hparams member variable in the parent __init__. Consequently when the child is loaded from checkpoint the key for cls.CHECKPOINT_HYPER_PARAMS_NAME exists int eh checkpoint but has no value.

I don't know if subclassing your custom LightningModule is not supported by design? I think sometimes the right encapsulation would require this, such as in my case, and most probably if someone is writing code with any intention for it to be used by other developers in a real world production setting. If it is supported, I think this part would have to be updated to find whether the .hparams has been assigned in a more elegant way:

https://github.com/PyTorchLightning/pytorch-lightning/blob/5c1eff351b035db5881d0cff81b1d9c9e150e2d0/pytorch_lightning/core/lightning.py#L1631-L1636

(maybe @williamFalcon can weigh in on the design goals question?)


So for a quick recap, here is the error in loading:

File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 197, in _load_model_state
    model = cls(**_cls_kwargs_filtered)
TypeError: __init__() missing 1 required positional argument: 'hparams'

And here is the debug model code to reproduce it: https://gist.github.com/chiragraman/ff0e6b0239cb635df30093dd31ae3dce

@chiragraman which version of Lightning are you running your bug report model on?

@rohitgr7 Just pulled master yesterday, I'm updated to 1.0.1

Screenshot from 2020-10-17 01-55-27
I tried your code on both master and 1.0.1, but not getting any error. Am I missing something here?

Hmm, just ran it again and it crashed again. Here's the HEAD for my local lightning repo:

❯ git rev-parse --short HEAD                                      
f967fbba

Hmm, if you add a breakpoint right before this,

https://github.com/PyTorchLightning/pytorch-lightning/blob/155f4e9aa5d6175f8d3db23d483154b065412cd0/pytorch_lightning/core/saving.py#L194

can you verify if the checkpoint contains any value for "hparams_name"? In the meanwhile let me pull master again and retry.

can you verify if the checkpoint contains any value for "hparams_name"

yes, it does hparams. Can you create a colab maybe? Would be better I guess to see that's hapenning here.

Hmm, just updated master so that my local HEAD is at 155f4e9a. Can confirm it still crashes. If it helps at all, here is the entire console dump. I also can't see how this would work given how .hparams doesn't exist in the lines parsed for ChildLitModel, this:

https://github.com/PyTorchLightning/pytorch-lightning/blob/5c1eff351b035db5881d0cff81b1d9c9e150e2d0/pytorch_lightning/core/lightning.py#L1635

has to fail and the function would return None.

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.
  warnings.warn(*args, **kwargs)

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7 K   
/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 0:  92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 11/12 [00:00<00:00, 108.34it/s, loss=2.013, v_num=5]Epoch 0: val_loss reached 1.82315 (best 1.82315), saving model to /home/chirag/Projects/test/lightning_logs/epoch=0-val_loss=1.82.ckpt as top 3
Epoch 1:  83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆEpoch 1: val_loss reached 1.76090 (best 1.76090), saving model to /home/chirag/Projects/test/lightning_logs/epoch=1-val_loss=1.76.ckpt as top 3
Epoch 1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 12/12 [00:00<00:00, 131.93it/s, loss=1.860, v_num=5]
Traceback (most recent call last):
  File "lit.py", line 109, in <module>
    new_model = ChildLitModel.load_from_checkpoint(ckpt_path, another_param='something2')
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 154, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 194, in _load_model_state
    model = cls(**_cls_kwargs)
TypeError: __init__() missing 1 required positional argument: 'hparams'

@rohitgr7 I created a colab here: https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3?usp=sharing

But it works there, and I really don't know how for now :/

it's because of notebook env I think. I tried it as a python script and it's throwing the error but in a notebook it isn't. What's happening here haha :sweat_smile:?

Phew, I was starting to lose my mind. Yeah, it works in that colab I shared. Haha I think the script behaviour is at least explainable looking at the code, I would expect the hparams value to be missing.

yep, true.

@rohitgr7 have you already started working on this? So adding a hparams argument to the bug report model will reproduce the issue?

@awaelchli I haven't started working on it since I don't know the best way to solve it.
@chiragraman has linked a code sample similar to bug_report_model to reproduce this issue along with some details on why this is happening https://github.com/PyTorchLightning/pytorch-lightning/issues/3998#issuecomment-710517824.

PS: don't run this code in a notebook since it won't throw any error.

I want to point out the following in your reproducible code that you posted:

class LitModel(pl.LightningModule):
    def __init__(self, hparams, another_param):
        super().__init__()

        # if isinstance(hparams, dict):
        #     hparams = Namespace(**hparams)

        # this is recommended Lightning way
        # it saves everything passed into __init__
        # and allows you to access it as self.myparam1, self.myparam2
        self.save_hyperparameters()

        # this is optional
        # only needed if you want to access hyperparameters via
        # self.hparams.myparam
        self.hparams = hparams

        # this is optional, if you call self.save_hyperparameters(), it's done for you
        self.another_param = another_param

        self.l1 = torch.nn.Linear(28 * 28, 10)

with save_hyperparameters(), this works smoothly,

@awaelchli you mean this in the context of the #924 / #4333 discussion we're having right? The two lines you've commented out were just there for working around the Namespace issue. You can still reproduce the issue with the inheritance chain though?

Yes, and I commented out to show that with save_hyperparameters() function you don't need the workaround.

You can still reproduce the issue with the inheritance chain though?

No. Your child module has the same args as the parent one, so they are passed correctly. Only if you have different args in the child module, you need to call self.save_hyperparameters() there too.

Aaah okay makes sense. Using your recommended method avoids the offending code altogether here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/5c1eff351b035db5881d0cff81b1d9c9e150e2d0/pytorch_lightning/core/lightning.py#L1631-L1636

So it makes sense that it works. However, the code above doesn't (and wouldn't) work for the setup in the reproducible example. If the recommended way is to explicitly save hyperparameters, then assigning to .hparams should altogether be dropped in my opinion, and this code above be removed to avoid unnecessary code bloat.

So in the spirit of offering the start of a solution instead of simply pointing out the problem, here is my opinion. The user code should be decoupled from the library code. If the way forward is to use save_hyperparameters() I think it would be more elegant to use it internally as well in the library based on a condition even if the user hasn't. We want to protect the usability at the user level that resulting from simply being able to assign to hparams.

The problem lies in what this condition should be. I'd contend that it's a little hacky to parse the class spec to see if the string ".hparams=" exists in there, which is doomed to fail the minute you use inheritance. If the same check needs to be upheld for some reason, one would have to walk up the inheritance chain to check it, which seems even more inelegant and might cause issues with complex subclassing patterns. So why is good old hasattr() not a good option here? (See the discussion here: https://hynek.me/articles/hasattr/

[Edit] So the idea here is that save_hyperparameters by default would save all the parameters assigned to hparams by the user. But the user can still alter this default behavior by calling it explicitly to save specific params only. That makes more sense to me from the perspective that in the most common use-case the end-user need to worry about making an explicit call to save params.

Was this page helpful?
0 / 5 - 0 ratings