Ignite: create_supervised_trainer fails if model device is diff from arg device

Created on 18 Jun 2018  路  4Comments  路  Source: pytorch/ignite

Feature or bug to discuss. If I modify mnist.py example:

def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'
#         model = model.to(device)

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': CategoricalAccuracy(),
                                                     'nll': Loss(F.nll_loss)},
                                            device=device)

The operation model(x) will fail as x and model are on different devices. Maybe we need to :

  • check this and assert or
  • remove device from argument and setup it internally according to the model or
  • set model on the provided device

@alykhantejani @elanmart @jasonkriss What do you guys think about ?

0.1.0

Most helpful comment

Yeah I think we should set the model on the provided device, this would fix the bug and make sense as we should make it as easy as possible with these factory functions

All 4 comments

I don't think this is a huge issue, and I don't think we should take care of setting the correct device for the user.

One thing we could perhaps do, is adding device='auto', which will set device to the device of the model parameters, if all parameters live on the same device.

I agree that this is not a big deal. Just if someone applies the document as it is written:

device (optional): device type specification (default: None)

he/she would wonder why it is applied on batches only...

Oh, yeah, on a second thought we could actually go with

  • set model on the provided device

or rename the arg / clarify the docs.

Yeah I think we should set the model on the provided device, this would fix the bug and make sense as we should make it as easy as possible with these factory functions

Was this page helpful?
0 / 5 - 0 ratings

Related issues

elanmart picture elanmart  路  4Comments

andreydung picture andreydung  路  4Comments

sisp picture sisp  路  3Comments

Sudy picture Sudy  路  4Comments

vfdev-5 picture vfdev-5  路  3Comments