Pytorch-lightning: Allow callbacks to access internal variables of training loop

Created on 26 Mar 2020  ·  18Comments  ·  Source: PyTorchLightning/pytorch-lightning

🚀 Feature

Internal variables (batch, predictions, etc) of the training loop (training + validation step) should be made transparent to callbacks. Right now, there is no way to access these internal variables of the training loop through callbacks without making them as attributes of lightning module. This doesn't sound optimal as it pollutes the lightning module with non-essential code.

Motivation

Use case: Visualize images and predictions from a training batch. Right now, there are two ways:

  1. Add a log method in pl module and call this from training_step method of pl module. By doing this, we are essentially polluting pl module with non-essential code.
  1. Write a visualization callback. As of now, callback has access to pl module and trainer but still, it can’t access the variables (images, predictions, etc) in the training step. We can make these variables as attributes of pl module but then updating these attributes in every training step (so that callback can access it) would also count as “non-essential code” which would defeat the point of the callback. It also spoils the neat separation between callbacks and pl module as we'll be caching/updating attributes in pl module even if we switch off the callback.
enhancement help wanted let's do it! question won't fix

Most helpful comment

we can offer some read-only properties...

All 18 comments

Hi! thanks for your contribution!, great first issue!

each callback method has accessible trainer and model...
what PL version are using? because we made these changes quite recently... :]
cc: @ethanwharris

Hi @Borda, I am using pl version 0.7.1. In this version, callbacks do have access to trainer and pl module as you said. However, I wanted to point out that I can't access the interval variables of pl module training step method from a callback. Example:

Let's say we have the following module class from pl examples:

class LitModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'loss': F.cross_entropy(y_hat, y)}

I'd try to write the visualization callback as follows to visualize images from training batch:

class VizCallback(pl.Callback):

    def on_batch_end(self, trainer, pl_module):
        # here I wish to have access to images `x` from training_step
        grid_img = torchvision.utils.make_grid(x)
        self.logger.experiment.log_image(grid_img)

The issue is I can't access training samples x in the example above from a callback without linking it to pl module. As the code gets more complicated, I wish to have access to several variables of training_step which could be model predictions, attention maps and so on from a callback. In order to access these variables from a callback, I'd have to make them as attributes of pl module and update them in every training step which I think is non-ideal. See what I mean?

I have figured a way to retrieve these informations in a callback. Use hiddens, not sure if it's the right way to do it but at least it works with little modifications in your model, here is my training_step() function

def training_step(self, batch, batch_idx):
    inputs, targets = batch
    preds = self(inputs)
    loss = #whatever 
    hiddens = {"inputs": inputs.detach(), "predictions": preds.detach(), "targets": targets.detach()}
    return {"loss": loss, "hiddens": hiddens}

with this you will get everything in your callback

class VizCallback(pl.Callback):
    def on_batch_end(self, trainer, pl_module):
        # here I wish to have access to images `x` from training_step
        if trainer.hiddens:
            inputs = trainer.hiddens["inputs"]
            preds = trainer.hiddens["predictions"]
            targets = trainer.hiddens["targets"]
            grid_img = torchvision.utils.make_grid(inputs)
            self.logger.experiment.log_image(grid_img)

But there is nothing for the validation loop as far as I see there is no on_batch_end for the validation am I wrong?

@czotti good hack haha but we need an actual way of doing this.
@PyTorchLightning/core-contributors ?

iirc, you can return any arbitrary key and it will be added to trainer.callback_metrics - you don't need to override the hiddens key

I think any key returning based idea (like overriding hiddens key/returning dictionary relevant for callback) would spoil the neat separation between callbacks and pl module because user would have to remember to cache/return variables in pl module which are relevant for callback and when user wishes to switch off the callback, they would have to delete the corresponding cache/return code in pl module. Thus, I suggested all the internal variables of training and validation to be exposed to callbacks by default so that they don't have to remember cache relevant stuff for callback in pl module. Please let me know if I wasn't clear. :)

iirc, you can return any arbitrary key and it will be added to trainer.callback_metrics - you don't need to override the hiddens key

Unfortunately, in the case of on_batch_end callback the Trainer's self.callback_metrics is not yet populated.

@williamFalcon I think the way to go here is to have more of the looping variables as trainer attributes. However, we need to be careful with this as we don't want our trainer to be bloated.

we can offer some read-only properties...

I haven't gone through this in detail but it seems relevant to the discussion here:

Repo: https://github.com/pberkes/persistent_locals
Blog post: http://code.activestate.com/recipes/577283-decorator-to-expose-local-variables-of-a-function-/

Quick summary from the README of above repo:

The proposed solution consists of a decorator that makes the local variables accessible from a read-only property of the function, 'locals'. For example:

@persistent_locals
def is_sum_lt_prod(a,b,c):
    sum = a+b+c
    prod = a*b*c
    return sum<prod

after calling the function, e.g. is_sum_lt_prod(2,1,2), we can analyze
the intermediate results as

is_sum_lt_prod.locals

which returns

{'a': 2, 'b': 1, 'c': 2, 'prod': 4, 'sum': 5}

Please ignore my above comment, I don't think the solution is neat and maintainable...

Yes, this would be useful. Currently there's no proper way to pass batch computation result to on_test_batch_end for a custom Callback (hiddens and callback_metrics won't work either). As a hacky workaround I'm just attaching the results of test_step to self lol.

I needed to do this in my code as well (perform some weird operation on inference outputs and store them in a custom format to be able to log to a csv later), and I was thinking of setting model class attributes i.e. self.current_test_output = ... and then accessing it in the callback.

But then I thought maybe this would cause issues during multiprocessing, since I'm not exactly sure how the model is split/synced during DP/DDP.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Has this been resolved? I am trying to implement a callback for visualization and currently facing the same issue.

Are there any solutions \ workarounds for this?
having a similar issue of complex reporting of predictions that I want to separate from the main module

I'm using czotti's hack with self.hiddens (see above) but I agree that it poluted the pl module.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

baeseongsu picture baeseongsu  ·  3Comments

anthonytec2 picture anthonytec2  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

justusschock picture justusschock  ·  3Comments