Group norm is more accurate than Batch norm for small batches, useful for many vision tasks. See Group Normalization
Pytorch and Tensorflow (code below) have implementations. Could anyone port this please? It would be really helpful for many of us but I'm not sure how to implement. Thanks!
def GroupNorm(x, gamma, beta, G, eps=1e−5):
# x: input features with shape [N,C,H,W]
# gamma, beta: scale and offset, with shape [1,C,1,1] # G: number of groups for GN
N, C, H, W = x.shape
x = tf.reshape(x, [N, G, C // G, H, W])
mean, var = tf.nn.moments(x, [2, 3, 4], keep dims=True) x = (x − mean) / tf.sqrt(var + eps)
x = tf.reshape(x, [N, C, H, W]) return x ∗ gamma + beta
Figure 3. Python code of Group Norm based on TensorFlow.
Thanks for submitting this issue @smorrel1
@sandeep-krishnamurthy could you add label "Feature Request" to this issue?
So have the Group Norm been added into mxnet?
Hi @kalyc @sandeep-krishnamurthy If I want to contribute Group Norm implementation in Gluon. What is the process?
@srikar2097 you can add a contrib block in gluon. If you want some feedback on your plan to implement GroupNorm, you can subscribe dev list https://mxnet.incubator.apache.org/community/ and send out an email for RFC. If you already knows how to implement it, feel free to submit a PR directly
Well, I have implemented GroupNorm. It's slower than nn.BatchNorm, but it works (as the code below):
class GroupNorm(nn.HybridBlock):
"""
If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
GroupNorm achieves good results even at small batch sizes.
Reference:
https://arxiv.org/pdf/1803.08494.pdf
"""
def __init__(self, num_channels, num_groups=32, eps=1e-5,
multi_precision=False, **kwargs):
super(GroupNorm, self).__init__(**kwargs)
with self.name_scope():
self.weight = self.params.get('weight', grad_req='write',
shape=(1, num_channels, 1, 1))
self.bias = self.params.get('bias', grad_req='write',
shape=(1, num_channels, 1, 1))
self.C = num_channels
self.G = num_groups
self.eps = eps
self.multi_precision = multi_precision
assert self.C % self.G == 0
def hybrid_forward(self, F, x, weight, bias):
x_new = F.reshape(x, (0, self.G, -1)) # (N,C,H,W) -> (N,G,H*W*C//G)
if self.multi_precision:
mean = F.mean(F.cast(x_new, "float32"),
axis=-1, keepdims=True) # (N,G,H*W*C//G) -> (N,G,1)
mean = F.cast(mean, "float16")
else:
mean = F.mean(x_new, axis=-1, keepdims=True)
centered_x_new = F.broadcast_minus(x_new, mean) # (N,G,H*W*C//G)
if self.multi_precision:
var = F.mean(F.cast(F.square(centered_x_new),"float32"),
axis=-1, keepdims=True) # (N,G,H*W*C//G) -> (N,G,1)
var = F.cast(var, "float16")
else:
var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)
x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps) # (N,G,H*W*C//G) -> (N,C,H,W)
).reshape_like(x)
x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
return x_new
Clearly there are several issues, for example:
F.moments() (quite common) is not implemented in MXNet yet. Hence, my implementation here might be slow.reshape_like() seems unavoidable -> the input tensor has to be kept, which costs RAM.I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then F.contrib.SyncBatchNorm() will work well ).
P.s. a question to the MXNet authors:
There seem to be an OP called F.SumSquare()( see: https://github.com/dmlc/gluon-cv/blob/0a699a5ccc21310c7ce41d4737f0de9f54fbf45a/gluoncv/model_zoo/syncbn.py#L206 ), which is used for the calculation of the second-order moment I guess. I didn't find it in MXNet's API..., does this OP really exist?
@eric-haibin-lin 楼上的实现ok吗?
@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm?
@haojin2
Attempting implementation in backend...
@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm?
Yes, it stays in some branch of my fork. We used to implement that for SyncBN
Implementation in #14959, almost done, just gradient of data still has some work to be done.
@smorrel1 , @haojin2 has implemented GroupNorm and you may try that. Feel free to reopen the issue if you met any problems.
Most helpful comment
Well, I have implemented GroupNorm. It's slower than
nn.BatchNorm, but it works (as the code below):Clearly there are several issues, for example:
F.moments()(quite common) is not implemented in MXNet yet. Hence, my implementation here might be slow.reshape_like()seems unavoidable -> the input tensor has to be kept, which costs RAM.I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then
F.contrib.SyncBatchNorm()will work well ).P.s. a question to the MXNet authors:
There seem to be an OP called
F.SumSquare()( see: https://github.com/dmlc/gluon-cv/blob/0a699a5ccc21310c7ce41d4737f0de9f54fbf45a/gluoncv/model_zoo/syncbn.py#L206 ), which is used for the calculation of the second-order moment I guess. I didn't find it in MXNet's API..., does this OP really exist?