Pytorch/XLA saves checkpoints using the following syntax which is not supported in pytorch-lightning.
import torch_xla.core.xla_model as xm
xm.save()
It is a little tricky to support because xm.save() has a barrier inside it and it checks for rank=0 while torch.save doesn't. This means torch.save should be called only on the process with rank=0 (which pytorch-lighting does) but xm.save() should be called by all processes (or it will wait forever at the barrier). This means pytorch-lightning code that checks for the rank (here will need to be switched off on TPUs.
ptl.Trainer(checkpoint_callback=[ModelCheckpoint(...)], num_tpu_cores=8)ptl.Trainer(resume_from_checkpoint='path_to_saved_checkpoint', num_tpu_cores=8)Loading checkpoint successfully.
pytorch-lightning==v0.8.5
Thanks to @matt-peters for finding the bug and suggesting the solution mention below.
@lezwon mind have look?
sure :]
Thanks, @lezwon. You might want to check this fix here https://github.com/ibeltagy/pytorch-lightning/commit/a5c8d182329a3be88f17f952bee9cc063116c515 which works but I don't like it. I also tried calling the functions inside xm.save here in the main process only without the barrier but everything hangs, maybe because the processes go out of sync.
@ibeltagy Nice work :] I'll check out your solution and try and back with a fix on this.
@ibeltagy I am able to reload the checkpoint successfully, however, the training fails due to some xla device issue. Is it the same error you face? could you share a notebook reproducing this issue?
Loading the checkpoint fails only fails when loading without a TPU device available, as torch.save will write out XLA tensors instead of pytorch tensors. This a common workflow to train on TPU, but then move to a CPU or GPU for further processing. As a result, I don't think it's possible to reproduce with a notebook. xm.save moves everything to CPU before saving to avoid this problem. I opened a PR #2726 that includes a fix.
Most helpful comment
Thanks, @lezwon. You might want to check this fix here https://github.com/ibeltagy/pytorch-lightning/commit/a5c8d182329a3be88f17f952bee9cc063116c515 which works but I don't like it. I also tried calling the functions inside
xm.savehere in the main process only without the barrier but everything hangs, maybe because the processes go out of sync.