Keras: Keras ignoring validation_data when provided from TF Iterator

Created on 11 Jan 2018  路  3Comments  路  Source: keras-team/keras

When providing validation_data from TensorFlow Iterator, Keras seems to ignore the parameter and use training data anyway.

import tensorflow as tf import keras

def _parse_function_x(filename):
  image = tf.random_uniform([tf.shape(filename)[0], 198,198,1]) # simulated image loading and manipulation
  image = tf.Print(image, [filename, tf.shape(image)], " - returning image for file -")
  return image

def _parse_function_y(label):
    return tf.one_hot(label, 10)

def _parse_function(filename, label):
    return _parse_function_x(filename), _parse_function_y(label)

flist = ["trimg1", "trimg2", "trimg3", "trimg4", "trimg5", "trimg6"]

filenames = tf.constant(flist)
labels = tf.constant([0, 5, 6, 1, 2, 3])

train_batch = 2
valid_batch = 3

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.repeat().batch(train_batch)
train_x_dataset = train_x_dataset.map(_parse_function_x)
it_train_x = train_x_dataset.make_one_shot_iterator()

train_y_dataset = tf.data.Dataset.from_tensor_slices((labels))
train_y_dataset = train_y_dataset.repeat().batch(train_batch)
train_y_dataset = train_y_dataset.map(_parse_function_y)
it_train_y = train_y_dataset.make_one_shot_iterator()

vlist = ["val1", "val2", "val3"]

valid_filenames = tf.constant(vlist)
valid_labels = tf.constant([3, 2, 5])
valid_dataset = tf.data.Dataset.from_tensor_slices((valid_filenames, valid_labels))
valid_dataset = valid_dataset.repeat().batch(valid_batch)
valid_dataset = valid_dataset.map(_parse_function)
it_valid = valid_dataset.make_one_shot_iterator()

model = keras.applications.resnet50.ResNet50(include_top=True, weights=None, input_tensor=it_train_x.get_next(), pooling=None, classes=10, input_shape=(198,198,1))

model.compile(optimizer='sgd',
              loss='categorical_crossentropy',
              metrics=['accuracy'],
              target_tensors=[it_train_y.get_next()])


model.fit(steps_per_epoch=len(flist) // train_batch, epochs=5, validation_data=it_valid.get_next(),
          validation_steps=len(vlist)  // valid_batch, verbose=2)

Result:

2018-01-10 19:07:57.810850: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.372769: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:58.428576: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:58.744026: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.759726: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 4s - loss: 8.7110 - acc: 0.0000e+00 - val_loss: 2.8315 - val_acc: 0.0000e+00
Epoch 2/5
2018-01-10 19:07:58.815126: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:58.869334: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.923224: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 0s - loss: 12.0298 - acc: 0.0000e+00 - val_loss: 2.7070 - val_acc: 0.0000e+00
Epoch 3/5
2018-01-10 19:07:58.939015: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.005950: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.067022: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:59.120895: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
 - 0s - loss: 12.9786 - acc: 0.0000e+00 - val_loss: 3.7159 - val_acc: 0.0000e+00
Epoch 4/5
2018-01-10 19:07:59.136508: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.190424: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:59.259350: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.319021: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.334779: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 0s - loss: 13.9429 - acc: 0.0000e+00 - val_loss: 5.9738 - val_acc: 0.0000e+00
Epoch 5/5
2018-01-10 19:07:59.388996: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.443311: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.507233: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 0s - loss: 11.9854 - acc: 0.1667 - val_loss: 1.3048 - val_acc: 0.5000

Process finished with exit code 0
The val1, val2, ... files seem to be ignored, however Keras somehow calculates val_loss etc.

Most helpful comment

@fchollet any news on supporting passing data tensors for validation_data in fit

All 3 comments

Yes, in this setup you are validating on what is being yielded by the data tensor input to the model. If your steps_per_epoch value matches the size of the training split of your dataset, then that's a correct way to do validation (i.e. you're validating on the end of the datasets and never training on that data).

At this time, it is pointless to pass a validation_data argument that is a data tensor, because the model cannot be rewired to use it as input. Your model already has an input hardwired in it, the tensor which you passed as input_tensor argument in your model.

In the future, we will make it possible to pass data tensors in fit, including for validation_data. But for the time being, that's not possible.

Thank you @fchollet fro your explanation. It would be great if fit method threw an exception or warn the user that validation_data will be ignored in this scenario.

@fchollet any news on supporting passing data tensors for validation_data in fit

Was this page helpful?
0 / 5 - 0 ratings

Related issues

braingineer picture braingineer  路  3Comments

oweingrod picture oweingrod  路  3Comments

NancyZxll picture NancyZxll  路  3Comments

LuCeHe picture LuCeHe  路  3Comments

harishkrishnav picture harishkrishnav  路  3Comments