I have the issue in my model that the running_var parameter of some batchnorm layers are NaN after my dummy forward pass to initialize the parameters. While debugging, I discovered that the value of the running_var depends on the context I use. I assume this is a bug, as a model should behave the same no matter which context is used. Here is a minimum reproducible example:
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.block import HybridBlock
import numpy as np
def _conv3x3(channels, stride, in_channels):
return nn.Conv2D(channels, kernel_size=3, strides=stride, padding=1,
use_bias=False, in_channels=in_channels)
def get_dummy_data(ctx):
data_shape = (1, 3, 32, 32) # cifar like
shapes = ((1,) + data_shape[1:], (1,))
return [mx.nd.array(np.zeros(shape), ctx=ctx) for shape in shapes]
def check_net_params_for_nan(net, ctx):
has_nan = False
for name, param in net.collect_params().items():
if np.isnan(param.data(ctx).asnumpy()).any():
print('Param {} has nan values!'.format(name))
has_nan = True
if 'running_var' in name:
print('Batchnorm running var values {}'.format(param.data(ctx).asnumpy()))
return has_nan
class test_net(HybridBlock):
def __init__(self, **kwargs):
super(test_net, self).__init__(**kwargs)
self.body = nn.HybridSequential(prefix='')
self.body.add(_conv3x3(64, 1, 0))
with self.body.name_scope():
self.body.add(_conv3x3(64, 2, 64))
self.body.add(nn.BatchNorm())
self.body.add(nn.Activation('relu'))
self.body.add(_conv3x3(64, 1, 64))
self.body.add(nn.BatchNorm())
def hybrid_forward(self, F, x, *args, **kwargs):
return self.body(x)
num_gpus = 2
ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
net = test_net()
net.initialize(mx.init.Xavier(), ctx=ctx)
# dummy forward pass to initialize layers
with mx.autograd.record():
# to make sure all params are initialized. Needs to re-run for networks that only execute layers with a certain
# probability to make sure all layers are initialized
for i in range(100):
data, label = get_dummy_data(ctx[0])
output = net(data)
assert not check_net_params_for_nan(net, ctx[0])
If I set num_gpus in this code to 0 its using the CPU and the output is, that all running_var values are 1.
If i set num_gpus to 1 or 2 its using GPUs and the value of all running_var values is 2.6561329e-05
Can anyone reproduce this? I am using mxnet 1.3.1 on ubuntu, built from source.
Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Bug
I ever submitted a PR
https://github.com/apache/incubator-mxnet/issues/9216, which may be related to the issue.
@mxnet-label-bot update [Bug, Gluon]
This bug is still present in 1.4.1, the batchnorm operator does not support training on CPU. That is quite critical in my opinion.
Wow, this is pretty crazy, all sorts of frequently used models will fail on CPU
To clarify it is only an issue during training, the running Var won't be computed correctly. But if you load a gpu trained model on CPU then inference is still correct.
I have just encountered the same or similar problem. I am training on a GPU using the python API. However, my inference target machine does not have a GPU and uses the C++ API on a CPU. My inference results are all nan. When I grouped all symbols to the output I was able to see that the problem starts after the first batch normalization layer. When I run the same C++ code on a machine with GPU, everything is ok.
Well, batch normalization makes sense also to train smaller MLPs, and some people do that (we have a project aiming to do that). And people will try to debug bigger models on CPU before pushing them out to expensive GPU instances.
I really hope this is not something very hard to fix?
@PatricZhao @TaoLv is there any success using BN for training on CPU?
@PatricZhao @TaoLv is there any success using BN for training on CPU?
Sure. We have training benchmarks for MKL-DNN backend. @juliusshufan Can you share some recent training trends here? Better to have BN in the model. Thanks.
@TaoLv @eric-haibin-lin Per the nightly tracking on convergence, the ResNet50 with CiFAR10 indating the training trends to converge.

