Keras: NaN loss when using euclidean distance for siamese network (due to K.sqrt)

Created on 17 Jan 2018  路  4Comments  路  Source: keras-team/keras

When I'm training a siamese network, I come across the following code and use it in my program:

def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True))

But it turns out nan loss when two inputs are the same. (Interestingly, it can be only discovered when GPU is not used. When GPU is used, the loss won't be nan but the network is not trainable.) However, when I add small dummy bias into it, the network can work.

def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True) + 0.01)

I'm just wondering why K.sqrt() cannot take zero tensor as input.

Btw, I'm using Keras 2.1.2 and tensorflow-gpu 1.4.0.

Most helpful comment

Maybe it's because derivative of sqrt(x) is infinite at 0? Does it happen during evaluate() or only during fit()?

All 4 comments

It is usually related to numerical stability. However, what is more interesting for me is -- why your Siamese Network training would meet the strange case -- two inputs are identical.

I double checked the source code of K.sqrt for both tensorflow and theano implementations, but both have already clipped an input to non-neg values. So K.sqrt should take zero tensors as inputs.

def sqrt(x):
    """Element-wise square root.
    # Arguments
        x: Tensor or variable.
    # Returns
        A tensor.
    """
    zero = _to_tensor(0., x.dtype.base_dtype)
    inf = _to_tensor(np.inf, x.dtype.base_dtype)
    x = tf.clip_by_value(x, zero, inf)
    return tf.sqrt(x)


def sqrt(x):
    x = T.clip(x, 0., np.inf)
    return T.sqrt(x)

@rex-yue-wu Thank you for your reply! My data set is generated through images and some kind of spatial transformations, I didn't realize my data set contains such case until I found the problem.

I directly apply sqrt function on zero tensor, like you say, it won't generate any problem. But it still comes out through the training...... But for now I just remove this case since identical input images won't generate any loss.

Maybe it's because derivative of sqrt(x) is infinite at 0? Does it happen during evaluate() or only during fit()?

@ozabluda I think you are right ! It happens only during fit(). So the problem is caused by an infinite gradient which comes from the special case.

Thanks for all your help :")

Was this page helpful?
0 / 5 - 0 ratings

Related issues

vinayakumarr picture vinayakumarr  路  3Comments

nryant picture nryant  路  3Comments

fredtcaroli picture fredtcaroli  路  3Comments

harishkrishnav picture harishkrishnav  路  3Comments

farizrahman4u picture farizrahman4u  路  3Comments