Pytorch-lightning: Why is `accumulate_grad_batches` set to 1 for first epoch when the argument is provided as int?

Created on 14 Jul 2020  路  7Comments  路  Source: PyTorchLightning/pytorch-lightning

馃殌 Feature

Is there any reasons why is it set to 1 for first epoch?
I think it should be set to the number users specify because of a lot of confusing.

Alternatives

Change the key of schedule dict to 0 in training_tricks.py:

def configure_accumulated_gradients(self, accumulate_grad_batches):
    if isinstance(accumulate_grad_batches, dict):
        self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
    elif isinstance(accumulate_grad_batches, int):
        schedule = {1: accumulate_grad_batches}  # => schedule = {0: accumulate_grad_batches}
        self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
    else:
        raise TypeError("Gradient accumulation supports only int and dict types")
question

All 7 comments

I think you misunderstand.
This value there means accumulate from epoch 0 on by this many batches.
https://pytorch-lightning.readthedocs.io/en/0.8.5/api/pytorch_lightning.trainer.html#accumulate-grad-batches
I don't see an error in the code.

By default, if nothing specified by user, the dict will be {0: 1} (the first elif)

If users specify the accumulate_grad_batches argument as int, schedule becomes {0: 1, 1:n}: https://github.com/PyTorchLightning/pytorch-lightning/blob/1d565e175d98103c2ebd6164e681f76143501da9/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py#L43-L50

It means the accumulation will be done from the second epoch.
I've checked it using debugger and trainer.accumulate_grad_batches is set to 1 in the first epoch.

I'm sorry, I still don't get it. I cannot reproduce this. Here is what I tried in console:

>>> x = Trainer(accumulate_grad_batches=3)
>>> x.accumulation_scheduler.scheduling
{0: 3}

I'm using the latest master here

ohhh I missed it, I saw the wrong versions sorry.
It is a bug on previous version, so upgrading 0.8.4 -> 0.8.5 will help 馃槃 .

I found it this PR fixed this issue: https://github.com/PyTorchLightning/pytorch-lightning/pull/2513/files
close here 馃檹

I should have known.. I already forgot about this bugfix. Glad you found it.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mmsamiei picture mmsamiei  路  3Comments

srush picture srush  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

versatran01 picture versatran01  路  3Comments

monney picture monney  路  3Comments