Hello,
I am working on a problem where I am using an objective which depends on the entire minibatch. Is there a way to access this when creating a custom objective function?
To illustrate this with code:
MSE looks like this and takes as arguments y_true and y_pred
def mean_squared_error(y_true, y_pred):
return K.mean(K.square(y_pred - y_true), axis=-1)
Now say I wanted my loss for a minibatch to be e.g. the product (instead of the sum) of the MSEs between each y_true and y_pred in my batch (to use a contrived example). Can I access the entire minibatch to define this e.g. if y_true_batch was for the sake of argument a list of each y_true in the batch
def product_error(y_true_batch, y_pred_batch):
error = 1
for i in range(len(y_true_batch)):
error *= mean_squared_error(y_true[i], y_pred[i])
return error
The actual function definition above doesn't really matter (and this one in particular wouldn't work due to the for loop), the real question is can I access the equivalent of y_true_batch when defining a custom objective?
Thanks!
_Note:_
I've searched and have not seen anyone addressing this, although a similar comment was made in the thread of this issue #369
y_true and y_pred are batches of predictions and targets, already.
Most helpful comment
y_trueandy_predare batches of predictions and targets, already.