In case when user would like to integrate ModelCheckpoint with a package for experiment tracking, e.g. mlflow, polyaxon, etc.
In such case logging, model weights etc can be stored on a cloud storage, e.g.
exp_tracking.log_artifact(filepath)
Be default ModelCheckpoint is saving the model to the provided path dirname.
Idea is to provide a flexibility to execute a custom code when model is saved and be eable to store everywhere we would like.
What do you think ?
cc @elanmart
I have been using a hackish solution for this problem. A lambda function at the end of the function _internal_save() that receives the path to the saved file.
This work fine to use with the mlflow.log_artifact function that expects a file. However I few that this approach is somewhat hackish. Maybe a better approach would be to use handles to attach "output targets".
This sounds nice. I also did some hacking and just want to share the code, in case it can be useful for anyone of you. In my case, I needed a custom save method. I didn't want to use torch.save(), because my own model class is still under development and I want to have compatibility between all its versions. My save method simply saves hyperparameters and weights from which the class is recreated when it is loaded.
I inherited my ModelSaver class from ignite.handlers.ModelCheckpoint and overloaded the _internal_save method. There are two more little features: Save the model on exception and when the training is completed. Sorry for the incomplete documentation, but at the moment, I have only little time.
Long story short... here is my code:
from ignite.engine import Events
import ignite
import os
class ModelSaver(ignite.handlers.ModelCheckpoint):
""""
Extends class`ignite.handlers.ModelCheckpoint with option to provide a custom save method,
saving the final model after training ends and saving a model if an exception is raised during training
"""
def __init__(self, *args, save_method=None, save_on_exception=True, save_on_completed=True, **kwargs):
if not isinstance(save_on_exception, bool):
raise TypeError(
"Argument save_on_exception must be of type bool, got {] instead.".format(type(save_on_exception)))
if not isinstance(save_on_completed, bool):
raise TypeError(
"Argument save_on_completed must be of type bool, got {] instead.".format(type(save_on_completed)))
if save_method is not None and not callable(save_method):
raise TypeError(
"Argument save_method must be callable accepting two arguments: Model object to save and path")
self._save_method = save_method
self._save_on_completed = save_on_completed
self._save_on_exception = save_on_exception
super(ModelSaver, self).__init__(*args, **kwargs)
def _internal_save(self, obj, path):
if self._save_method is not None:
self._save_method(obj, path)
else:
super(ModelSaver, self)._internal_save(obj, path)
def _on_exception(self, engine, exception, to_save):
for name, obj in to_save.items():
fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, "_on_exception")
path = os.path.join(self._dirname, fname)
if os.path.exists(path):
os.remove(path)
self._save(obj=obj, path=path)
def _on_completed(self, engine, to_save):
for name, obj in to_save.items():
fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, "_on_completed")
path = os.path.join(self._dirname, fname)
if os.path.exists(path):
os.remove(path)
self._save(obj=obj, path=path)
def attach(self, engine, model_dict):
"""
Attaches the model saver to an engine object
Args:
engine (Engine): engine object
model_dict (dict): A dict mapping names to objects, e.g. {'mymodel': model}
"""
engine.add_event_handler(Events.EPOCH_COMPLETED, self, model_dict)
engine.add_event_handler(Events.COMPLETED, self._on_completed, model_dict)
engine.add_event_handler(Events.EXCEPTION_RAISED, self._on_exception, model_dict)
@Bibonaut thanks for sharing the code ! Looks nice! We can think to put it into contrib module.
@Bibonaut would you like to send a PR with this code and some basic tests ? It would be awesome !
@vfdev-5 Sorry for my late response, still very busy :| I'ld really like to contribute if you can wait until mid of May where I have to meet a deadline. Then I'ld like to share and discuss some more features with you I made to make my workflow and investigations a lot easier.
@Bibonaut no problems, we can wait until your deadline :)