Is there a way to define a number N so that the loss backward update can be performed after N iterations? It is common scenario that we cannot increase the batch size due to memory constraint and we would like to do less loss computation for faster training.
@MottoX there are several ways to do this. The most simplest is the following:
accumulation_steps = 4
def update_fn(engine, batch):
model.train()
if engine.state.iteration % accumulation_steps == 0:
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = criterion(y_pred, y) / accumulation_steps
loss.backward()
if engine.state.iteration % accumulation_steps == 0:
optimizer.step()
return loss.item()
trainer = Engine(update_fn)
Thank you!