Keras: Update moving_mean and moving_variance of BatchNormalization Layer use sess.run()

Created on 25 May 2017  路  5Comments  路  Source: keras-team/keras

I usually just use the layers of keras. I write the training code by myself in tensorflow.
I find if I don't use model.fit function to train a model, moving_mean and moving_variance of BatchNormalization Layer will not update. That is, moving_mean is always equal to 0 and moving_variance is always equal to 1.
Here is a example of my model:

import keras
import tensorflow as tf
import keras.layers as kl
import keras.backend as K
import numpy as np

K.set_learning_phase(1)
model = keras.models.Sequential()
model.add(kl.InputLayer([784]))
model.add(kl.Dense(400))
model.add(kl.normalization.BatchNormalization())
model.add(kl.Activation('relu'))
model.add(kl.Dense(400))
model.add(kl.normalization.BatchNormalization())
model.add(kl.Activation('relu'))
model.add(kl.Dense(10,activation='sigmoid'))

When I use model.fit to train it, moving_mean and moving_variance are updated.

model.compile(loss='categorical_crossentropy',optimizer=keras.optimizers.Adam())
model.fit(x,y,500,1)

But when I train it use original tensorflow code like the following:

train = tf.train.AdamOptimizer(0.001).minimize(loss,var_list=model.weights)
_train, err = sess.run([train,loss],{img:a,label:b})

In this way, moving_mean and moving_variance are not updated.
I know we can see moving_mean and moving_variance in model.updates. But I don't know how to update them during training if I don't want to use model.fit.
Is there a simple solution?

Most helpful comment

Actually, I think that there's a small mistake in that tutorial, because "layer" there is just a tf tensor.
(Or maybe it's just not updated)

You need to change it into something like this:

from keras.layers import BatchNormalization

layer = BatchNormalization()

blah = layer(x)

update_ops = []
for old_value, new_value in layer.updates:
    update_ops.append(tf.assign(old_value, new_value))


also, it seems that layer.updates already contains the assign ops - so further change is needed into this:

from keras.layers import BatchNormalization

layer = BatchNormalization()

blah = layer(x)

update_ops = []
for assign_op in layer.updates:
    update_ops.append(assign_op))


please correct me if I'm wrong :)

All 5 comments

I mean how to train moving_mean and moving_variance directly using sess.run().

@fchollet

@zzd1992, you might wanna take a look at the section "Collecting trainable weights and state updates" of https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html.

Collecting trainable weights and state updates

Some Keras layers (stateful RNNs and BatchNormalization layers) have internal updates that need to be run as part of each training step. There are stored as a list of tensor tuples, layer.updates. You should generate assign ops for those, to be run at each training step. Here's an example:

from keras.layers import BatchNormalization

layer = BatchNormalization()(x)

update_ops = []
for old_value, new_value in layer.updates:
    update_ops.append(tf.assign(old_value, new_value))

Note that if you are using a Keras model (Model instance or Sequential instance), model.udpates behaves in the same way (and collects the updates of all underlying layers in the model).

Actually, I think that there's a small mistake in that tutorial, because "layer" there is just a tf tensor.
(Or maybe it's just not updated)

You need to change it into something like this:

from keras.layers import BatchNormalization

layer = BatchNormalization()

blah = layer(x)

update_ops = []
for old_value, new_value in layer.updates:
    update_ops.append(tf.assign(old_value, new_value))


also, it seems that layer.updates already contains the assign ops - so further change is needed into this:

from keras.layers import BatchNormalization

layer = BatchNormalization()

blah = layer(x)

update_ops = []
for assign_op in layer.updates:
    update_ops.append(assign_op))


please correct me if I'm wrong :)

Is it possible to get a complete example of the above solution in order to clarify when the update_ops should be called?

For instance, given the code of zzd1992 in the first post and the proposed solution, a training step would be run using

_train, err = sess.run([train, loss, update_ops],{img:a,label:b})

or do the update_ops need to be called separately?

Was this page helpful?
0 / 5 - 0 ratings