Pytorch-lightning: Saving / loading hparams type in checkpoint

Created on 24 Oct 2020  ยท  22Comments  ยท  Source: PyTorchLightning/pytorch-lightning

๐Ÿ› Bug

There seems to be an issue with saving or loading hyperparams type in checkpoints. Related to #924 (unrelated to #3998).

Please reproduce using the BoringModel and post here

Here is the BoringModel gist using the snippet from @awaelchli in #924
(running in colab has been unreliable, see #3998)
https://gist.github.com/chiragraman/c235f6f2a25b2432bde1a08ae6ed1b03

Behavior

With my local master at 155f4e9a, I'm getting the following error:

type of hparams <class 'pytorch_lightning.utilities.parsing.AttributeDict'>
class of hparams type AttributeDict
accessing hparams 1 no problem
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.
  warnings.warn(*args, **kwargs)
/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 0: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:00<00:00, 427.58it/s, loss=2.280, v_num=11]
type of hparams <class 'pytorch_lightning.utilities.parsing.AttributeDict'>
class of hparams type AttributeDict
Traceback (most recent call last):
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py", line 160, in __getattr__
    return self[key]
KeyError: 'something'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "ckpt-loading.py", line 118, in <module>
    run_test()
  File "ckpt-loading.py", line 106, in run_test
    model = BoringModel.load_from_checkpoint(ckpt_path)
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 154, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 198, in _load_model_state
    model = cls(**_cls_kwargs_filtered)
  File "ckpt-loading.py", line 40, in __init__
    print("accessing hparams", self.hparams.something, "no problem")
  File "/home/chirag/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py", line 162, in __getattr__
    raise AttributeError(f'Missing attribute "{key}"') from exp
AttributeError: Missing attribute "something"

Also, for isolating the problem, look at this code:

https://github.com/PyTorchLightning/pytorch-lightning/blob/207ff728c940ff7d8bb317a83d22378b759c9292/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L297-L306

If you add a breakpoint or log line for debugging after the inner if, so that:

if isinstance(model.hparams, Container):
    print("Saving hparams type")
    checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)

you should see that it doesn't get triggered.

Environment

  • PyTorch Version : 1.0.2
  • PyTorch Version (e.g., 1.0) : 1.6
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda

    • Python version: 3.8.5

documentation help wanted

All 22 comments

Did you print the type of model.hparams elsewhere? Is it being converted silently away from omegaconf?

Looking into this right now, breakpointing in the code over in saving.py to look at the stored type.

[Edit:] Okay, so pulled from remote so that my head is at 207ff728, and I can at least get around the crash. The key cls.CHECKPOINT_HYPER_PARAMS_TYPE is indeed None however. From the code I think that it's effectively not needed unless omegaconf is being used. I now have the following output omitting the warnings for brevity:

type of hparams <class 'pytorch_lightning.utilities.parsing.AttributeDict'>
class of hparams type AttributeDict
accessing hparams 1 no problem

Epoch 0: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:00<00:00, 324.05it/s, loss=0.174, v_num=16]
type of hparams <class 'pytorch_lightning.utilities.parsing.AttributeDict'>
class of hparams type AttributeDict
accessing hparams 1 no problem
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'callbacks', 'optimizer_states', 'lr_schedulers', 'state_dict', 'hparams_name', 'hyper_parameters'])
hparams name saved to ckpt: hparams
hparams contents saved to ckpt: {'something': 1}
hparams type saved to ckpt: <class 'dict'>

Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
Testing: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 64/64 [00:00<00:00, 7663.23it/s]

However, the type of the hparams passed in is AttributeDict instead of Namespace, so the loading code seems to be ignoring the CHECKPOINT_HYPER_PARAMS_TYPE (also it is None) and passing in the hparams. If the expectation is that the user type is preserved, then I'd say this is a bug, or if the expectation is that any user type will always be switched to an AttributeDict silently, then I think this could be made more explicit in the docs.

Thanks, I'm glad we agree we're not seeing any errors on master branch.

