Ignite: ModelCheckpoint's _saved variable and EarlyStopping

Created on 9 Apr 2020  路  5Comments  路  Source: pytorch/ignite

I'm using ignite 0.2.1 similar to the transfer-learning-conv-ai repo by Hugging Face. In these lines, you can see that:

  • the checkpoint is being saved for every epoch
  • just the last three saved checkpoints are being retained on disk
  • the last checkpoint (due to _saved[-1]) is being renamed to be the final trained model

In my code, I'm additionally using the EarlyStopping class with a configurable patience like this:

    valid_es_handler = EarlyStopping(patience=args.patience, score_function=early_stopping_score_function,
                                     trainer=trainer)
    validator.add_event_handler(Events.COMPLETED, valid_es_handler)

Now what I want to accomplish is this: I want to identify and rename the best (in terms of validation set score) trained model from the window of stored checkpoints.

I think the first change that needs to be done is n_saved=args.patience from n_saved=3, so that the window of saved checkpoints is equal to the patience used for early stopping.

Consequently, it looks like I need to provide the same early_stopping_score_function also to ModelCheckpoint using the score_function arg, and that would create a score-based priority queue of saved checkpoints.

And with those changes, it looks like _saved[-1] would still point to the "best" model checkpoint in the window. Is my understanding of the changes correct?

Also, I haven't looked at the newer versions of ignite after 0.2.1, but could you please share what the breaking changes are (using the above linked code as an example)? I might consider upgrading to the latest ignite if the changes needed are minimal.

@vfdev-5

The other thing I don't understand is this - the score function would be called on the engine, but for our use-case, this engine should be the validator (for both EarlyStopping and ModelCheckpoint), right?

But this line in the transfer-learning-conv-ai repo:

trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation

will end up making the score function call on the trainer Engine if I understand correctly. How do I ensure that the validator is used for the score function in the checkpoint_handler?

question

All 5 comments

Thank you for this report +1

I don't have ignite 0.2.1 in mind but for checkpoint, please look the following code

global_step_transform = global_step_from_engine(trainer)

best_model_handler = ModelCheckpoint(
        dirname=output_path,
        filename_prefix="best",
        n_saved=n_saved,
        global_step_transform=global_step_transform,
        score_name="{}_{}".format(tag, metric_name.lower()),
        score_function=get_default_score_fn(metric_name),
    )

evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {"model": model,})

This snippet is from https://github.com/pytorch/ignite/blob/master/ignite/contrib/engines/common.py to help to define handlers.

So it's possible to save wrt to a metric :) and the score is suffixed in the name of the checkpoint file.

I hope it could help !

EDIT : Ok you pointed out internal ignite code so I suppose you already see that :'

EDIT 2 : for the second part of your question, I think that checkpoint should be attached to evaluator (like the snippet I shared). Althought, I don't know if it's ok with ignite 0.2.1...

REMARK Maybe we could refactor the code from HuggingFace to update to a recent version of ignite ? The requirements.txt refers to pytorch-ignite so I guess 0.3 (see https://github.com/huggingface/transfer-learning-conv-ai/blob/master/requirements.txt)

@vfdev-5 you should have more inputs.

@g-karthik please tell us if @sdesrozis 's solution does not fit.

And with those changes, it looks like _saved[-1] would still point to the "best" model checkpoint in the window. Is my understanding of the changes correct?

There were a bug with that found recently : https://github.com/pytorch/ignite/pull/745
It was then fixed and code is available in nightly release.

Also, I haven't looked at the newer versions of ignite after 0.2.1, but could you please share what the breaking changes are (using the above linked code as an example)? I might consider upgrading to the latest ignite if the changes needed are minimal.

Please, the release notes of 0.3.0 and keep us updated if you have other questions :)

@vfdev-5 @sdesrozis I don't know what global_step_transform is for, I definitely don't see that arg in 0.2.1. But I don't think that is related to my question.

@sdesrozis's answer in EDIT 2 doesn't make sense to me. Should model checkpointing be done by the trainer or the validator? @vfdev-5 Can you please give a clean example of how to do model checkpointing along with early stopping, both based on the same score function?

