Pytorch-lightning: Schedulers like get_linear_schedule_with_warmup need access to the length of the train dataset

Created on 4 Mar 2020  路  14Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

If you're using a lr scheduler that needs access to the number of batches in the train dataset like @huggingface's get_linear_schedule_with_warmup, there's currently no way to access the dataset in configure_optimizers() because it looks like it is called before train_dataloader().

It would be nice to have some way to load the datasets before the optimizers and make the dataset available to other methods with something like self.train_dataset = train_dataset.

Code sample:

train_steps = int(len(train_dataset) / (batch_size * grad_steps) * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=
                      int(0.1 * train_steps), num_training_steps=train_steps)
enhancement help wanted won't fix

Most helpful comment

Any chance we can revive this issue?
I think this feature still needs to be supported, especially when you are doing distributed_backend="ddp|ddp2|horovod".

In this case

total_steps = len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs // num_distributed_processes

And num_distributed_processes is usually not specified in the arguments if running on a SLURM cluster. In addition, when users choose different distributed backend (e.g. ddp v.s. horovod), the method to get this num_distributed_processes will also differ (or you can get it from the trainer).

I agree with @SkafteNicki that it's bad to pass the trainer into configure_optimizers(self, trainer).
What I'm imagining is maybe we can provide a special lambdaLR scheduler class that will be configured in the training loop, so that total_steps can be passed in as a parameter to the lambda.

All 14 comments

Well you can pass train dataset or loader in constructor, so it will be available as a field, any reason not to do so?

I would support this change - as far as I know it should just change the order in which the dataset / optimizer methods are called w/o impacting anything else.

Another note on these schedulers is that I believe you have to override optimizer_step to step the scheduler on each update, rather than per-epoch (which is the only thing lightning supports as of now).

@rmrao It seems latest pl supports per-step lr schedulers as in #941 .

Ah ok great!

Fixed with #941, closing

I think that this issue should not be closed yet. From what I can see in the PR for #941 it started to support granular LR stepping, but it does not cover usage for something like the get_linear_schedule_with_warmup mentioned in the first post of this issue as creating such schedule requires access to number of epochs (or total number of steps). Am I missing something?

@SkafteNicki pls ^^

Agree that PR #941 only covers the granular LR stepping.
Regarding the case with get_linear_schedule_with_warmup, I do not think lightning needs a specific feature to support this, since the user can already achieve this with a bit of code:

def configure_optimizers(self):
      optimizer = ...
      train_steps = len(self.training_dataloader()) * self.hparams.max_epochs
      lr_scheduler = get_linear_scheduler_with_warmup(optimzer, num_warmup_steps=
                      int(0.1 * train_steps), num_training_steps=train_steps

where the trainer is then initialized with pl.Trainer(max_epochs=self.hparams.max_epochs) (and probably early stopping disabled).
If we want this to be a fully supported feature, then we need to expose the trainer to the model i.e. make it an argument to configure_optimizers(self, trainer), which I do not think is a good idea.

@SkafteNicki thanks for the response. That's almost exactly what I did:

    @lru_cache()
    def total_steps(self):
        return len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.hparams.lr)
        lr_scheduler = get_linear_schedule_with_warmup(
                    optimizer,
                    num_warmup_steps=self.hparams.warmup_steps,
                    num_training_steps=self.total_steps(),
        )
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]

If that's the "recommended" way of doing it then I'm fine with that :)

Any chance we can revive this issue?
I think this feature still needs to be supported, especially when you are doing distributed_backend="ddp|ddp2|horovod".

In this case

total_steps = len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs // num_distributed_processes

And num_distributed_processes is usually not specified in the arguments if running on a SLURM cluster. In addition, when users choose different distributed backend (e.g. ddp v.s. horovod), the method to get this num_distributed_processes will also differ (or you can get it from the trainer).

I agree with @SkafteNicki that it's bad to pass the trainer into configure_optimizers(self, trainer).
What I'm imagining is maybe we can provide a special lambdaLR scheduler class that will be configured in the training loop, so that total_steps can be passed in as a parameter to the lambda.

I am not sure how deep this should be integrated into lightning, it is after all a feature for specific types of schedulers (those who rely on knowing the total number of steps) in the specific case where the user do not know (in advance) how many distributed processes it gets allocated. I will let this be up to the core team.

That said the Callback system in lightning already allows for doing this.
In the configure_optimizers method we start out by only defining the variable we know at that point using functools.partial:

def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    scheduler = partial(transformers.get_cosine_schedule_with_warmup, 
                        optimizer=optimizer, num_cycles=0.5)
    return [optimizer], [{'scheduler': scheduler}]

Then we define a callback that will pass in the remaining variables that are available through the trainer in the on_train_start method (so before any training actually starts)

class TransformerLrScheduler(pl.Callback):
    def on_train_start(self, trainer, pl_module):
        for lr_scheduler in trainer.lr_schedulers:
            if hasattr(lr_scheduler['scheduler'], '__call__'):
                scheduler = lr_scheduler['scheduler']

                n_train = len(pl_module.train_dataloader())
                n_accumulate_grad = trainer.accumulate_grad_batches
                n_max_epochs = trainer.max_epochs
                n_devices = trainer.num_gpus # or trainer.tpu_cores if tpu or 1 if cpu

                num_training_steps = n_train // n_accumulate_grad * n_max_epochs // n_devices
                num_warmup_steps = int(0.1*num_training_steps)

                # Here we actually define the lr schduler
                scheduler = scheduler(num_warmup_steps=num_warmup_steps, 
                                      num_training_steps=num_training_steps)

                lr_scheduler['scheduler'] = scheduler

We of cause then initialize the trainer with pl.Trainer(callbacks=[TransformerLrScheduler()])

@williamFalcon thoughts?

I think that we can easily fix the first mentioned issue, if just configure_optimizers() will be called after train_dataloader(). In that case, we can easily save the length of the data loader into an attribute and use it in the configure_optimizers() method.
I don't like previously mentioned solution with explicit call len(self.train_dataloader()) because I have to call prepare_data() first and construct train_dataloader one more unnecessary time.

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

anthonytec2 picture anthonytec2  路  3Comments

iakremnev picture iakremnev  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

edenlightning picture edenlightning  路  3Comments

mmsamiei picture mmsamiei  路  3Comments