Internal variables (batch, predictions, etc) of the training loop (training + validation step) should be made transparent to callbacks. Right now, there is no way to access these internal variables of the training loop through callbacks without making them as attributes of lightning module. This doesn't sound optimal as it pollutes the lightning module with non-essential code.
Use case: Visualize images and predictions from a training batch. Right now, there are two ways:
log method in pl module and call this from training_step method of pl module. By doing this, we are essentially polluting pl module with non-essential code.Hi! thanks for your contribution!, great first issue!
each callback method has accessible trainer and model...
what PL version are using? because we made these changes quite recently... :]
cc: @ethanwharris
Hi @Borda, I am using pl version 0.7.1. In this version, callbacks do have access to trainer and pl module as you said. However, I wanted to point out that I can't access the interval variables of pl module training step method from a callback. Example:
Let's say we have the following module class from pl examples:
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
return {'loss': F.cross_entropy(y_hat, y)}
I'd try to write the visualization callback as follows to visualize images from training batch:
class VizCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
# here I wish to have access to images `x` from training_step
grid_img = torchvision.utils.make_grid(x)
self.logger.experiment.log_image(grid_img)
The issue is I can't access training samples x in the example above from a callback without linking it to pl module. As the code gets more complicated, I wish to have access to several variables of training_step which could be model predictions, attention maps and so on from a callback. In order to access these variables from a callback, I'd have to make them as attributes of pl module and update them in every training step which I think is non-ideal. See what I mean?
I have figured a way to retrieve these informations in a callback. Use hiddens, not sure if it's the right way to do it but at least it works with little modifications in your model, here is my training_step() function
def training_step(self, batch, batch_idx):
inputs, targets = batch
preds = self(inputs)
loss = #whatever
hiddens = {"inputs": inputs.detach(), "predictions": preds.detach(), "targets": targets.detach()}
return {"loss": loss, "hiddens": hiddens}
with this you will get everything in your callback
class VizCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
# here I wish to have access to images `x` from training_step
if trainer.hiddens:
inputs = trainer.hiddens["inputs"]
preds = trainer.hiddens["predictions"]
targets = trainer.hiddens["targets"]
grid_img = torchvision.utils.make_grid(inputs)
self.logger.experiment.log_image(grid_img)
But there is nothing for the validation loop as far as I see there is no on_batch_end for the validation am I wrong?
@czotti good hack haha but we need an actual way of doing this.
@PyTorchLightning/core-contributors ?
iirc, you can return any arbitrary key and it will be added to trainer.callback_metrics - you don't need to override the hiddens key
I think any key returning based idea (like overriding hiddens key/returning dictionary relevant for callback) would spoil the neat separation between callbacks and pl module because user would have to remember to cache/return variables in pl module which are relevant for callback and when user wishes to switch off the callback, they would have to delete the corresponding cache/return code in pl module. Thus, I suggested all the internal variables of training and validation to be exposed to callbacks by default so that they don't have to remember cache relevant stuff for callback in pl module. Please let me know if I wasn't clear. :)
iirc, you can return any arbitrary key and it will be added to
trainer.callback_metrics- you don't need to override thehiddenskey
Unfortunately, in the case of on_batch_end callback the Trainer's self.callback_metrics is not yet populated.
@williamFalcon I think the way to go here is to have more of the looping variables as trainer attributes. However, we need to be careful with this as we don't want our trainer to be bloated.
we can offer some read-only properties...
I haven't gone through this in detail but it seems relevant to the discussion here:
Repo: https://github.com/pberkes/persistent_locals
Blog post: http://code.activestate.com/recipes/577283-decorator-to-expose-local-variables-of-a-function-/
Quick summary from the README of above repo:
The proposed solution consists of a decorator that makes the local variables accessible from a read-only property of the function, 'locals'. For example:
@persistent_locals
def is_sum_lt_prod(a,b,c):
sum = a+b+c
prod = a*b*c
return sum<prod
after calling the function, e.g. is_sum_lt_prod(2,1,2), we can analyze
the intermediate results as
is_sum_lt_prod.locals
which returns
{'a': 2, 'b': 1, 'c': 2, 'prod': 4, 'sum': 5}
Please ignore my above comment, I don't think the solution is neat and maintainable...
Yes, this would be useful. Currently there's no proper way to pass batch computation result to on_test_batch_end for a custom Callback (hiddens and callback_metrics won't work either). As a hacky workaround I'm just attaching the results of test_step to self lol.
I needed to do this in my code as well (perform some weird operation on inference outputs and store them in a custom format to be able to log to a csv later), and I was thinking of setting model class attributes i.e. self.current_test_output = ... and then accessing it in the callback.
But then I thought maybe this would cause issues during multiprocessing, since I'm not exactly sure how the model is split/synced during DP/DDP.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Has this been resolved? I am trying to implement a callback for visualization and currently facing the same issue.
Are there any solutions \ workarounds for this?
having a similar issue of complex reporting of predictions that I want to separate from the main module
I'm using czotti's hack with self.hiddens (see above) but I agree that it poluted the pl module.
Most helpful comment
we can offer some read-only properties...