Keras: Access entire minibatch inside custom objective function

Created on 1 Sep 2016  路  1Comment  路  Source: keras-team/keras

Hello,
I am working on a problem where I am using an objective which depends on the entire minibatch. Is there a way to access this when creating a custom objective function?

To illustrate this with code:
MSE looks like this and takes as arguments y_true and y_pred

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

Now say I wanted my loss for a minibatch to be e.g. the product (instead of the sum) of the MSEs between each y_true and y_pred in my batch (to use a contrived example). Can I access the entire minibatch to define this e.g. if y_true_batch was for the sake of argument a list of each y_true in the batch

def product_error(y_true_batch, y_pred_batch):
    error = 1
    for i in range(len(y_true_batch)):
         error *= mean_squared_error(y_true[i], y_pred[i])     
    return error

The actual function definition above doesn't really matter (and this one in particular wouldn't work due to the for loop), the real question is can I access the equivalent of y_true_batch when defining a custom objective?

Thanks!

_Note:_
I've searched and have not seen anyone addressing this, although a similar comment was made in the thread of this issue #369

Most helpful comment

y_true and y_pred are batches of predictions and targets, already.

>All comments

y_true and y_pred are batches of predictions and targets, already.

Was this page helpful?
0 / 5 - 0 ratings