The AttributeDict, as far as I know, was introduced exactly because of this, to have the flexibility of passing in a dict or Namespace.
The setter self.hparams = ... helps taking care of this.
My understanding is that the Lightning recommended way is to pass in hyperparameters as regular arguments and then call
self.save_hyperparameters.
The way of passing in a namespace/dict and then assigning to self.hparams is the "old" way of doing it, which indeed shows exactly the problems you are talking about, about remembering the type. This is the reason why we moved away from single argument hparams namespace, and we don't do any type conversions anymore unless for backward compatibility. I recommend to do just save_hyperparameters.

@Borda you had worked on some fixes for hparams issues, any thoughts on this?

Hmm, I think even before the saving, at the first invocation of __init__ it is already an AttributeDict, as seen from the first set of print statements. So even if one calls save_hyperparameters in init, I think it will only ever receive an AttributeDict no?

I don't think even the latest doc assumes this. From the docstring on save_hyperparameters, it reads:

args โ€“ single object of dict, NameSpace or OmegaConf or string names or argumenst from class __init__

but I don't see how it will ever get Namespace or even dict if called within __init__, leading me to treat this more as a bug.

Also, for the doc on under the section on hyperparams on the latest docs here ( https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html?highlight=save%20hyperparameters ), where it seemed like still a valid and supported way of saving hyperparams, so maybe that should just be dropped if it's not encouraged.


From a design stand point, I think it's a little complicated that the type the user passes in is somehow changed when being accessed in the init. I think transparency might be a wiser design choice, where any intermediate type wrangling is abstracted away from the end user. Stating my subjective opinion on the design here to hopefully foster a discussion.

@awaelchli wdyt about making hparams into a model hook for people to explicitly save and load? i think the automatic deduction of args is really tricky and runs into all of these edge cases

That seems like the wrong design choice - the args you pass in are presumably used in other hooks as well, so it makes sense to pass it into the init. I think the simpler fix here would be to save the type regardless of if omegaconf is available or not. That nested if seems like incorrect branching.

I'd assume it is okay to always have the type saved in the ckpt whether or not omegaconf is available.

[Edit] What I meant to say is hooks ought to serve the purpose of allowing behavior that is plugged in and sort of self contained, not for allowing data that is accessed in the rest of the system. Seems like a violation of the dependency-inversion principle. The module shouldn't depend on the hook. The hook should depend on the module.

@chiragraman Mabe we need to do this?

        if model.hparams:
            if hasattr(model, '_hparams_name'):
                checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
            # add arguments to the checkpoint
            if OMEGACONF_AVAILABLE:
                checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
                # if isinstance(model.hparams, Container):
                #     checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
            else:
                checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

            # I moved this line out of the inner if-statements
            checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)

sorry for the slow resolution of this problem, but my head is a bit twisted looking at all this hparams logic, I am not sure what to do tbh. Even with this change, the type after reloading will still be AttributeDict:

def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object:
    """Convert hparams according given type in callable or string (past) format."""
    # if not hparams type define
    if not hparams_type:
        return model_args
    # if past checkpoint loaded, convert str to callable
    if isinstance(hparams_type, str):
        hparams_type = AttributeDict
    # convert hparams
    return hparams_type(model_args)

and the setter of LightningModule.hparams always converts dict and Namespace to AttributeDict. Looking at these different places, it really looks like this is intentional to have AttributeDict as a common type to handle both dict and Namespace.

decision: update docs to use save_hparams and not the previous deprectaed way.

open question: do we need to keep backward compatibility for the old way? @Borda

@awaelchli I think your fix would be a welcome partial change already. I think the AttributeDict discussion is a more detailed one in terms of designing for transparency, but I think the fix would definitely be an improvement.

@edenlightning the previous way isn't deprecated. Point 3 here. I think there is something to say about the usability this provides to the end user in the most common use-case. If it is being deprecated, can we plan for it rather than deprecating it immediately? Also, is there a reason to drop it altogether rather than try and fix the issues?

@chiragraman I'm sorry but my fix doesn't check out. We're circling back to the time when we attempted to fix the old way of storing hyperparameters. The AttributeDict is a _workaround_ to make both dict and Namespace users happy. The real fix, and that is what we came up with, is to separate two things:

  1. access of hyperparameters within a LightningModule
  2. saving the hyperparameters to the checkpoint

