For a RNN model, during model.fit_generator() the batch generator calls model.reset_states() occasionally, which works with Theano but results in the following error with TensorFlow:
ValueError: Tensor("Placeholder:0", shape=(16, 2048), dtype=float32) must be from the same graph as Tensor("Variable:0", shape=(16, 2048), dtype=float32_ref).
(built TensorFlow from source just now and using the latest Keras version)
Anyone else experiencing this? Is TensorFlow working with reset_states() called in the batch generator for fit_generator()? @fchollet, there is no test of this, rather the only tests are between model.predict() calls.
I'm experiencing exactly the same thing; calling model.reset_states() in my batch generator results in the same ValueError.
@carlthome How did you solve this?
When creating a simple batch training loop yourself, in which the batch generator is used, it works fine by the way. So, this works:
# Create and compile a model here
...
# Batch generator gets model injected so it can
# call model.reset_states() at appropriate times.
batch = batch_generator(model)
for epoch in range(num_epochs):
for step in range(num_steps_per_epoch):
# Call the generator to create a batch, this could call
# model.reset_states() when required
input_batch, output_batch, sample_weights_batch = next(batch)
metrics = model.train_on_batch(
input_batch,
output_batch,
sample_weight=sample_weights_batch)
This does not work (generates the mentioned ValueError):
model.fit_generator(batch, num_steps_per_epoch, epochs=num_epochs)
I started using TensorFlow. :/
Sorry for digging out this old issue but google led me here and I now have found a solution to use reset_states() inside the generator-function with keras and TensorFlow.
The problem seems to be that the generator is started in its own python instance that also has a new tensorflow session and graph.
If you save your current tf_session and graph in your "main"code and pass it to the generator:
import tensorflow as tf
import keras
tf_session = keras.backend.get_session()
tf_graph = tf.get_default_graph()
trainGen = myGen(..., model, tf_session, tf_graph)
you can reset the states of the model inside the generator by calling:
with tf_session.as_default():
with tf_graph.as_default():
model.reset_states()
Maybe the keras code could also be changed to call generators with the same tf_session and graph as the session they are created in.
I'm having the same issue.
The workaround by @Ced4 doesn't work for me. I either by chance get this error on the model.reset_states() line:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [100,10]
Or it runs without errors but the model state doesn't actually reset. It compares to the low performance I get when not resetting the state at all, in contrast to using a callback.
Most helpful comment
Sorry for digging out this old issue but google led me here and I now have found a solution to use reset_states() inside the generator-function with keras and TensorFlow.
The problem seems to be that the generator is started in its own python instance that also has a new tensorflow session and graph.
If you save your current tf_session and graph in your "main"code and pass it to the generator:
you can reset the states of the model inside the generator by calling:
Maybe the keras code could also be changed to call generators with the same tf_session and graph as the session they are created in.