Pytorch-lightning: Incorrect Implementation for Accumulating Batch Gradients in Trainer

Created on 10 Aug 2019  路  7Comments  路  Source: PyTorchLightning/pytorch-lightning

Current Behavior:
If accumulate_grad_batches > default of 1, the Trainer will proceed to take the loss from each batch and run loss.backward() for each batch accumulated, running the optimizer.step() once the desired number of batches has undergone backprop.
Loss averaging is only done for batch_loss_value.

Correct Behavior:
The loss from the output needs to be divided by accumulate_grad_batches before loss.backward() is run, otherwise the overall magnitude of the gradient could be up to N times greater for a simulated batch size N times bigger than the actual.

bug / fix

All 7 comments

@AS-researcher6 good find, i see the bug. the division is done once the step is applied, but the division should be on line 926.

Current order:

  1. clip
  2. step
  3. zero_grad
  4. divide accumulated loss by nb accumulated batches

Correct order:

  1. divide accumulated loss by nb accumulated batches
  2. clip
  3. step
  4. zero_grad

Submitting the change.
Mind sanity checking PR #88

Gladly! Not sure if you're done fixing things but I don't see the commit where the loss is divided by self.accumulate_grad_batches before the loss.backward(). Otherwise the original version where self.batch_loss_value += loss.item() is good as long as the self.batch_loss_value is not averaged after.

look at the PR #88
line 928

Sorry, I don't think it's quite right yet. Accounting for multiple batches needs to happen in the actual loss itself before backpropagation. Otherwise, its like N accumulated additions of loss.backward() rather than N accumulated additions of (1/N * loss).backward()
After line 898 and before loss.backward() on the 903 if statement, there needs to be a line that's loss = loss / self.accumulate_grad_batches. Line 922 should still be self.batch_loss_value += loss.item(). Line 928 can be removed entirely, since the averaging has already been accounted for.

@AS-researcher6 i see. updated. Was following the wrong approach before.

How about now?

@AS-researcher6 merged to master. should be correct now. pls verify

All good. Thanks for the fixes!
Working on a SLURM cluster myself so may submit pull requests for expanded Trainer functionality and bring up more things in the future.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

srush picture srush  路  3Comments

remisphere picture remisphere  路  3Comments

justusschock picture justusschock  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

polars05 picture polars05  路  3Comments