Keras: Issue with BatchNormalization when used with load_weights

Created on 8 Aug 2016  路  15Comments  路  Source: keras-team/keras

Hi,

I believe there is an error in coordinating weights for BatchNormalization with load_weights function. I was training an architecture with BatchNormalization layer in it, for which I saved the weights via checkpoint function. When I ran the predict() function after loading the weights on test data, I observed that the prediction varied when I changed the number of the test data points. I suspected that this has to do with BatchNormalization layer calculating the normalization parameters on the fly with the test data points instead of the one saved from the training which is what it should be doing.

When I took out BatchNormalization layer from training and repeat the same procedure, I was no longer having the same issue.

Most helpful comment

With the mode setting removed from BatchNormalization layer in Keras 2.x, this solution no longer applies. Any other way to remedy it? Why doesn't gamma + beta get saved like any other weights? I'm currently having this same issue where my performance on validation after loading weights and using predict() is much worse than my performance on validation during my call to fit_generator().

All 15 comments

What mode was your BatchNormalization layer operating in?

mode=2

BatchNormalization mode=2 is used when you're using a shared BatchNormalization layer, i.e. it is used to process more than one input in your model. In this mode, you are correct -- at test time, statistics are computed over the testing batches. If you want more consistent behavior, mode=0 should be used if possible. Otherwise, make a larger batch size for more stable estimates of normalization statistics at test time.

I see - thank you for your help.

With the mode setting removed from BatchNormalization layer in Keras 2.x, this solution no longer applies. Any other way to remedy it? Why doesn't gamma + beta get saved like any other weights? I'm currently having this same issue where my performance on validation after loading weights and using predict() is much worse than my performance on validation during my call to fit_generator().

I met the same issue recently. Any solution? I evened used training data for testing but the result is not reasonable at all. It's not predicting anything even if I set the test batch size the same as the training batch size.

@nicholaslocascio @ysyyork Were either of you able to solve this issue?

Is this problem resolved? I am using keras 2.1.5. Training accuracy is reported as excellent during training, but when I try to do a simple predict using the trained model on the same training data, the predictions are awful. What flags do I need to set, and where?

I guess the reason can be the fact, that batch normalization layer uses moments learned during training phase to make predictions in prediction phase. Since model.save_weights() method saves probably only the weights (moments of batch normalization layers don't sound like weights, do they?), as a result your batch normalization doesn't work. Using model.save() to save your model instead could help.

@bmiselis model.save_weights() saves all batchnorm params properly (I have checked) see also https://github.com/bonlime/keras-deeplab-v3-plus/issues/5 so it has to be something different

Nope, I have the same issue even if I'll try to make a prediction right after training without saving and loading weights.

I guess the reason can be the fact, that batch normalization layer uses moments learned during training phase to make predictions in prediction phase. Since model.save_weights() method saves probably only the weights (moments of batch normalization layers don't sound like weights, do they?), as a result your batch normalization doesn't work. Using model.save() to save your model instead could help.

I also get different results between training and evaluating even for train data. Any recommendations?

I am having the same issue as well. It seems like the BatchNorm makes the prediction result (.predict() function) inconsistent, i.e., it depends on the number of data points). Anybody with a solution? This really makes any implementation of networks with BatchNorm unreliable.

I have made a separate issue thread about BatchNorm layer. #12400

Was this page helpful?
0 / 5 - 0 ratings