Pytorch-lightning: Issue with epoch count with repeated save/restore

Created on 15 Oct 2020  路  22Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

I'm trying to save and restore the state of both a model and a pytorch-lightning trainer.

I suspect the epoch count is wrong because I'm not able to save and restore several times with the same max_epoch count.

Here's what I do:

Step 1: run model for max_epochs = 1. Save checkpoint (gets saved as epoch=0.ckpt)
Step 2: load previous checkpoint and rerun again with max_epochs = 1. No training is run (because 1 epoch was already run before). A checkpoint is saved again, however this is called epoch=1.ckpt.
Step 3: load checkpoint from step 2 and rerun again with max_epochs = 1. Training fails because it believes step 2 was run for 2 epochs (and here max is 1)

Output:

pytorch_lightning.utilities.exceptions.MisconfigurationException: 
            you restored a checkpoint with current_epoch=2
            but the Trainer(max_epochs=1)

Code below to reproduce.

What am I doing wrong? this should be a possible scenario right?

Thanks!

To Reproduce

Run code below 3 times from same location

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from pathlib import Path

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val, _ = random_split(dataset, [5500, 500, len(dataset) - 6000])

def get_last_checkpoint(checkpoint_folder):
    if os.path.exists(checkpoint_folder):
        past_experiments = sorted(Path(checkpoint_folder).iterdir(), key=os.path.getmtime)

        for experiment in past_experiments[::-1]:
            experiment_folder = os.path.join(experiment, "checkpoints")
            if os.path.exists(experiment_folder):
                checkpoints = os.listdir(experiment_folder)

                if len(checkpoints):
                    checkpoints.sort()
                    path = os.path.join(experiment_folder, checkpoints[-1])
                    return path

    return None

chk = get_last_checkpoint('lightning_logs')
if chk is not None:
    print("loading from ", chk)
    autoencoder = LitAutoEncoder.load_from_checkpoint(chk)
else:
    autoencoder = LitAutoEncoder()

trainer = pl.Trainer(max_epochs=1, resume_from_checkpoint=chk)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

Expected behavior

Should not increase epochs on second run. Should be able to load checkpoint and save unchanged (several times)

Environment

  • CUDA:
    - GPU:
    - Quadro P2000 with Max-Q Design
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.2
    - pyTorch_debug: True
    - pyTorch_version: 1.8.0.dev20201014
    - pytorch-lightning: 1.0.2
    - tqdm: 4.50.2
  • System:
    - OS: Windows
    - architecture:
    - 64bit
    - WindowsPE
    - processor: Intel64 Family 6 Model 158 Stepping 10, GenuineIntel
    - python: 3.7.9
    - version: 10.0.17763

Additional context

Checkpoint Priority P0 bug / fix help wanted

All 22 comments

I think the reason is it still calls the training which eventually calls on_train_end and then checkpoint_callback which creates this new checkpoint. Also noticed that using resume_from_checkpoint, if creates a folder for new checkpoints, but I think it should use the old one since we are resuming the training here.

when it calls on_train_end maybe it could check whether anything has been run and if not, not increase the epoch and global step count? This would basically duplicate the checkpoint folder, but in a consistent way?

good suggestion. But I still think it should not create a new version for checkpoints if we are resuming the training, but this might create some problems with the logs, not sure.
@PyTorchLightning/core-contributors suggestions here?

But I still think it should not create a new version for checkpoints if we are resuming the training, but this might create some problems with the logs, not sure.

since the logger handles the versioning, I would leave it up to the logger to decide how this is managed. For example, in some loggers you have to set an argument for when you want to resume logging to an existing experiment. In this case, the version number should be come the same and checkpoints can be saved to the same location.

I see also that the checkpoint is saving the next loop count, both for epoch index (in this case epoch index = 1, even though the file is called epoch=0.ckpt) and global step, as if we are ready to start the next loop.

Another way to fix this would be to save the state before incrementing the counters (and incrementing them on state load)

I found the same bug is present when running in SLURM cluster, when the task is rescheduled more then the number of epochs.

Hey All,

I am working on resolving this bug right now. I will keep you posted on my progress :)

Best ,
Thomas Chaton.

Dear @michele-arrival @andreamad8,

I have a draft PR trying to reproduce this behaviour: https://github.com/PyTorchLightning/pytorch-lightning/pull/4277/files
It seems to be PASS. Would you mind having a look and maybe extending it to catch the bug ?

Also, I found another weird behaviour where validation isn't been called on first test() and then yes when loaded from checkpoints. Another issue 馃憤

Thanks for those great catch !

Best regards,
Thomas Chaton.

I will pull and check, thanks for your help

Andrea

Hi @tchaton ,

thanks for the new test! I believe in your test_checkpoint_repeated_strategy, in steps 1-5 you are not saving the checkpoint in the same folder you are reading the previous checkpoint from, but rather into the default ./lightning-logs, therefore never triggering the chain. Maybe you are missing an explicit path to the trainer?

I can't replicate the FAIL on your test. That ModelCheckpoint changes everything. The checkpoint is correctly saved as epoch=00.ckpt, epoch=00-v0.ckpt, epoch=00-v1.ckpt etc... I don't understand what is going on inside that callback, however it's the vanilla case (with trainer constructed with restore_from_checkpoint=file) that fails

Hey @michele-arrival,

I extended the test and still PASS. Could you make a test to reproduce your bug with a BoringModel.

Best,
T.C

yes, I have a version that fails, can i push to your branch or would you rather get a patch?

Hey @michele-arrival ,

Looking into it.

Best, T.C

Thanks, eager to test when ready

Hi, just FYI. this doesn't fail anymore, however the saved checkpoint is still epoch=0.ckpt on the first iteration, and epoch=1.ckpt on the following ones, which I believe is wrong?

this issue is not resolved completely yet. It's still creating checkpoints. I think I have a better suggestion. Will send a PR. Also this new has_trained should be a property of the trainer IMO. @tchaton thoughts?

Hi @michele-arrival , can you check whether #4372 solves your issue?

Hi @michele-arrival , can you check whether #4372 solves your issue?

Hi @rohitgr7 thanks for this. Yes I believe after your fix it stops saving new checkpoints after the first one.

here's what the snipped of code i just used:

import os
import shutil
from pathlib import Path

import pytorch_lightning as pl
from tests.base import BoringModel


def get_last_checkpoint(checkpoint_folder):
    if os.path.exists(checkpoint_folder):
        past_experiments = sorted(Path(checkpoint_folder).iterdir(), key=os.path.getmtime)

        for experiment in past_experiments[::-1]:
            experiment_folder = os.path.join(experiment, "checkpoints")
            if os.path.exists(experiment_folder):
                checkpoints = os.listdir(experiment_folder)

                if len(checkpoints):
                    checkpoints.sort()
                    path = os.path.join(experiment_folder, checkpoints[-1])
                    return path

    return None

autoencoder = BoringModel()

tmpdir = 'lightning_logs'
shutil.rmtree(tmpdir, ignore_errors=True)

for i in range(10):
    chk = get_last_checkpoint(tmpdir) if i > 0 else None

    trainer = pl.Trainer(max_epochs=1, resume_from_checkpoint=chk, gpus=-1)

    autoencoder.validation_step_end = None
    autoencoder.validation_epoch_end = None
    trainer.fit(autoencoder)
Was this page helpful?
0 / 5 - 0 ratings

Related issues

edenlightning picture edenlightning  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

Vichoko picture Vichoko  路  3Comments

iakremnev picture iakremnev  路  3Comments