Ignite: add gradient clipping to `create_supervised_trainer()`

Created on 30 Jan 2019  路  6Comments  路  Source: pytorch/ignite

It would be good to add gradient clipping to the trainers created by create_supervised_trainer. This is already provided by torch.nn.utils.clip_grad_norm_.

One possible implementation could be:

import math
from torch.nn.utils import clip_grad_norm_

def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              gradient_clip=math.inf):
    """
    Factory function for creating a trainer for supervised models.
    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        gradient_clip (float, optional): value to use to clip gradients.
    Note: `engine.state.output` for this engine is the loss of the processed batch.
    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        clip_grad_norm_(model.parameters(), gradient_clip)
        optimizer.step()
        return loss.item()

    return Engine(_update)
question

Most helpful comment

@AntoinePrv I think it would be more simple to write custom processing function instead of custom events.

All 6 comments

@lmarti thanks for the feedback. We discussed a similar question in #375.
Methods like create_supervised_trainer are just helper methods for a basic usage, use directly Engine with custom process_fn.

We can discuss whether such trainer could be useful and placed in contrib.engines module.
cc @willprice

Sorry, I missed that one. I had the same doubts w.r.t. moving it to contrib.engines. My point against doing it is that the code would be so similar to the one in create_supervised_trainer. In any case, you are driving here.

A general way to maintain this would be to fire a new event (GRADIENT_COMPUTED?) between loss.backward() and optimizer.step()

Doesn't have to be added into core events, it can just be added for supervised_trainer as we did with supervised_tbptt_trainer.

@AntoinePrv I think it would be more simple to write custom processing function instead of custom events.

@vfdev-5 While I agree with you, it would be nice to have options. In particular, it would be great if we could have more events compared to the fastai callback system. The callbacks listed there are (events in parenthesis):

  1. on_train_begin() (Events.STARTED)
  2. on_epoch_begin() (Events.EPOCH_STARTED)
  3. on_batch_begin() (Events.ITERATION_STARTED)
  4. on_loss_begin()*: Called after forward pass but before loss has been computed.
  5. on_backward_begin()*: Called after forward pass and loss computation but before backprop.
  6. on_backward_end()*: Called after backprop but before optimizer step.
  7. on_step_end()*: Called after optimizer step but before gradients are zeroed.
  8. on_batch_end() (Events.ITERATION_COMPLETED)
  9. on_epoch_end() (Events.EPOCH_COMPLETED)
  10. on_train_end() (Events.COMPLETED)
  • these fastai callbacks have not corresponding ignite events. Having these as options provides the following advantages:
  • It adds even more flexibility to the engine
  • A lot of fastai's callbacks are utilized to provide tips and other advantages such as LRFinder, gradient clipping etc. It would be easy to port over those if we have these events.

@sudarshan85 we can think about to provide a generic callback class into contrib module.
But I hardly imagine a class that uses all these on_* methods. The example you cited, LRFinder implements just 3 methods: on_train_begin, on_batch_end, on_train_end. This is very similar to the behaviour of our classes with attach method = handle 2-3 events of the Engine: Metric, ProgressBar etc.

Was this page helpful?
0 / 5 - 0 ratings