I am trying to use a scheduler that should call step() function on every iteration rather than every epoch (i.e., 1-cycle like scheduler). However, I am not sure what is the best way to implement it with the most recent version of pytorch-lightning.
I've seen some discussions and docs here. (Including the discussion about LR warm-up that proposes a possible solution). But I don't know if they are still relevant with the most recent version of the package. Also, they don't seem to solve exactly the same question.
Currently, I am trying to do something similar to what the following snippet shows.
class Module(pl.LightningModule):
...
def configure_optimizers(self):
opt = create_optimizer(self)
sched = create_scheduler(total_steps=self.num_train_batches * self.num_epochs)
return {'optimizer': opt, 'lr_scheduler': sched}
def on_batch_end(self):
self.scheduler.step()
def on_epoch_end(self):
pass # do nothing
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int,
second_order_closure: Optional[Callable] = None,
):
optimizer.step()
optimizer.zero_grad()
def training_step(self, batch, batch_no):
loss = self(batch)
return {'loss': loss}
So I am calling the step() "manually" at the end of every batch. But am I doing it right? Will the Trainer call my scheduler at the end of an epoch as well? (Not something that I would expect). What is the "right way" to implement an iteration-based scheduler?
My goal is to configure the pl.LightningModule behavior in such a way that the scheduler is called at the end of every training batch, and is not called during validation. Also, it shouldn't be called at the end of the epoch.
conda0.7.3Hi! thanks for your contribution!, great first issue!
By default, it updates after every epoch. You can change the interval I guess:
opt = create_optimizer(self)
sched = {
'scheduler': create_scheduler(total_steps=self.num_train_batches * self.num_epochs),
'interval': 'step'
}
return {'optimizer': opt, 'lr_scheduler': sched}
@rohitgr7 that's great! I didn't know about this option. Do you know if there is some doc page to read about available options and parameters helping to configure the Trainer? Or is it currently only available via docstrings and the code?
in the docs there is a trainer section. that section lists all the flags
It's a little bit buried, but I was looking for this yesterday and found it here under configure_optimizers() at the bottom of the _note_ https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#lightningmodule-class
@williamFalcon close this?