Keras: Add option to return model predictions to train_on_batch

Created on 13 Nov 2017  路  11Comments  路  Source: keras-team/keras

It can be useful to get the prediction of a model during training (e.g. for debugging, boosting, etc.) without running an additional forward-pass. As far as I know there is no support of this using train_on_batch but the underlying backends do support this.

At first I thought it would be possible by using a custom metric function (see below), but unfortunately this doesn't seem to work as it always returns a float value.

def metric(y_true, y_pred):
    return y_pred
```` 

After inspection of the Keras code I noticed that the `test_on_batch` merely uses the model outputs to fetch the predictions. This resulted in the following workaround to get labels during training:

```python
model = models.Model(inputs, outputs)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Explicitly add the output as a new metric
model.metrics_tensors += model.outputs

This method works but depends on the internal works of the Keras training module. It would be nice if there is a more cleaner way that directly uses the Keras api instead of its internal representation.

Is there another way (that I missed) to get network output during training without using any 'hacks'? If not, would it be possible to add a simple flag to make_train_function() to return network output? If needed, I can draft up a pull request. Other suggestions/ideas are welcome.

Most helpful comment

It would be really nice to have the output of the network as an output without doing another forward pass! It would make hard negative mining or boosting much easier.

All 11 comments

I would also add, have the same functionality for evaluate on batch.

That would enable metrics to be computed on entire epochs instead of batch averages (which doesn't work for things like AUC, etc...)

To add to my 'hack' above: this does not work with loading/saving models. After loading the train function is created without adding the model outputs to the metrics. In other words, the hack only works after compiling the model.

I'm currently working on a PR for this. I will add a flag to both the train and test function.

It would be really nice to have the output of the network as an output without doing another forward pass! It would make hard negative mining or boosting much easier.

This PR would allow to pass extra fetches to the train function and retrieving the value afterwards.

@hermansje it would be a nice addition to do this within the keras-api. For now I've used the approach below:

After compiling ad the output tensor as a metric:

model.metrics_tensors += model.outputs
model.metrics_names += ['predictions']

This has to be repeated when loading a model (after a save). The train and test functions need to be set to None, this makes sure that the model takes the new tensors into account:

model.train_function = None
model.test_function = None
model.metrics_tensors += model.outputs
model.metrics_names += ['predictions']

As an additional note in favor of adding this functionality, it would reduce training times nearly by half for some common GAN configurations.

@wouterbulten, your hack works great for calculating the average outputs per batch, but is there a way to keep the output of every single sample in the testing dataset?

@ale152 what do you mean exactly? Using this method you will get the predictions for each element in the batch, so for the full set you just need to store them yourself.

Doesn't train_on_batch now return an accuracy metric? Am I misunderstanding this? Thanks in advance.

Hi, try this :

c1=train_on_batch()
print(c1)

This will print the current state of the loss and of the metric you gave in your compile()

Best, -MH

Was this page helpful?
0 / 5 - 0 ratings

Related issues

kylemcdonald picture kylemcdonald  路  3Comments

anjishnu picture anjishnu  路  3Comments

braingineer picture braingineer  路  3Comments

snakeztc picture snakeztc  路  3Comments

KeironO picture KeironO  路  3Comments