Pytorch-lightning: Checkpoint loading at the test time

Created on 19 Oct 2020  路  5Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

Hi,
I am having problems with using checkpointing and testing with the use of DDP in multi-GPU setup. I guess that model in each process tries to save its model during the training and then load it at the test time. However, there is just one such model saved and therefore the other processes cannot find their saved model and I get the following error:
FileNotFoundError: [Errno 2] No such file or directory: '/home/vobecant/PhD/system_validation/results/debug_cifar10_multigpu_gradclip/checkpoints/ckpt_epoch=07-val_loss=0.80.ckpt'

Please reproduce using the BoringModel and post here

I don't know how to test multi-GPU behavior with DDP in colab. I post the snippet here:

model = BoringModel()

parser = argparse.ArgumentParser()
parser.add_argument('--gpus_per_node', default=1, type=int)
parser.add_argument('--num_nodes', default=1, type=int)
parser.add_argument('--distributed_backend', default='ddp')
trainer_args = parser.parse_args()

callbacks = []

early_stopping = EarlyStopping('val_loss', patience=args.patience)
setattr(trainer_args, 'early_stopping_callback', early_stopping)
filepath = os.path.join('./tmp', 'ckpt_{epoch:02d}-{val_loss:.2f}')
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor=args.monitor, filepath=filepath)
setattr(trainer_args, 'checkpoint_callback', checkpoint_callback)
callbacks.append(checkpoint_callback)

lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks.append(lr_monitor)
setattr(trainer_args, 'callback', callbacks)
tb_logger = pl_loggers.TensorBoardLogger(args.logs_dir)
setattr(trainer_args, 'logger', tb_logger)
print('Used callbacks: {}'.format(callbacks))
trainer = pl.Trainer.from_argparse_args(trainer_args, callbacks=callbacks,
                                        checkpoint_callback=checkpoint_callback, max_epoch=1)

trainer.fit(model, datamodule=dm)

result = trainer.test() if (args.monitor and args.val_size > 0) else trainer.test(model)
print(result)

trainer.test()

Expected behavior

Test the model with the best model saved during the training.

Environment

  • CUDA:

    • GPU:



      • GeForce GTX 1080 Ti



    • available: True

    • version: 10.1

  • Packages:

    • numpy: 1.19.1

    • pyTorch_debug: False

    • pyTorch_version: 1.6.0

    • pytorch-lightning: 1.0.2

    • tqdm: 4.50.2

  • System:

    • OS: Linux

    • architecture:



      • 64bit


      • ELF



    • processor: x86_64

    • python: 3.8.5

    • version: #1 SMP Tue Aug 25 17:23:54 UTC 2020

Checkpoint Priority P0 bug / fix help wanted

All 5 comments

Hi! thanks for your contribution!, great first issue!

@justusschock maybe you can take a stab

I encountered the same issue and my temporary fix is removing val_acc or val_loss from the checkpoint file name:

     checkpoint_callback = ModelCheckpoint(monitor='val_acc',
                                           mode='max',
-                                          filepath='%s/{epoch:02d}-{val_acc:.2f}' % model_dir,
+                                          filepath='%s/{epoch:02d}' % model_dir,

The processes should have the same copy of the current model. They are getting different val_loss/val_acc because they evaluate the model on different subsets of the dataset.

Ideally, ModelCheckpoint() should aggregate the metrics from all processes before saving the new best model.

@mzweilin agreed! however we do offer the metrics API which handles this, we urge users to move towards this: https://pytorch-lightning.readthedocs.io/en/latest/metrics.html

The issue is related to ensuring that sync_dist is enabled for val results. Using the bug report model modify:

def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}
def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log("x", loss, sync_dist=True)

This will ensure that the validation metric is synced across distributed process. Feel free to re-open this issue if it doesn't work :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

srush picture srush  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

polars05 picture polars05  路  3Comments

monney picture monney  路  3Comments

versatran01 picture versatran01  路  3Comments