Now what I want to accomplish is this: I want to identify and rename the best (in terms of validation set score) trained model from the window of stored checkpoints.

@g-karthik in v0.2.1 it should work like this :


Code

import torch.nn as nn
import ignite
print(ignite.__version__)

import logging

from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, EarlyStopping


##### Setup logger to see what happens 
def setup_logger(name, level=logging.INFO, format="%(name)s %(levelname)s: %(message)s"):
    logger = logging.getLogger(name)
    logger.setLevel(level)

    # Remove previous handlers
    if logger.hasHandlers():
        for h in list(logger.handlers):
            logger.removeHandler(h)

    formatter = logging.Formatter(format)

    ch = logging.StreamHandler()
    ch.setLevel(level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    return logger


##### Setup simulated validation scores
val_scores = [
    0.1, 0.2, 0.3, 0.41, 0.51,  # goes up
    0.42, 0.43, 0.39, 0.422, # plateau
    0.61, # jump
    0.52, 0.50, 0.41, 0.404, 0.412, 0.432, 0.41, 0.41, 0.41, 0.41 # down and plateau => should stop
]

##### Setup model, trainer, evaluator
model = nn.Linear(1, 1)

trainer = Engine(lambda e, b: None)
trainer._logger = setup_logger("trainer")

# For example purposes only, evaluation function writes the score
def eval_fn(e, b):
    i = trainer.state.epoch - 1
    e.state.metrics["score"] = val_scores[i]

evaluator = Engine(eval_fn)

##### Compute validation score:
@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(_):
    evaluator.run([0])
    print("{} - Val score: {}".format(trainer.state.epoch, evaluator.state.metrics["score"]))


##### Setup ModelCheckpoint to save best models
n_saved = 5  # We need to have 5 best models seen during the whole training. This is unrelated to EarlyStopping and its patience 

def score_function(_evaluator):
    return _evaluator.state.metrics["score"]

best_model_handler = ModelCheckpoint(
    dirname="logs",
    filename_prefix="best",
    n_saved=n_saved,
    score_name="val_score",
    score_function=score_function,
)

# As we need to save the best model based on validation score, it is simplier to attach the handler to the evaluator:
evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {"model": model,})


##### Setup EarlyStopping to save best models

patience = 7  # We will tolerate other val score lower than 0.61 on the 10th epoch and stop 7 epoch after

val_es = EarlyStopping(patience=patience, score_function=score_function, trainer=trainer)
val_es._logger = setup_logger("EarlyStopping")

evaluator.add_event_handler(Events.COMPLETED, val_es)


##### Let's train
trainer.run([0, 1, 2], max_epochs=len(val_scores))

##### Outlook
print("\nBest model:", best_model_handler._saved[-1])
print("\nOther saved models:")
for v in best_model_handler._saved[:-1]:
    print("\t", v)



Output

0.2.1
trainer INFO: Engine run starting with max_epochs=20.
trainer INFO: Epoch[1] Complete. Time taken: 00:00:00
1 - Val score: 0.1
trainer INFO: Epoch[2] Complete. Time taken: 00:00:00
2 - Val score: 0.2
trainer INFO: Epoch[3] Complete. Time taken: 00:00:00
3 - Val score: 0.3
trainer INFO: Epoch[4] Complete. Time taken: 00:00:00
4 - Val score: 0.41
trainer INFO: Epoch[5] Complete. Time taken: 00:00:00
5 - Val score: 0.51
trainer INFO: Epoch[6] Complete. Time taken: 00:00:00
6 - Val score: 0.42
trainer INFO: Epoch[7] Complete. Time taken: 00:00:00
7 - Val score: 0.43
trainer INFO: Epoch[8] Complete. Time taken: 00:00:00
8 - Val score: 0.39
trainer INFO: Epoch[9] Complete. Time taken: 00:00:00
9 - Val score: 0.422
trainer INFO: Epoch[10] Complete. Time taken: 00:00:00
10 - Val score: 0.61
trainer INFO: Epoch[11] Complete. Time taken: 00:00:00
11 - Val score: 0.52
trainer INFO: Epoch[12] Complete. Time taken: 00:00:00
12 - Val score: 0.5
trainer INFO: Epoch[13] Complete. Time taken: 00:00:00
13 - Val score: 0.41
trainer INFO: Epoch[14] Complete. Time taken: 00:00:00
14 - Val score: 0.404
trainer INFO: Epoch[15] Complete. Time taken: 00:00:00
15 - Val score: 0.412
trainer INFO: Epoch[16] Complete. Time taken: 00:00:00
16 - Val score: 0.432
trainer INFO: Epoch[17] Complete. Time taken: 00:00:00
EarlyStopping INFO: EarlyStopping: Stop training
trainer INFO: Terminate signaled. Engine will stop after current iteration is finished.
17 - Val score: 0.41
trainer INFO: Engine run complete. Time taken 00:00:00