save_hyperparameters() takes care of 2. Optionally, it can also take care of 1) if we pass in the names, but by default, we let the user decide about 1.
This is the recommended way and Point 3 in the docs you pointed out should be labelled as deprecated. I will send a PR asap.

I think there is something to say about the usability this provides to the end user in the most common use-case. If it is being deprecated, can we plan for it rather than deprecating it immediately?

This usability is guaranteed. Your use case is valid and will not be broken.

@awaelchli thanks for the explanation! I think the idea behind the implicit saving was a great feature, but I can see it would need more work to execute correctly. In some measure this would entail keeping AttributeDict as an internal type to be used only in library code while preserving the original type in user code. However, I think a consistent solution is indeed better, and I think explicitly asking the user code to save params to the checkpoint isn't a massive hit in usability, so this is great.

1) if we pass in the names, but by default, we let the user decide about 1.

Just wanted to confirm what this means for AttributeDict? I can see perhaps how the workaround came about, since the docs still assumes that the params could be multiple types historically, but then this type was perhaps introduced and converted in a single place instead of ensuring preservation of the original type in user hooks. If saving is going to be made an explicit responsibility of the user code, may I request reverting to letting the original type be preserved?

Just wanted to confirm what this means for AttributeDict?

After the user calls self.save_hyperparameters(), they can access each parameter through self.hparams, if they want to, but this is optional. In the context of save_hyperparameters, I believe this is the only use for AttributeDict.

If saving is going to be made an explicit responsibility of the user code, may I request reverting to letting the original type be preserved?

I'm not sure what you mean. With the save_hyperparameters method, we preserve all types. The container that holds all parameters is a AttributeDict, and this gets stored to the checkpoint. So, for example:

def __init__(self, arg1: type1, arg2: type2):
    super().__init__()
    self.save_hyperparameters()
    self.arg1 = arg1
    self.arg2 = arg2

You train your model with Lightning then inspect the checkpoint. You will find:

assert isinstance(checkpoint["hyper_parameters"], AttributeDict)
assert isinstance(checkpoint["hyper_parameters"]["arg1"], type1)
assert isinstance(checkpoint["hyper_parameters"]["arg2"], type2)

I was thinking about this a bit more and there is a strong argument against having AttributeDict in the checkpoint.
Since it is a pytorch lightning construct, it requires pytorch lightning in the environment when unpickling the checkpoint.
This is a problem when one tries to load a checkpoint with pure torch.load in an environment where Lightning is not installed.

Hi Adrian,

Exactly! Your last comment is closer to what I was alluding to. If the design goal is that lightning is a thin platform on top of pytorch (which I think is one of its advantages compared to other libraries), then any custom types intended as an internal type should be transparent to the user code, not only in the serialized checkpoint. I ought to have been clearer about my question, this is what I meant:

  1. access of hyperparameters within a LightningModule
  2. saving the hyperparameters to the checkpoint

So if explicitly calling save_hyperparameters (by the user) takes care of 2., is the user free to use whatever type they want for 1., like so:

def __init__(self, hparams: type1, arg2: type2):
    super().__init__()
    self.save_hyperparameters()
    self.hparams = hparams
    self.arg2 = arg2

In the above, code, what would be the type of hparams? I assumed that in the new change that should be type1 (be that Namespace or dict or anything else), as opposed to the current case where that is changing into AttributeDict. Just wanted to confirm this.

In the above, code, what would be the type of hparams?

Currently it will have type AttributeDict, you are right, but only because Lightning offers this as a "feature" that all arguments collected with save_hyperparameters are accessible via self.hparams. I think the example you just made is interesting, because practically, the two ways self.save_hyperparameters and self.hparams = hparams are mutually exclusive. They are not meant to be mixed.

Here is what happens if you don't have self.hparams = hparams:

def __init__(self, hparams: type1, arg2: type2):
    super().__init__()
    self.save_hyperparameters()
    # self.hparams = hparams
    self.arg2 = arg2

