Ignite: Improve create_supervised_trainer with optional output_transform

Created on 4 Apr 2019  路  2Comments  路  Source: pytorch/ignite

Following the discussion, idea is to give more flexibility to users who are using create_supervised_trainer:

def default_output_transform(x, y, y_pred, loss):
    return loss.item() 


def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None, non_blocking=False, prepare_batch=_prepare_batch, 
                              output_transform=default_output_transform):
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

    return Engine(_update)    

cc @IlyaOvodov

enhancement help wanted

All 2 comments

@vfdev-5
But I think make 2 functions output_loss and output_y_pred_y and:

def create_supervised_trainer(... , output_transform = output_loss)
def create_supervised_evaluator(... , output_transform = output_y_pred_y)

It is more meaningful, IMHO

I'll make it

Was this page helpful?
0 / 5 - 0 ratings

Related issues

jphdotam picture jphdotam  路  4Comments

Aiden-Jeon picture Aiden-Jeon  路  3Comments

vfdev-5 picture vfdev-5  路  3Comments

karfly picture karfly  路  4Comments

samarth-robo picture samarth-robo  路  3Comments