Following the discussion, idea is to give more flexibility to users who are using create_supervised_trainer:
def default_output_transform(x, y, y_pred, loss):
return loss.item()
def create_supervised_trainer(model, optimizer, loss_fn,
device=None, non_blocking=False, prepare_batch=_prepare_batch,
output_transform=default_output_transform):
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)
return Engine(_update)
cc @IlyaOvodov
@vfdev-5
But I think make 2 functions output_loss and output_y_pred_y and:
def create_supervised_trainer(... , output_transform = output_loss)
def create_supervised_evaluator(... , output_transform = output_y_pred_y)
It is more meaningful, IMHO
I'll make it