def some_method(self):
    x = self.hparams.hparams.batch_size

Why hparams.hparams? Because hparams itself is now considered a hyperparameter. Everything passed into the model as argument is considered a hyperparameter. This means that it is possible to mix _old way_ with the _new way_, but it is not what we should do!

What we expect to see in the old way:

def __init__(self, hparams):  # pass in a single container with all hyperparameters
    super().__init__()
    # self.save_hyperparameters()   <- does not exist in "old way", forbidden!
    self.hparams = hparams  # <- mandatory

def some_method(self):
    x = self.hparams.batch_size  # works (self.hparams is a AttributeDict)

How we expect the new way to work:

def __init__(self, batch_size, arg2, arg3, ...):  # individual hyperparameters
    super().__init__()
    self.save_hyperparameters() #   <- mandatory 
    # self.hparams = hparams  # forbidden

def some_method(self):
    x = self.hparams.batch_size  # works (self.hparams is a AttributeDict)

While self.hparams = hparams is "forbidden", self.hparams is still a getter method so the user is free to access their hyperparameter through this object. This is simply a luxury of not having to init all arguments via

self.arg1 = arg1
self.arg2 = arg2

but the user is still free to do this if they want and they can completely ignore self.hparams.
So, to answer your question:

is the user free to use whatever type they want for 1.,

Yes! The only restriction is they can't use the variable self.hparams = with an assignment because this is reserved for the old way of managing hyperparameters.

Conclusion: In the new way, we want to put all hyperparameters into the signature of init. Then we use the convenience function self.save_hyperparameters and all elements are available under self.hparams.

What if I have too many hyperparameters and I can't write them all into the signature?
Easy enough:

def __init__(self, **kwargs):
    super().__init__()
    self.save_hyperparameters()

def some_method(self):
    x = self.hparams.batch_size  # works

I think a good way to move forward is to

  1. remove the setter for self.hparams (only provide the getter)
  2. store hyperparameters in a dict in checkpoint, not AttributeDict.
  3. remove references to old way of saving and loading hyperparameters in docs, or put warning that it is deprecated

Adrian, thank you so much for taking the time to write all that out! That clears everything up. Haha coming from a C++ background it took me a few years to get used to Python's philosophy of treating coders as adults and letting them shoot themselves in the foot, but given that, can't see anything here that is problematic. Maybe a gradual phasing out the possibility of mixing the old and new ways? But that's only for long term, and perhaps easily handled by adding this to the docs?

Thanks again.

Yes that is true, the Python magic is hiding a bit too much here.

Maybe a gradual phasing out the possibility of mixing the old and new ways?

https://github.com/PyTorchLightning/pytorch-lightning/issues/4333#issuecomment-719422339 what do you think of this plan?

Ah sorry, I think we both posted comments at the same moment and I missed yours.

  1. I have to check the relevant code, but if it's not already, I think making hparams a read-only attribute might be a valid choice to consider here. The save_hyperparameters() is then serving the function of being an explicit setter for the attribute, but in the interest of the single responsibility principle, it should only act as a setter for this. I think your separation of access and serialization responsibilities is great, and in keeping with that, we can perhaps make this function become something like set_hyperparameters() since currently the word save might be misleading in terms of having serialization side effects.

If that sounds good, the added benefit would be that the serialization code (to checkpoint or json or whatever in the future) would only invoke the getter of the attribute and serialize the returned object. That way we also don't have to worry about inheritance chains or parse class specs to see if the thing has been set while serializing, like in #3998

  1. Yes, I agree. dict seems a better option. Only thing I can think of is thinking through the implications of ordering, in case this needs to be an OrderedDict? But in general, a type from the std lib sounds great, I agree.

  2. Yup yup. +1.

@awaelchli I was going to take a crack at points 1 and 2 above, I think the PR only updates the docs right? Wanted to confirm before proceeding, since incorporating the first two points would also solve #3998.

yes, that update to the docs was just a first step. Please feel free to give it a try! Would appreciate the help.

Cool, I'll start a draft PR to discuss the details. Can we leave this issue open till then?

Was this page helpful?
0 / 5 - 0 ratings