Hey guys,
Loving the project so far. I have a handful of extensions, utilities that I'd love to discuss either adding to this project or putting in a separate library. Would it be possible to get some insight into the future directions of ignite? How are decisions being made? I'm happy to start opening PRs, but obviously wanted to discuss them first.
Some of the things I'm curious about:
If this is better discussed offline over email or something, just let me know.
Thanks a lot!
Hi @jasonkriss,
I think it'd be good to discuss this on github to keep it in the open and let others weigh in. To answer some of the questions please see my comments below:
_Would it be possible to get some insight into the future directions of ignite? How are decisions being made?_
_adding some common update and inference callables (e.g. a supervised updater that essentially just does what tnt did by default)_
_support for multiple validation sets_
_a different approach to metrics that doesn't require coordination between the update/inference function and the event handlers_
_considering using a "state" based approach (like in tnt) rather than a "history" based approach (in order to minimize memory consumption for use cases like #20)_
_allow validation to happen every n iterations instead of being epoch based (similar to what is suggested here_
validate_every_epoch parameter altogether and just let users fire of the trainer.validate() function in handlers either on the TRAINING_EPOCH_COMPLETE or the TRAINING_ITERATION_COMPLETE events._a "callback" abstraction for cases where multiple event handlers need to coordinate and pass state around_
Cheers,
Aly
cc @fmassa who might also have opinions on this
Hey @alykhantejani,
Thanks a lot for the thoughtful response! That all sounds great. I'll get into some more detail here but it is of course just a starting point for discussion.
Here are some rough sketches of supervised update and inference callables that I had in mind. Others could be added as well (WGAN, GAN, etc).
class SupervisedUpdate(object):
def __init__(self, model, optimizer, loss_fn):
self._model = model
self._optimizer = optimizer
self._loss_fn = loss_fn
def __call__(self, batch):
self._model.train()
self._optimizer.zero_grad()
x, y = batch
y_hat = self._model(x)
loss = self._loss_fn(y_hat, y)
loss.backward()
self._optimizer.step()
return loss.data[0]
class SupervisedInference(object):
def __init__(self, model):
self._model = model
def __call__(self, batch):
self._model.eval()
x, y = batch
y_hat = self._model(x)
return y, y_hat
One of the benefits of having something like this is that it removes the boilerplate for the most common scenarios. Another nice benefit is that it allows us to "standardize" the return signature of the inference callable. This would be important for the proposed validation/metric approach I'll get into below.
I would love to see a summary of what @apaszke had in mind for metrics. Curious how it would compare. This proposed approach addresses the need for multiple validation sets and allows triggering validation every n iterations rather than per epoch. It also introduces the proposed metric approach and callback abstraction.
Here is a code sketch to help illustrate the ideas:
train_loader = ...
val_loader = ...
model = ...
optimizer = ...
loss_fn = ...
update = SupervisedUpdate(model, optimizer, loss_fn)
inference = SupervisedInference(model)
Trainer(update, inference, callbacks=[
Validate(train_loader, Accuracy(), VisdomLogger('Train Accuracy'), interval=500),
Validate(val_loader, Accuracy(), VisdomLogger('Val Accuracy'), interval=500),
Checkpoint(model, '/tmp/checkpoints', interval=500)
])
trainer.run(train_loader, max_epochs=epochs)
Metrics are stateful and are computed in a "streaming" way. Very similar to the approach in tnt. The API could look something like this:
class Metric(object):
def reset(self): ...
def accumulate(self, *args, **kwargs): ...
def compute(self): ...
Metric#accumulate would receive the the output from the inference function at each iteration. As a concrete example, here is a sketch of what an Accuracy metric might look like:
class Accuracy(Metric):
def reset(self):
self._correct = 0
self._count = 0
def accumulate(self, y, y_hat):
predictions = y_hat.data.max(1, keepdim=True)[1]
self._correct += predictions.eq(y.data.view_as(predictions)).sum()
self._count += predictions.shape[0]
def compute(self):
return self._correct / self._count
Callbacks are a (potentially stateful) abstraction for the coordination of multiple event handlers. They are simply objects that implement methods corresponding to the fired events. They can implement any number of the events/methods they want. For example, a checkpoint callback might look something like this:
class Checkpoint(Callback):
def __init__(self, model, dirname, interval):
super().__init__()
self._model = model
self._dirname = dirname
self._interval = interval
os.makedirs(dirname, exist_ok=True)
def training_iteration_completed(self, trainer):
if (trainer.current_iteration > 0) and ((trainer.current_iteration % self._interval) == 0):
torch.save(self._model.state_dict(), self._path_for(trainer.current_iteration))
def _path_for(self, iteration):
return os.path.join(self._dirname, '{}.pkl'.format(iteration))
In this case, Checkpoint only implements the single method/event training_iteration_completed. A more complex example would be a Validate callback. Here is a rough sketch just to give you an idea:
class Validate(Callback):
def __init__(self, iterator, metric, logger, interval):
self._iterator = iterator
self._metric = metric
self._logger = logger
self._interval = interval
def training_iteration_completed(self, trainer):
if (trainer.current_iteration > 0) and ((trainer.current_iteration % self._interval) == 0):
trainer.validate(self._iterator)
def validation_starting(self, trainer):
self._metric.reset()
def validation_iteration_completed(self, trainer):
self._metric.accumulate(trainer.validation_history[-1])
def validation_completed(self, trainer):
trainer.validation_history.clear()
result = self._metric.compute()
self._logger.log(trainer.current_iteration, result)
The big win of callbacks is that as a user i can just be like "i know i need to checkpoint so let me just grab the checkpoint callback". You don't need to think about which event to attach it to and you wont run into issues where you attach it to the wrong event. The callback encapsulates all this necessary information itself.
Trainer would now only take the update and inference functions at instantiation.Trainer#validate would take a validation_data argument. This makes it easy to support multiple validation sets.Trainer#run would take a training_data argument. This makes it consistent with Trainer#validate. I would also argue that it makes more sense for the data to be passed to run. This makes it possible to train on multiple datasets as well.Trainer would take a list of callbacks at instantiation.Trainer._fire_event would be modified to trigger callbacks, e.g.def _fire_event(self, event_name):
for callback in self._callbacks:
getattr(callback, event_name.value)(self)
...
What are the main use cases for history in your experience so far? From my perspective, the main use of history is for computing metrics but if we were to take an approach like the one above, i think that need is mostly obviated. What do you think? If there are other use cases, maybe its worth having both a "history" and a "state"?
A state object that can be passed to each event/callback and can be written to is super convenient. If we make metrics stateful, using a "state"-based approach will also ensure that we don't eat up a ton of memory unnecessarily.
I would love for ignite to "make the easy things easy and the hard things possible". I think the the decision to have the user provide the update and inference closures was a great idea and makes the hard things possible. I think the proposals above do a lot to make the easy things easy while not taking away from the already solid foundation. It should make it easier for a new user to jump in and start benefitting from ignite right away.
Let me know if anything is unclear or if you'd like me to go into more detail on anything. Excited to hear what you all think and see where this goes!
Hi all,
I'd also be interested to know what future directions is ignite is planning to take.
If I may add my few cents:
Callables, I think most people will write those things anyway, so it would be nice to have them built-in.I also liked the idea of State in tnt (albeit I used a slightly modified version, since the dict-like access is just terrible). I'd frequently use it to store / accumulate stuff during training loop, only to flush it on epoch end, or to exchange information between training / validation events.
Let me advocate against callbacks however.
The big win of callbacks is that as a user i can just be like "i know i need to checkpoint so let me just grab the checkpoint callback". You don't need to think about which event to attach it to and you wont run into issues where you attach it to the wrong event.
Maybe it's a matter of personal taste, but I really dislike this way of setting up the experiment. It's much easier for me to read through a seqence of events like
@engine.on_epoch_start
def foo():
...
where I clearly see what will happen at each step of my experiment. I remember keras had some magic callbacks that I would always forget how to implement / read, and when exactly they will be called.
So instead of
t = Trainer(update, inference, callbacks=[
Checkpoint(model, '/tmp/checkpoints', interval=500)
])
I'd much rather have a much more flexible approach:
t = Trainer(update, inference)
@t.on_epoch_done
def checkpoint(trainer, [state]):
torch.save(...)
with ignite providing only a set of most-commonly used utilities.
@jasonkriss @elanmart, thanks for the detailed feedback/suggestions. Here's my thoughts on what's been discussed so far:
On the idea of predefined callables:
On Metrics and State:
History and the proposed Metrics. For example, the History object right now provides the accumulation of all raw data that you would want to compute metrics for. It also provides some basic summary statistics on this data and for any other metrics the user would define a callable (class or function) to essentially define the compute function of Metrics. So in my understanding History + user-defined compute function is equivalent to a Metric class unless I am misunderstanding?History object can be used to also store state, it may be cleaner to also provide a State object that is available to the handlers (which provides a clear() function). Should this live in the trainer (i.e. trainer.state) or be explicitly passed to the handlers alongside the trainer? On Callbacks:
I feel like the Callback wrapper has a few down sides:
@ignite.on_training_epoch_complete
def checkpoint(trainer, ...):
torch.save(...)
instead of
def checkpoint(trainer,...):
torch.save(....)
trainer.add_event_handler(TRAINING_EPOCH_COMPLETE, checkpoint, ...)
On multiple validation sets:
validation_data_loader + inference function directly to the trainer.validate function. This way calling validate is up to the user, as is the what and how you validate. The other approach could be to accept a list of validation_data and validation_inference_functions to the constructor.History if you have multiple validation sets running? i.e. should it be the responsibility of the user to clear the history before each validation run on a different set or should we clear the history internally (probably not)?@alykhantejani @elanmart, really appreciate the feedback! You make some very good points.
It seems like this is pretty uncontroversial. I'll open a PR to get this started. I definitely agree that "callables" is far too vague. We can try to think of better names in the meantime.
One other thought on this topic. Is there any interest in combining the update and inference functions into a single abstraction? For example,
class Supervised(object):
def __init__(self, model, optimizer, loss_fn):
self._model = model
self._optimizer = optimizer
self._loss_fn = loss_fn
def update(self, batch):
self._model.train()
self._optimizer.zero_grad()
x, y = batch
y_hat = self._model(x)
loss = self._loss_fn(y_hat, y)
loss.backward()
self._optimizer.step()
return loss.data[0]
def predict(self, batch):
self._model.eval()
x, y = batch
y_hat = self._model(x)
return y, y_hat
You could then pass this object as the single argument to Trainer#__init__. What do you all think? Is it worth combining them since they will most often be used together or is it better to keep them separate?
@alykhantejani you are correct that History + compute is functionally equivalent to Metrics. The only effective difference is memory requirement. For example, the Accuracy example would simply store two integers whereas History has to store the complete history. That being said, I think this is less of an issue than i previously thought. As long as you are just attaching a couple scalars to the history at each point, memory will not be an issue. It will only be an issue if you start attaching a batch of images at each step. So for that reason, I think we should hold off on any Metric abstraction. The History + compute is much simpler/cleaner.
I think my only driving use case for state is for things (like images) that i want to have access to in handlers but that i don't want to accumulate in history/memory. This isn't really a strong requirement at this point in time (if we aren't using Metrics) so I would probably vote for leaving it out for now to keep things simple. History is a powerful concept and we can revisit the state idea later if necessary as other things develop a bit more.
You both bring up valid points/concerns. I think the biggest argument against them is that they aren't nearly as explicit/obvious. The biggest reason I like them is for coordinating logic across multiple events. That being said, I think this could be less common than I anticipated. Especially if we stick with History over Metrics and take an approach to validation similar to what I lay out below.
What do you all think of an approach like this...
We introduce a validate handler. Something like this...
def validate(trainer, validation_data, interval=None, clear_history=True):
if clear_history:
trainer.validation_history.clear()
if (interval and ((trainer.current_iteration % interval) == 0) or not interval:
trainer.validate(validation_data)
You could then run this every epoch with:
trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED, validate, validation_loader)
or every n (500 in this case) iterations , with:
trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, validate, validation_loader, interval=500)
We also introduce a report handler (or maybe there is a better name):
def report(trainer, compute_fn, reporter_fn):
result = compute_fn(trainer)
reporter_fn(trainer.current_epoch, current_iteration, result)
We can have predefined compute_fns for the most common things like accuracy and loss. Or maybe it is better to have separate report_accuracy, report_loss handlers. Either way. We can also have predefined reporter_fns for visdom and text logging and whatever else.
The previous example I introduced above would now look something like this...
train_loader = ...
val_loader = ...
model = ...
optimizer = ...
loss_fn = ...
update = SupervisedUpdate(model, optimizer, loss_fn)
inference = SupervisedInference(model)
Trainer(update, inference)
trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, validate, train_loader, interval=500)
trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, validate, val_loader, interval=500)
trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, checkpoint, model, '/tmp/checkpoints', interval=500)
trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, report, accuracy, VisdomLogger())
trainer.run(train_loader, max_epochs=epochs)
This is roughly the same amount of code while being much more explicit and allowing the user to more easily reason about the sequence of events.
Of course the details would still need to be worked out a lil bit, but what are your thoughts on the overall ideas?
Hi @jasonkriss,
Sorry for the delay in response, it's been a bit busy around the festive period 馃巺
On the "callables". I think in order to support multiple validation sets/procedures we might change the API of validate to support taking in the dataloader + inference function. Therefore, passing a class to the trainer.__init__ wont make much sense (over just the training dataloader and update function). However, this doesn't mean we can't group the update + predict/inference functions in a class and just pass the relevant one to the relevant function i.e.
trainer = Trainer(train_dataloader, Supervised.update)
trainer.validate(val_dataloader, Supervised.predict)
Regarding the wrapper of validate, I think it could be something like:
def validate(dataloader, inference_fn, iteration_interval=None, epoch_interval=None, clear_history=False)
where the user can choose to validate every n iterations or n epochs (but only one or the other)?
For the report handler, something towards this was started in ignite.handlers.logging. It currently only logs the summary statistics provided on the History objects. The problem with the current logging handler is that the logging text is predefined by the library, which may not suit all users and is hard/messy to make customizable without overcomplicating it. I'm not sure how useful this module is, but it allows people to get off the ground fairly quickly.
Along the lines of what you were suggesting we could have something like the following. ignite.metrics which has common metrics, where a metric is a callable with the interface metric(y, y_pred) and returns a float. We could include (to start with) the following metrics (following Keras):
However, a wrapper for reporting can be tricky as this is hard to make generic. Instead, users can just call the metric and print/log/log to visdom from within their handler i.e.:
def log_validation_metrics(trainer):
accuracy = metrics.top_k_accuracy(k=1, trainer.validation_history)
accuracy = (accuracy * 100.) / len(trainer.validation_data.dataset)
visdom.line(X=np.array([trainer.current_epoch]),
Y=np.array([accuracy]),
win=val_accuracy_plot_window,
update='append')
logger.info("validation after {} iterations accuracy = {}".format(trainer.current_iteration, accuracy))
So, to conclude:
trainer.validate to accept the dataloader and inference function to support validation on multiple validation setsvalidate handler that wraps trainer.validate to run on intervals and clear historyignite.metrics to add common metrics with the interface (y, y_pred)@jasonkriss @elanmart would these answer some of your concerns/feature requests? If so, I can go ahead and make issues/tasks from these and start working on them
@alykhantejani no worries! I think this looks great!
I'm definitely on board with the metrics approach. I agree that its hard to make the reporting part generic so I think this is smart. Perhaps later it becomes clear how we can make this even easier but for now this sounds great! On that same note, I think taking this metrics approach means that the stuff in ignite.handlers.logging is less useful than it was before. Should we maybe consider removing it to simplify things? Whatever you think on that, I'm cool with.
The only other thing i'm still not 100% solid on is the Trainer#validate signature. While i often validate on multiple datasets, it's hard for me to imagine a use case where I would want to use two different validation/inference functions. Is there a use case you have in mind here? I think it would be nice to have the signatures of Trainer#run and Trainer#validate mirror each other by passing the update and inference functions (or a class with both) to Trainer#__init__ but then passing data to Trainer#run and Trainer#validate. What do you think?
@jasonkriss
On ignite.handlers.logging you may be right. We can probably remove this after metrics are in, it shouldn't be a problem as we're still in alpha so no need to worry about backwards compat stuff.
On the trainer, I agree there's probably no reason to pass the inference function to the validate function. Whilst I agree that making the interface of run and validate consistent by taking in the dataloader, I also cant envisage when you would pass a different training data loader in. I'm ok with doing this though
As for the class wrapping the update and predict/inference function, I think by making the __init__ take in separate functions gives the user the flexibility on how they want to structure their code. i.e. you can still have a class to encapsulate the logic and pass in SupervisedTraining.update and SupervisedTraining.predict. How we write the examples will probably influence the "preferred" method too
Hey, sorry for the lack of response.
Thanks for opening conrete issues, I think I could work on some of them. Please give me a day or two to sort out other things.
@elanmart that'd be great. No rush. Thanks!
@alykhantejani 馃憤 looks great! Thanks for breaking out the issues. I'll continue the relevant discussions there.
Closing this in favor of individual issues.
@kirk86 was this response meant for another issue/PR? perhaps in torchvision?
@alykhantejani oops, my apologies I was following the long thread and I missed the title. I just deleted so to keep things clear. Thanks!
Hello,
First of all thank you for the good work done on ignite. I enjoy very much using it.
I opened #216 to submit the same idea as the Callback class without reading this thread.
Reading this I must agree that although the idea comes naturally, they are not as necessary as they seems (as @jasonkriss said)
My strongest case for them was a Visdom class I had where I would accumulate in a list the loss on every iteration, and send to Visdom server on every epoch.
For the sake of transparency: I did not test if sending to Visdom on every iteration would be too much (probably not).
Although allow me to disagree with @alykhantejani. The reason why I like ignite is because it is not linear. Once I have a handle to log, plot, checkpoint, etc. I want to keep it in a separate module where it would not grab my attention.
For that matter I do not like the @trainer.on(...) because it's an invitation to stuffing every function at the same place as where you instantiate your engine, your model and so on.
I may use it on very specific occasion related to the math/ the model itself that I want to keep this close to the training loop. Otherwise, everything auxiliary must disappear from my eyes.
In conclusion
I don't think Callback calss worth maintaining another API for now, but still wanted to share some feedback in case somebody might relate.
Most helpful comment
@jasonkriss @elanmart, thanks for the detailed feedback/suggestions. Here's my thoughts on what's been discussed so far:
On the idea of predefined callables:
On
MetricsandState:Historyand the proposedMetrics. For example, theHistoryobject right now provides the accumulation of all raw data that you would want to compute metrics for. It also provides some basic summary statistics on this data and for any other metrics the user would define a callable (class or function) to essentially define thecomputefunction ofMetrics. So in my understandingHistory+ user-definedcomputefunction is equivalent to aMetricclass unless I am misunderstanding?Historyobject can be used to also store state, it may be cleaner to also provide aStateobject that is available to the handlers (which provides aclear()function). Should this live in the trainer (i.e.trainer.state) or be explicitly passed to the handlers alongside the trainer?On
Callbacks:I feel like the
Callbackwrapper has a few down sides:instead of
On multiple validation sets:
validation_data_loader+inference functiondirectly to thetrainer.validatefunction. This way callingvalidateis up to the user, as is the what and how you validate. The other approach could be to accept a list ofvalidation_dataandvalidation_inference_functionsto the constructor.Historyif you have multiple validation sets running? i.e. should it be the responsibility of the user to clear the history before each validation run on a different set or should we clear the history internally (probably not)?