As discussed with @vfdev-5 in #309, it could be sometimes useful to provide a handler to store all output prediction history for visualization purposes. Following is my first try to implement it.
import torch
from ignite.engine import Events
class EpochOutputStore(object):
"""EpochOutputStore handler to save output prediction and target history
after every epoch, could be useful for e.g., visualization purposes.
Note:
This can potentially lead to a memory error if the output data is
larger than available RAM.
Args:
output_transform (callable, optional): a callable that is used to
transform the :class:`~ignite.engine.engine.Engine`'s
``process_function``'s output into the form `y_pred, y`, e.g.,
lambda x, y, y_pred: y_pred, y
Examples:
.. code-block:: python
import ...
eos = EpochOutputStore()
trainer = create_supervised_trainer(model, optimizer, loss)
train_evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})
eos.attach(train_evaluator)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
train_evaluator.run(train_loader)
y_pred, y = eos.get_output()
# plottings
"""
def __init__(self, output_transform=lambda x: x):
self.predictions = None
self.targets = None
self.output_transform = output_transform
def reset(self):
self.predictions = []
self.targets = []
def update(self, engine):
y_pred, y = self.output_transform(engine.state.output)
self.predictions.append(y_pred)
self.targets.append(y)
def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.reset)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.update)
def get_output(self, to_numpy=False):
prediction_tensor = torch.cat(self.predictions, dim=0)
target_tensor = torch.cat(self.targets, dim=0)
if to_numpy:
prediction_tensor = prediction_tensor.cpu().detach().numpy()
target_tensor = target_tensor.cpu().detach().numpy()
return prediction_tensor, target_tensor
@ZhiliangWu thank you for this FR! It looks good!
EDIT : it is a FR and I react as it was a PR, sorry 馃槉 The following is only if you would contribute with a PR
Please follow the contribution guideline https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md
In particular, you have to use Pull Request (PR) from GitHub. If you are not familiar, look https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#send-a-pr
In addition, could you provide some tests? It would be very nice to add this handler in the doc. I suppose it will be located in ignite.contrib.handlers so I think relevant doc is https://github.com/pytorch/ignite/blob/master/docs/source/contrib/handlers.rst
Thank you again 馃槉
Hi @sdesrozis, I tried to write some tests according to the test cases in the repository. As I am not very familiar with testing with the engine object. Could you kindly check it?
import numpy as np
import torch
from ignite.engine.engine import Engine, Events
import pytest
from stores import EpochOutputStore
@pytest.fixture
def dummy_trainer():
y_pred = torch.zeros(5)
y = torch.ones(5)
def dummy_process_function(engine, batch):
return y_pred, y
dummy_trainer = Engine(dummy_process_function)
return dummy_trainer
@pytest.fixture
def eos():
return EpochOutputStore()
class TestEpochOutputStore(object):
def test_reset(self, dummy_trainer, eos):
eos.attach(dummy_trainer)
dummy_trainer.run(range(2))
eos.reset()
assert eos.predictions == []
assert eos.targets == []
def test_update_one_iteration(self, dummy_trainer, eos):
eos.attach(dummy_trainer)
dummy_trainer.run(range(1))
assert len(eos.predictions) == 1
assert len(eos.targets) == 1
def test_update_five_iterations(self, dummy_trainer, eos):
eos.attach(dummy_trainer)
dummy_trainer.run(range(5))
assert len(eos.predictions) == 5
assert len(eos.targets) == 5
def test_attatch(self, dummy_trainer, eos):
eos.attach(dummy_trainer)
assert dummy_trainer.has_event_handler(eos.reset, Events.EPOCH_STARTED)
assert dummy_trainer.has_event_handler(eos.update,
Events.ITERATION_COMPLETED)
def test_get_output(self, dummy_trainer, eos):
eos.attach(dummy_trainer)
dummy_trainer.run(range(1))
assert len(eos.get_output()) == 2
def test_get_numpy_output(self, dummy_trainer, eos):
eos.attach(dummy_trainer)
dummy_trainer.run(range(1))
y_pred, y = eos.get_output(to_numpy=True)
assert isinstance(y_pred, np.ndarray)
assert isinstance(y, np.ndarray)
On my laptop it seems to be okay:
$ pytest test_stores.py -vvv
========================================================================= test session starts ==========================================================================
platform linux -- Python 3.7.7, pytest-6.0.1, py-1.9.0, pluggy-0.13.1 -- xxxxxxxxxx/miniconda2/envs/torch_play/bin/python
cachedir: .pytest_cache
rootdir: xxxx
collected 6 items
test_stores.py::TestEpochOutputStore::test_reset PASSED [ 16%]
test_stores.py::TestEpochOutputStore::test_update_one_iteration PASSED [ 33%]
test_stores.py::TestEpochOutputStore::test_update_five_iterations PASSED [ 50%]
test_stores.py::TestEpochOutputStore::test_attatch PASSED [ 66%]
test_stores.py::TestEpochOutputStore::test_get_output PASSED [ 83%]
test_stores.py::TestEpochOutputStore::test_get_numpy_output PASSED [100%]
========================================================================== 6 passed in 0.34s ===========================================================================
Hi @ZhiliangWu !
Your tests mixing your metric and engine look good.
Writing good and full covering tests is a very hard task. I think you could have a look in tests/ignite/metrics/test_epoch_metric.py to find a good inspiration. IMO your tests should be more or less similar
Feel free to look others tests of metrics, it's very informative.
However,
you should do a PR to facilitate comments on specific parts of your code. As I mentioned in my previous post, you can check here https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#send-a-pr
By the way, thank you again for your work !! 馃槉
Hi @ZhiliangWu, thanks for the FR. I just have a remark on the implementation. Maybe, we can simply store all outputs in the list, instead of separating into predictions/targets ?
import torch
from ignite.engine import Events
class EpochOutputStore: # no need to inherit from object for Python 3 code
def __init__(self):
self.data = None
def reset(self):
self.data = []
def update(self, engine):
self.data.append(engine.state.output)
def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.reset)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.update)
This way, it can be more generic (for example, if some users would like to store also x or some other values). But I agree that in this case we can not provide a helper method like get_output as in your implementation...
What do you think ?
Hi @vfdev-5, yes, I would agree that save everything on a data list is more general. Maybe I am a bit too focused on my understanding of output being the output of the network. I will try to figure out the PR process and send an update.
Most helpful comment
Hi @vfdev-5, yes, I would agree that save everything on a
datalist is more general. Maybe I am a bit too focused on my understanding ofoutputbeing the output of the network. I will try to figure out the PR process and send an update.