I am training a sequence-to-sequence model with TimeDistributed(Dense()) as the final layer. As far as I can see, if I compile my model with metrics=['accuracy'], the accuracy is calculated as the average across all timesteps.
What I would like to have is an accuracy metric that looks at _full sequences_ instead of individual timesteps, i.e., the accuracy for an output sequence should be 1 iff _all_ timesteps have been predicted correctly, and 0 otherwise.
Is it possible to write a function to do this which can be used with metrics=...?
I have a really hard time figuring out how to write a custom function for use with metrics=..., since I've never worked with tensor functions directly, and I'm also unsure what exactly gets passed to that function. Any help would be appreciated.
Okay, I figured this out by trial & error, and it's incredibly simple really.
For example, take this accuracy function defined in keras/metrics.py:
def categorical_accuracy(y_true, y_pred):
return K.mean(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)))
To get the accuracy per sequence, do:
def categorical_accuracy_per_sequence(y_true, y_pred):
return K.mean(K.min(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)), axis=-1))
I assume that y_true and y_pred are tensors in the form of the model's output shape. K.equal(...) checks if the predictions are correct and returns a tensor of booleans. We want to reduce this to a tensor that has True iff _all_ elements in a sequence are True, and False otherwise. Turns out the simplest way to do this is K.min(..., axis=-1), and then just take the mean from that.
Most helpful comment
Okay, I figured this out by trial & error, and it's incredibly simple really.
For example, take this accuracy function defined in
keras/metrics.py:To get the accuracy per sequence, do:
I assume that
y_trueandy_predare tensors in the form of the model's output shape.K.equal(...)checks if the predictions are correct and returns a tensor of booleans. We want to reduce this to a tensor that has True iff _all_ elements in a sequence are True, and False otherwise. Turns out the simplest way to do this isK.min(..., axis=-1), and then just take the mean from that.