I could be wrong, but I noticed the following in the code of lightning module's optimizer_step
if on_tpu:
xm.optimizer_step(optimizer)
elif using_native_amp:
self.trainer.scaler.step(optimizer)
elif using_lbfgs:
optimizer.step(second_order_closure)
else:
optimizer.step()
If someone uses a custom optimizer that needs the closure returning the loss multiple times, it won't work.
Since all classes that inherit from torch.optim.Optimizer have the step method accept a closure (even if they don't need it), we could just do
if on_tpu:
xm.optimizer_step(optimizer)
elif using_native_amp:
self.trainer.scaler.step(optimizer)
# elif using_lbfgs:
# optimizer.step(second_order_closure)
else:
optimizer.step(second_order_closure)
and drop the "using_lbfgs" argument?
The user has to override the optimizer_step themself.
it sounds as a good way to go to me...
cc: @williamFalcon
we had that before but removed it for some reason. I can't remember it though. But i originally agreed.
Let's do this after 0.9.0 since we need to dig back to why it changed
@awaelchli want to take a look at this now? or better post v1?
@williamFalcon LBFGS not being compatible with native amp is the only exception I found in the code. maybe you mean that?
@edenlightning I don't think it's an essential for v1.0, but if you wish to have the final api without this argument "using_lbfgs" for optimizer step, then I could send a PR.
Most helpful comment
we had that before but removed it for some reason. I can't remember it though. But i originally agreed.
Let's do this after 0.9.0 since we need to dig back to why it changed