The problem is not the training performance, the problem is the learnt running var used at inference time outside a autograd scope that is messed up and leads to erroneous predictions.
@ThomasDelteil Thanks for your prompt response. So far the training and pre-trained model based inference is covered, may I know if your proposal on a (minimum) reproduction case?
Thanks.
Inference on a trained from scratch on CPU network is likely to be erroneous thought not guaranteed, the inference of a pretrained-on-GPU model on CPU is fine because the running values have been computed correctly.
I'll share with you tomorrow an example of a training script that works on GPU but not on CPU.
for ctx in [mx.cpu(), mx.gpu()]:
layer = gluon.nn.BatchNorm()
layer.initialize(ctx=ctx)
for i in range(100):
data = mx.nd.random.normal(loc=10, scale=2, shape=(1,3,224,224), ctx=ctx)
with autograd.record():
out = layer(data)
print(ctx, layer.running_var.data().asnumpy(), layer.running_mean.data().asnumpy())
cpu(0) [1. 1. 1.] [0. 0. 0.]
gpu(0) [4.010342 4.002672 3.9972866] [10.002233 9.998462 10.000072]
as you can see the variance and mean are erroneous on CPU, it is actually always ones for the variance, and 0 for the mean.
edit: it seems that the running_mean and running_var are computed in the forward on GPU but not on CPU.
for ctx in [mx.cpu(), mx.gpu()]:
layer = gluon.nn.BatchNorm()
layer.initialize(ctx=ctx)
for i in range(100):
data = mx.nd.random.normal(loc=10, scale=2, shape=(1,3,224,224), ctx=ctx)
with autograd.record():
out = layer(data)
out.backward()
print(ctx, layer.running_var.data().asnumpy(), layer.running_mean.data().asnumpy())
cpu(0) [3.9925063 4.008966 3.9975138] [10.000259 10.001338 9.999197]
gpu(0) [3.9989533 3.995771 4.009656 ] [10.001726 9.998586 9.999382]
import mxnet as mx
from mxnet import nd, autograd, gluon
import numpy as np
def transform(data, label):
return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
trainset = gluon.data.vision.FashionMNIST(train=True, transform=transform)
train_data = gluon.data.DataLoader(dataset=trainset, batch_size=50, shuffle=True)
SCE = gluon.loss.SoftmaxCrossEntropyLoss()
for ctx in [mx.cpu(), mx.gpu()]:
net = gluon.model_zoo.vision.get_model('resnet18_v1', pretrained=False, classes=10)
# Parameter initialization
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True)
trainer = gluon.Trainer(params=net.collect_params(), optimizer='sgd', optimizer_params={'learning_rate': .01, 'wd': 0.0001, 'momentum': 0.9})
# Training
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with autograd.record():
output = net(data)
loss = SCE(output, label)
loss.backward()
trainer.step(data.shape[0])
if i == 20:
break
# Training accuracy under autograd
accuracy = mx.metric.Accuracy()
for i, (data, label) in enumerate(train_data):
with autograd.record():
output = net(data.as_in_context(ctx))
accuracy.update(label, output)
if i == 5:
break
print("Train accuracy so far, under autograd training scope evaluation:", accuracy.get(), ctx)
# Training accuracy outside autograd
accuracy = mx.metric.Accuracy()
for i, (data, label) in enumerate(train_data):
output = net(data.as_in_context(ctx))
accuracy.update(label, output)
if i == 5:
break
print("Train accuracy so far, outside autograd training scope evaluation:", accuracy.get(), ctx)
Inside / Outside the autograd scope should only impact the batch norm operator, using the local batch normalization or the computed running mean and variance. We can see that the CPU version is a lot worse than the GPU one.
Train accuracy so far, under autograd training scope evaluation: ('accuracy', 0.6333333333333333) cpu(0)
Train accuracy so far, outside autograd training scope evaluation: ('accuracy', 0.09666666666666666) cpu(0)
Train accuracy so far, under autograd training scope evaluation: ('accuracy', 0.6833333333333333) gpu(0)
Train accuracy so far, outside autograd training scope evaluation: ('accuracy', 0.5466666666666666) gpu(0)
@ThomasDelteil thanks for the example. Will follow up the case @wuxun-zhang @juliusshufan
Thanks @pengzhao-intel, I was planning to look into this issue. Please let me know if you have already in progress work.
@sandeep-krishnamurthy Thanks for asking, a previous quick going through this issue, seems the difference caused by different computation method on CPU/GPU backend, it's iters by iters v.s. accumulation. I also tried the ResNet50 training on CIFAR, the training and validation accuracy are all okay on CPU.
Has this bug been fixed in the latest versions?
Sorry for the later response for a while. We will look into this issue now and get back soon :)
Hi @adrianloy,
Your issue still exist on the 1.6 branch. On master branch (on 81747710c) ārunning_varā is calculated only in backward pass on CPU and GPU backend as well, so your test gives the same results on both contexts:
Batchnorm running var values [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Hi @ThomasDelteil,
According to the training script from https://github.com/apache/incubator-mxnet/issues/14357#issuecomment-497487102,
As I mentioned in previous comment on the master branch (on 8174771) running variables in BatchNorm are calculated only during the backward pass.
Still, there are some differences in the results between CPU and GPU backend. One of the reasons comes from different input tensors on GPU and CPU as mx.nd.random.normal() function produces different results on both backends. According to the documentation https://mxnet.apache.org/api/python/docs/api/mxnet/random/index.html it is an expected behavior:
Random number generators in MXNet are device specific. mx.random.seed(seed_state) sets the state of each generator using seed_state and the device id. Therefore, random numbers generated from different devices can be different even if they are seeded using the same seed.
To produce identical random number sequences independent of the device id, set optional ctx argument. This produces the same sequence of random numbers independent of the device id, but the sequence can be different on different kind of devices as MXNetās random number generators for CPU and GPU use different algorithms.
So, for comparison purpose I moved generating tensors to NumPy.
The second issue I observe is synchronization problem for running vars. For now I put some workaround to receive final result from backward pass (I am not sure if it is an issue for real network). So the scripts could looks as follows:
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
import numpy as np
seed = np.random.randint(np.iinfo(np.int32).max)
#seed = 0
print("seed:", seed)
shape = (1,3,224,224)
layers = 100
dataNumpy = {}
np.random.seed(seed)
for i in range(layers):
dataNumpy[i] = np.random.normal(loc=10, scale=2, size=shape)
for ctx in [mx.cpu(), mx.gpu()]:
layer2 = gluon.nn.BatchNorm()
layer2.initialize(ctx=ctx)
for i in range(layers):
data2 = mx.nd.array(dataNumpy[i], ctx=ctx)
with autograd.record():
out = layer2(data2)
out.backward()
# workaround for synchronization issue
var1 = layer2.running_var.data().asnumpy()
for t in range(1, 10):
var2 = layer2.running_var.data().asnumpy()
if (var1 != var2).any():
print(ctx, "- DIFF in running_var reads:\n 0 :", var1, "\n ", t,":", var2 )
break
print(ctx, layer2.running_var.data().asnumpy(), layer2.running_mean.data().asnumpy() )
For the test above I receive almost the same values for both backends:
seed: 791821049
cpu(0) [3.9977632 3.999108 4.0007195] [10.000481 10.000663 9.999296]
gpu(0) [3.997764 3.9991088 4.0007195] [10.000478 10.000664 9.999295]
The difference is so small that I guess it could be neglected (as a difference in rounding in both backends)
Hi @ThomasDelteil ,
According to the training script from to https://github.com/apache/incubator-mxnet/issues/14357#issuecomment-497508722,
I didnāt manage to change the script to produce the same result in each run, so instead I run the test procedure 999 times and get the mean value:
import mxnet as mx
from mxnet import nd, autograd, gluon
import numpy as np
def transform(data, label):
return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
trainset = gluon.data.vision.FashionMNIST(train=True)
trainset= trainset.transform(transform)
train_data = gluon.data.DataLoader(dataset=trainset, batch_size=50, shuffle=True)
SCE = gluon.loss.SoftmaxCrossEntropyLoss()
under_res = {}
under_sum = {}
outside_res = {}
outside_sum = {}
for ctx in [mx.gpu(), mx.cpu()]:
under_sum[ctx] = 0.0
outside_sum[ctx] = 0.0
for t in range(1,1000):
for ctx in [mx.cpu(), mx.gpu()]:
net = gluon.model_zoo.vision.get_model('resnet18_v1', pretrained=False, classes=10)
# Parameter initialization
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True)
trainer = gluon.Trainer(params=net.collect_params(), optimizer='sgd', optimizer_params={'learning_rate': .01, 'wd': 0.0001, 'momentum': 0.9})
# Training
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with autograd.record():
output = net(data)
loss = SCE(output, label)
loss.backward()
trainer.step(data.shape[0])
if i == 20:
break
# Training accuracy under autograd
accuracy = mx.gluon.metric.Accuracy()
for i, (data, label) in enumerate(train_data):
with autograd.record():
output = net(data.as_in_context(ctx))
accuracy.update(label, output)
if i == 5:
break
under_res[ctx] = accuracy.get()[1]
under_sum[ctx] += under_res[ctx]
# Training accuracy outside autograd
accuracy = mx.gluon.metric.Accuracy()
for i, (data, label) in enumerate(train_data):
output = net(data.as_in_context(ctx))
accuracy.update(label, output)
if i == 5:
break
outside_res[ctx] = accuracy.get()[1]
outside_sum[ctx] += outside_res[ctx]
for ctx in [mx.cpu(), mx.gpu()]:
print("Test {:3} Accuracy for {}: under autograd: {:.6f} mean: {:.6f}, outside autograd: {:.6f} mean: {:.6f}".format(
t, ctx,
under_res[ctx],
under_sum[ctx] / t,
outside_res[ctx],
outside_sum[ctx] / t))
It shows that statistically GPU and CPU give similar result:
Test 999 Accuracy for cpu(0): under autograd: 0.583333 mean: 0.588709, outside autograd: 0.173333 mean: 0.192819
Test 999 Accuracy for gpu(0): under autograd: 0.673333 mean: 0.589486, outside autograd: 0.170000 mean: 0.191892
Please see the log for full data: test_03_master_fixed_sync.txt
@adrianloy, @PatricZhao,
could you close the issue as the results on master (959143696) are the same on GPU and CPU (see the comments above for details)?
I do see that there is no difference between CPU and GPU anymore so I guess I can close this. What I dont really understand is why the accuracy outside autograd is so much lower than under autograd (for both contexts). Is this expected behaviour? Still sounds like a bug to me, but maybe I am missing something obvious.
Hi @adrianloy,
Using autograd.record() with default value of train_mode = True causes BatchNorm to use currently given batch data to calculate mean and variance values in forward pass. These values are later used to produce the output (instead of running_mean and running_var values previously calculated during training by backward propagation). This way the output is "adjusted" to the current batch - it results in better accuracy values during the checking phase.
If you disable training mode in autograd:
with autograd.record(False):
output = net(data.as_in_context(ctx))
you will receive the same accuracy under and outside autograd.
So in my opinion (Unfortunately I am not an expert yet š) during checking the training accuracy autograd shouldn't be used at all.