Best model: (0.61, ['logs/best_model_10_val_score=0.61.pth'])

Other saved models:
         (0.432, ['logs/best_model_16_val_score=0.432.pth'])
         (0.5, ['logs/best_model_12_val_score=0.5.pth'])
         (0.51, ['logs/best_model_5_val_score=0.51.pth'])
         (0.52, ['logs/best_model_11_val_score=0.52.pth'])



Same code using nightly release 0.4.0.dev20200412
Minor changes:

  • uses built-in setup_logger, _logger -> logger
  • uses global_step_transform to correctly log trainer's epoch
  • uses last_checkpoint to fetch best model


Code

import torch.nn as nn
import ignite
print(ignite.__version__)

from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, EarlyStopping, global_step_from_engine
from ignite.utils import setup_logger

##### Setup simulated validation scores
val_scores = [
    0.1, 0.2, 0.3, 0.41, 0.51,  # goes up
    0.42, 0.43, 0.39, 0.422, # plateau
    0.61, # jump
    0.52, 0.50, 0.41, 0.404, 0.412, 0.432, 0.41, 0.41, 0.41, 0.41 # down and plateau => should stop
]

##### Setup model, trainer, evaluator
model = nn.Linear(1, 1)

trainer = Engine(lambda e, b: None)
trainer.logger = setup_logger("trainer", format="%(name)s %(levelname)s: %(message)s")

# For example purposes only, evaluation function writes the score
def eval_fn(e, b):
    i = trainer.state.epoch - 1
    e.state.metrics["score"] = val_scores[i]

evaluator = Engine(eval_fn)

##### Compute validation score:
@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(_):
    evaluator.run([0])
    print("{} - Val score: {}".format(trainer.state.epoch, evaluator.state.metrics["score"]))


##### Setup ModelCheckpoint to save best models
n_saved = 5  # We need to have 5 best models seen during the whole training. This is unrelated to EarlyStopping and its patience 

def score_function(_evaluator):
    return _evaluator.state.metrics["score"]

best_model_handler = ModelCheckpoint(
    dirname="logs",
    filename_prefix="best",
    n_saved=n_saved,
    score_name="val_score",
    global_step_transform=global_step_from_engine(trainer),
    score_function=score_function,
)

# As we need to save the best model based on validation score, it is simplier to attach the handler to the evaluator:
evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {"model": model,})


##### Setup EarlyStopping to save best models

patience = 7  # We will tolerate other val score lower than 0.61 on the 10th epoch and stop 7 epoch after

val_es = EarlyStopping(patience=patience, score_function=score_function, trainer=trainer)
val_es.logger = setup_logger("EarlyStopping")

evaluator.add_event_handler(Events.COMPLETED, val_es)


##### Let's train
trainer.run([0, 1, 2], max_epochs=len(val_scores))

##### Outlook
print("\nBest model:", best_model_handler.last_checkpoint)
print("\nOther saved models:")
for v in best_model_handler._saved[:-1]:
    print("\t", v.filename)

What do you think ?

HTH

@vfdev-5 Thanks for the example!

I think it makes sense now, I was thinking n_saved for ModelCheckpoint should be dependent/proportional on patience for EarlyStopping if we wanted to fetch the best model. But I understand now that they can be independent. I will try using the evaluator for model checkpointing.

Was this page helpful?
0 / 5 - 0 ratings