Pytorch-lightning: Checkpointing is broken on TPUs

Created on 25 Jul 2020  ·  6Comments  ·  Source: PyTorchLightning/pytorch-lightning

🐛 Bug

Pytorch/XLA saves checkpoints using the following syntax which is not supported in pytorch-lightning.

import torch_xla.core.xla_model as xm
xm.save()

It is a little tricky to support because xm.save() has a barrier inside it and it checks for rank=0 while torch.save doesn't. This means torch.save should be called only on the process with rank=0 (which pytorch-lighting does) but xm.save() should be called by all processes (or it will wait forever at the barrier). This means pytorch-lightning code that checks for the rank (here will need to be switched off on TPUs.

To Reproduce

  1. Train any model on TPUs using PyTorch/XLA with ptl.Trainer(checkpoint_callback=[ModelCheckpoint(...)], num_tpu_cores=8)
  2. Wait until the model saves one checkpoint then kill the process
  3. Try to load the saved checkpoint with ptl.Trainer(resume_from_checkpoint='path_to_saved_checkpoint', num_tpu_cores=8)
  4. See error

Expected behavior

Loading checkpoint successfully.

Environment

pytorch-lightning==v0.8.5

Additional Context

Thanks to @matt-peters for finding the bug and suggesting the solution mention below.

Priority P0 TPU bug / fix help wanted

Most helpful comment

Thanks, @lezwon. You might want to check this fix here https://github.com/ibeltagy/pytorch-lightning/commit/a5c8d182329a3be88f17f952bee9cc063116c515 which works but I don't like it. I also tried calling the functions inside xm.save here in the main process only without the barrier but everything hangs, maybe because the processes go out of sync.

All 6 comments

@lezwon mind have look?

sure :]

Thanks, @lezwon. You might want to check this fix here https://github.com/ibeltagy/pytorch-lightning/commit/a5c8d182329a3be88f17f952bee9cc063116c515 which works but I don't like it. I also tried calling the functions inside xm.save here in the main process only without the barrier but everything hangs, maybe because the processes go out of sync.

@ibeltagy Nice work :] I'll check out your solution and try and back with a fix on this.

@ibeltagy I am able to reload the checkpoint successfully, however, the training fails due to some xla device issue. Is it the same error you face? could you share a notebook reproducing this issue?

Loading the checkpoint fails only fails when loading without a TPU device available, as torch.save will write out XLA tensors instead of pytorch tensors. This a common workflow to train on TPU, but then move to a CPU or GPU for further processing. As a result, I don't think it's possible to reproduce with a notebook. xm.save moves everything to CPU before saving to avoid this problem. I opened a PR #2726 that includes a fix.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

versatran01 picture versatran01  ·  3Comments

as754770178 picture as754770178  ·  3Comments

remisphere picture remisphere  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

srush picture srush  ·  3Comments