Keras: Cannot use model.reset_states() with TensorFlow but works with Theano

Created on 31 Aug 2016  路  6Comments  路  Source: keras-team/keras

For a RNN model, during model.fit_generator() the batch generator calls model.reset_states() occasionally, which works with Theano but results in the following error with TensorFlow:

ValueError: Tensor("Placeholder:0", shape=(16, 2048), dtype=float32) must be from the same graph as Tensor("Variable:0", shape=(16, 2048), dtype=float32_ref).

(built TensorFlow from source just now and using the latest Keras version)

stale

Most helpful comment

Sorry for digging out this old issue but google led me here and I now have found a solution to use reset_states() inside the generator-function with keras and TensorFlow.
The problem seems to be that the generator is started in its own python instance that also has a new tensorflow session and graph.

If you save your current tf_session and graph in your "main"code and pass it to the generator:

import tensorflow as tf
import keras
tf_session = keras.backend.get_session()
tf_graph = tf.get_default_graph()

trainGen = myGen(..., model, tf_session, tf_graph)

you can reset the states of the model inside the generator by calling:

with tf_session.as_default():
     with tf_graph.as_default():
            model.reset_states() 

Maybe the keras code could also be changed to call generators with the same tf_session and graph as the session they are created in.

All 6 comments

Anyone else experiencing this? Is TensorFlow working with reset_states() called in the batch generator for fit_generator()? @fchollet, there is no test of this, rather the only tests are between model.predict() calls.

I'm experiencing exactly the same thing; calling model.reset_states() in my batch generator results in the same ValueError.

@carlthome How did you solve this?

When creating a simple batch training loop yourself, in which the batch generator is used, it works fine by the way. So, this works:

# Create and compile a model here
...

# Batch generator gets model injected so it can
#    call model.reset_states() at appropriate times.
batch = batch_generator(model)

for epoch in range(num_epochs):

    for step in range(num_steps_per_epoch):

        # Call the generator to create a batch, this could call 
        #    model.reset_states() when required
        input_batch, output_batch, sample_weights_batch = next(batch)

        metrics = model.train_on_batch(
                input_batch, 
                output_batch, 
                sample_weight=sample_weights_batch)

This does not work (generates the mentioned ValueError):

model.fit_generator(batch, num_steps_per_epoch, epochs=num_epochs)

I started using TensorFlow. :/

Sorry for digging out this old issue but google led me here and I now have found a solution to use reset_states() inside the generator-function with keras and TensorFlow.
The problem seems to be that the generator is started in its own python instance that also has a new tensorflow session and graph.

If you save your current tf_session and graph in your "main"code and pass it to the generator:

import tensorflow as tf
import keras
tf_session = keras.backend.get_session()
tf_graph = tf.get_default_graph()

trainGen = myGen(..., model, tf_session, tf_graph)

you can reset the states of the model inside the generator by calling:

with tf_session.as_default():
     with tf_graph.as_default():
            model.reset_states() 

Maybe the keras code could also be changed to call generators with the same tf_session and graph as the session they are created in.

I'm having the same issue.

The workaround by @Ced4 doesn't work for me. I either by chance get this error on the model.reset_states() line:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [100,10]

Or it runs without errors but the model state doesn't actually reset. It compares to the low performance I get when not resetting the state at all, in contrast to using a callback.

Was this page helpful?
0 / 5 - 0 ratings