Ignite: Accumulate loss and perform a backward update after several iterations?

Created on 1 Apr 2019  路  2Comments  路  Source: pytorch/ignite

Is there a way to define a number N so that the loss backward update can be performed after N iterations? It is common scenario that we cannot increase the batch size due to memory constraint and we would like to do less loss computation for faster training.

question

All 2 comments

@MottoX there are several ways to do this. The most simplest is the following:


accumulation_steps = 4

def update_fn(engine, batch):
        model.train()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.zero_grad()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = criterion(y_pred, y) / accumulation_steps 
        loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()

        return loss.item()

trainer = Engine(update_fn)

Thank you!

Was this page helpful?
0 / 5 - 0 ratings