Keras: Is it possible to implement sparse mean squared error loss in Keras

Created on 21 Jun 2017  路  4Comments  路  Source: keras-team/keras

I wanted to modify the following keras mean squared error loss (MSE) such that the loss is only computed sparsely.

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

For instance, if output y_true is a 3 channel image, where the 3rd channel is non-zero at only those pixels where loss is to be computed. The loss is computed with the 2 channel image y_pred at only pixels where 3rd channel of y_true is non zero. Any idea if it is possible to implement a custom sparse MSE loss in Keras?

stale

Most helpful comment

This is not the exact loss you are looking for, but I hope it will give you a hint to write your function:

def masked_mse(mask_value):
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        masked_mse = K.sum(masked_squared_error, axis=-1) / K.sum(mask_true, axis=-1)
        return masked_mse
    f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
    return f

The function computes the MSE loss over all the values of the predicted output, except for those elements whose corresponding value in the true output is equal to a masking value (e.g. -1).

Two notes:

  • when computing the mean the denominator must be the count of non-masked values and not the
    dimension of the array, that's why I'm not using K.mean(masked_squared_error, axis=1) and I'm
    instead averaging _manually_.
  • the masking value must be a valid number (i.e. np.nan or np.inf will not do the job), which means that you'll have to adapt your data so that it does not contain the mask_value.

In this example, the target output is always [1, 1, 1, 1], but some prediction values are progressively masked.

y_pred = K.constant([[ 1, 1, 1, 1], 
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
                     [ 1, 1, 1, 1],
                     [-1, 1, 1, 1],
                     [-1,-1, 1, 1],
                     [-1,-1,-1, 1],
                     [-1,-1,-1,-1]])

true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))

for i in range(true.shape[0]):
    print(true[i], pred[i], loss[i], sep='\t')

The expected output is:

[ 1.  1.  1.  1.]  [ 1.  1.  1.  1.]  0.0
[ 1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.0
[-1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.33333
[-1. -1.  1.  1.]  [ 1.  1.  1.  3.]  2.0
[-1. -1. -1.  1.]  [ 1.  1.  1.  3.]  4.0
[-1. -1. -1. -1.]  [ 1.  1.  1.  3.]  nan

All 4 comments

This is not the exact loss you are looking for, but I hope it will give you a hint to write your function:

def masked_mse(mask_value):
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        masked_mse = K.sum(masked_squared_error, axis=-1) / K.sum(mask_true, axis=-1)
        return masked_mse
    f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
    return f

The function computes the MSE loss over all the values of the predicted output, except for those elements whose corresponding value in the true output is equal to a masking value (e.g. -1).

Two notes:

  • when computing the mean the denominator must be the count of non-masked values and not the
    dimension of the array, that's why I'm not using K.mean(masked_squared_error, axis=1) and I'm
    instead averaging _manually_.
  • the masking value must be a valid number (i.e. np.nan or np.inf will not do the job), which means that you'll have to adapt your data so that it does not contain the mask_value.

In this example, the target output is always [1, 1, 1, 1], but some prediction values are progressively masked.

y_pred = K.constant([[ 1, 1, 1, 1], 
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
                     [ 1, 1, 1, 1],
                     [-1, 1, 1, 1],
                     [-1,-1, 1, 1],
                     [-1,-1,-1, 1],
                     [-1,-1,-1,-1]])

true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))

for i in range(true.shape[0]):
    print(true[i], pred[i], loss[i], sep='\t')

The expected output is:

[ 1.  1.  1.  1.]  [ 1.  1.  1.  1.]  0.0
[ 1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.0
[-1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.33333
[-1. -1.  1.  1.]  [ 1.  1.  1.  3.]  2.0
[-1. -1. -1.  1.]  [ 1.  1.  1.  3.]  4.0
[-1. -1. -1. -1.]  [ 1.  1.  1.  3.]  nan

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

Hi @baldassarreFe thanks for the solution above. I am trying to apply it to a 3D tensor output but I keep getting nans. My y tensor has these dimensions: (Batch_size x timesteps x features). Essentially, my model predicts two sequences (features). Any suggestions on what should I modify?

@baldassarreFe , @sdimi if y_true contains only missing values somewhere, the error is nan, as shown in the last line of the example above. If you adapt the function like this:

def masked_mse(mask_value):
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        # in case mask_true is 0 everywhere, the error would be nan, therefore divide by at least 1
        # this doesn't change anything as where sum(mask_true)==0, sum(masked_squared_error)==0 as well
        masked_mse = K.sum(masked_squared_error, axis=-1) / K.maximum(K.sum(mask_true, axis=-1), 1)
        return masked_mse
    f.__name__ = str('Masked MSE (mask_value={})'.format(mask_value))
    return f

the error will be 0 in that case instead of nan.

Was this page helpful?
0 / 5 - 0 ratings