Ignite: Custom configuration how/where to save the best model with ModelCheckpoint

Created on 27 Dec 2018  路  6Comments  路  Source: pytorch/ignite

In case when user would like to integrate ModelCheckpoint with a package for experiment tracking, e.g. mlflow, polyaxon, etc.
In such case logging, model weights etc can be stored on a cloud storage, e.g.

exp_tracking.log_artifact(filepath)

Be default ModelCheckpoint is saving the model to the provided path dirname.
Idea is to provide a flexibility to execute a custom code when model is saved and be eable to store everywhere we would like.

What do you think ?

cc @elanmart

enhancement help wanted

All 6 comments

I have been using a hackish solution for this problem. A lambda function at the end of the function _internal_save() that receives the path to the saved file.
This work fine to use with the mlflow.log_artifact function that expects a file. However I few that this approach is somewhat hackish. Maybe a better approach would be to use handles to attach "output targets".

This sounds nice. I also did some hacking and just want to share the code, in case it can be useful for anyone of you. In my case, I needed a custom save method. I didn't want to use torch.save(), because my own model class is still under development and I want to have compatibility between all its versions. My save method simply saves hyperparameters and weights from which the class is recreated when it is loaded.

I inherited my ModelSaver class from ignite.handlers.ModelCheckpoint and overloaded the _internal_save method. There are two more little features: Save the model on exception and when the training is completed. Sorry for the incomplete documentation, but at the moment, I have only little time.

Long story short... here is my code:

from ignite.engine import Events
import ignite
import os


class ModelSaver(ignite.handlers.ModelCheckpoint):
    """"
    Extends class`ignite.handlers.ModelCheckpoint with option to provide a custom save method,
    saving the final model after training ends and saving a model if an exception is raised during training
    """

    def __init__(self, *args, save_method=None, save_on_exception=True, save_on_completed=True, **kwargs):
        if not isinstance(save_on_exception, bool):
            raise TypeError(
                "Argument save_on_exception must be of type bool, got {] instead.".format(type(save_on_exception)))
        if not isinstance(save_on_completed, bool):
            raise TypeError(
                "Argument save_on_completed must be of type bool, got {] instead.".format(type(save_on_completed)))
        if save_method is not None and not callable(save_method):
            raise TypeError(
                "Argument save_method must be callable accepting two arguments: Model object to save and path")

        self._save_method = save_method
        self._save_on_completed = save_on_completed
        self._save_on_exception = save_on_exception

        super(ModelSaver, self).__init__(*args, **kwargs)

    def _internal_save(self, obj, path):
        if self._save_method is not None:
            self._save_method(obj, path)
        else:
            super(ModelSaver, self)._internal_save(obj, path)

    def _on_exception(self, engine, exception, to_save):
        for name, obj in to_save.items():
            fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, "_on_exception")
            path = os.path.join(self._dirname, fname)
            if os.path.exists(path):
                os.remove(path)
            self._save(obj=obj, path=path)

    def _on_completed(self, engine, to_save):
        for name, obj in to_save.items():
            fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, "_on_completed")
            path = os.path.join(self._dirname, fname)
            if os.path.exists(path):
                os.remove(path)
            self._save(obj=obj, path=path)

    def attach(self, engine, model_dict):
        """
                Attaches the model saver to an engine object

                Args:
                    engine (Engine): engine object
                    model_dict (dict): A dict mapping names to objects, e.g. {'mymodel': model}
        """
        engine.add_event_handler(Events.EPOCH_COMPLETED, self, model_dict)
        engine.add_event_handler(Events.COMPLETED, self._on_completed, model_dict)
        engine.add_event_handler(Events.EXCEPTION_RAISED, self._on_exception, model_dict)

@Bibonaut thanks for sharing the code ! Looks nice! We can think to put it into contrib module.

@Bibonaut would you like to send a PR with this code and some basic tests ? It would be awesome !

@vfdev-5 Sorry for my late response, still very busy :| I'ld really like to contribute if you can wait until mid of May where I have to meet a deadline. Then I'ld like to share and discuss some more features with you I made to make my workflow and investigations a lot easier.

@Bibonaut no problems, we can wait until your deadline :)

Was this page helpful?
0 / 5 - 0 ratings