It was not immediately clear to me how the BatchNorm operator works from the documentation. In the original paper, in the training stage, the operator uses the batch statistics for normalization and accumulate those as moving averages, which are to be used in the inference stage.
The MxNet BatchNorm operator has a "use_global_stats" flag, which adjusts, if I understand correctly, that behavior. If set to true, it uses the global statistics from the auxillary arrays and if set to false, it uses batch statistics.
Now, my question is, how does setting "is_train" to True/False in the forward pass affacts the behavior of the BatchNorm, combined with the use_global_stats flag? For example, does setting use_global_stats to False would override is_train flag and cause the operator to use batch statistics everytime? Or is "is_train" not effective for BatchNorm at all?
https://github.com/dmlc/mxnet/blob/master/src/operator/batch_norm-inl.h#L95
if (ctx.is_train && !param_.use_global_stats) {
# use batch statistics
} else {
# use global stats
}
Then it is clear that both is_train and use_global_stats are considered at the same time.
Thanks for the answer!
By the way, I just have another question. From the C++ code, I see that the BatchNorm operator always uses global averaged statistics, no matter what the flag "use_global_stats" says, when the forward() is called with "is_train=False". Is there a simple way (from the Python interface) to make it use the batch statistics instead, when the network is doing inference? I have a use case which uses BatchNorm for a different purpose from its original covariate shift reduction goal and this might require that I make the operator use the batch statistics during inference as well.
Is there any reason why you can't use is_train=True?
This issue is closed due to lack of activity in the last 90 days. Feel free to reopen if this is still an active issue. Thanks!
Most helpful comment
https://github.com/dmlc/mxnet/blob/master/src/operator/batch_norm-inl.h#L95
Then it is clear that both is_train and use_global_stats are considered at the same time.