Version 0.8.2 and above changed the behavior of either my learning rate scheduler or the WandbLogger logger. I am using a linear warmup and decay scheduler. However, the learning rate graph produced by the LearningRateLogger is as shown below ever since version 0.8.2:

The period where the learning rate is zero corresponds to the last epoch of training as you can see below:

This graph raises another issue. The first epoch appears to take twice as many steps as the second and third epoch. I specified max_epochs=3. During training, each epoch takes the same amount of time, so this seems like a logging issue.
Note that the above graphs are for a model that had its training stopped early. So the last epoch is slightly shorter than the second to last. This is not the issue.
Both of these issues (the 0 learning rate and the twice-as-long epoch) do not exist in version 0.8.1, and both graphs look as they should.
These issues could be caused by the logger or they might actually occur and be logged correctly. I have looked through the changelog and I am guessing that these bugs are caused by "Changed epoch indexing from 0 instead of 1" (#2289). I also may be relying on the fact that epoch indexing started at 1 somewhere in my code, but I do not believe this to be the case.
Reproducing this problem may be difficult since I can't provide the script and data I used. I used the WandbLogger logger and LearningRateLogger callback. I trained with 1400 warmup steps and accumulate_grad_batches set to 2.
I can provide additional code samples or information that you may need.
def lr_lambda_func(current_step, num_warmup_steps, num_training_steps):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0,
float(num_training_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
t_total = int(len(self.train_dataloader_object) * self.hparams.max_epochs // self.hparams.accumulate_grad_batches)
lr_lambda = partial(
lr_lambda_func,
num_warmup_steps=self.hparams.warmup_steps
* self.hparams.accumulate_grad_batches,
num_training_steps=t_total,
)
scheduler = LambdaLR(optimizer, lr_lambda, -1)
scheduler_dict = {"scheduler": scheduler, "interval": "step"}
return ([optimizer], [scheduler_dict])
The learning rate should warmup and decay in versions greater than 0.8.2 the same way it does in versions less than 0.8.2. Each epoch should be the same number of steps.
The below graphs highlight the expected behavior. They are from a different model so they are not directly comparable, but their shape is as expected since they were captured from a model trained with pytorch_lightning version 0.8.1.


Hi! thanks for your contribution!, great first issue!
it would be good to know whether this can be observed with the other loggers as well. Could you run your example also with TensorboardLogger?
Hey! I believe problem lies in configure_accumulated_gradients() when accumulate_grad_batches is integer, scheduler is set to use it from current_epoch=1, but Trainer starts from current_epoch=0, so trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) sets accumulate_grad_batches to default value (1) for this epoch.
Update: When accumulate_grad_batches is an integer, Scheduler gets {1: accumulate_grad_batches} as input, and then scheduling.update({0: 1}) inserts "default" 1 for first epoch.
@HHousen You could do workaround and set accumulate_grad_batches={0: <your value>} in pl.Trainer (I did so), but you might have problems with restoring from checkpoint, as
n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches
expected_steps = self.num_training_batches / n_accum
in restore_training_state will try to use dict in division.
@szymonzareba Yep, setting args.accumulate_grad_batches to {0: 2} fixed this problem (I create my pl.Trainer like so: trainer = Trainer.from_argparse_args(args)). Both the learning rate and epoch graphs are now correct. It seems like your reasoning is correct.
@HHousen mind send a PR?
Sure
tests/trainer/test_lr_finder.py::test_accumulation_and_early_stopping test. See #2490 for more information.