Pytorch-lightning: Trainer.scale_batch_size requires model.batch_size instead of model.hparams.batch_size

Created on 3 Jul 2020  路  7Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

Trainer.scale_batch_size only works if a model has the batch_size property and does not work with model.hparams.batch_size even though all documentation points to the reverse.

To Reproduce

All of my hyperparameters are available as model.hparams like suggested in the documentation: (hyperparameters, option 3.
This means that my batch_size is available as model.hparams.batch_size.

This should be fully compatible with the documented example code of Trainer.scale_batch_size() since that code also uses model.hparams.batch_size instead of model.batch_size.

However, when I put my model in Trainer.scale_batch_size, I get the following error:

pytorch_lightning.utilities.exceptions.MisconfigurationException: Field batch_size not found in `model.hparams`

Example code

class LitModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = args

model = LitModel(args)
trainer = Trainer()
trainer.scale_batch_size(model)

Expected behavior

Either Trainer.scale_batch_size should work with model.hparams or the error message, linked documentation examples and docstrings should all change (i.e. here, here and here).

(I would prefer the second option. I think that it should work with both model.batch_size and model.hparams.batch_size.)

Environment

  • pytorch-lightning 0.8.4
bug / fix good first issue help wanted

Most helpful comment

A clean fix on the user side while waiting for the PR is to actually use self.hparams.batch_size and define self.batch_size as a property of your module:

@property
def batch_size(self):
    return self.hparams.batch_size

@batch_size.setter
def batch_size(self, batch_size):
    self.hparams.batch_size = batch_size

That way you keep your hyper parameters together in case you want to dump them somewhere without having to add specific code.

All 7 comments

Hi! thanks for your contribution!, great first issue!

it seems like a nice first issue, @wietsedv mind send a PR? :rabbit:

Appears to be the same with the learning rate parameter.

A clean fix on the user side while waiting for the PR is to actually use self.hparams.batch_size and define self.batch_size as a property of your module:

@property
def batch_size(self):
    return self.hparams.batch_size

@batch_size.setter
def batch_size(self, batch_size):
    self.hparams.batch_size = batch_size

That way you keep your hyper parameters together in case you want to dump them somewhere without having to add specific code.

From #1896 it seems that the problem is rather on the docs side than scale_batch_size()'s implementation.
self.batch_size is the correct location to look for the parameter, not self.hparams.batch_size.
My above fix is thus also obsolete.

From #1896 it seems that the problem is rather on the docs side than scale_batch_size()'s implementation.
My above fix is thus also obsolete.

I just tried your fix and it seemed to work :)

Yes it does work, but from what they said in the PR I linked, hparams was just there as a temporary solution, and all hyper parameters are intended to be set as direct instance attributes in __init__.
My fix is obsolete regarding the intended usage of PL.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

remisphere picture remisphere  路  3Comments

monney picture monney  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

justusschock picture justusschock  路  3Comments

chuong98 picture chuong98  路  3Comments