Incubator-mxnet: Group Norm

Created on 14 Jun 2018  ·  12Comments  ·  Source: apache/incubator-mxnet

Group norm is more accurate than Batch norm for small batches, useful for many vision tasks. See Group Normalization
Pytorch and Tensorflow (code below) have implementations. Could anyone port this please? It would be really helpful for many of us but I'm not sure how to implement. Thanks!

def GroupNorm(x, gamma, beta, G, eps=1e−5):
# x: input features with shape [N,C,H,W]
# gamma, beta: scale and offset, with shape [1,C,1,1] # G: number of groups for GN
N, C, H, W = x.shape
x = tf.reshape(x, [N, G, C // G, H, W])
mean, var = tf.nn.moments(x, [2, 3, 4], keep dims=True) x = (x − mean) / tf.sqrt(var + eps)
x = tf.reshape(x, [N, C, H, W]) return x ∗ gamma + beta
Figure 3. Python code of Group Norm based on TensorFlow.

Feature request Operator

Most helpful comment

Well, I have implemented GroupNorm. It's slower than nn.BatchNorm, but it works (as the code below):

class GroupNorm(nn.HybridBlock):
    """
    If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
    GroupNorm achieves good results even at small batch sizes.
    Reference:
      https://arxiv.org/pdf/1803.08494.pdf
    """
    def __init__(self, num_channels, num_groups=32, eps=1e-5,
                 multi_precision=False, **kwargs):
        super(GroupNorm, self).__init__(**kwargs)

        with self.name_scope():
            self.weight = self.params.get('weight', grad_req='write',
                                          shape=(1, num_channels, 1, 1))
            self.bias = self.params.get('bias', grad_req='write',
                                        shape=(1, num_channels, 1, 1))
        self.C = num_channels
        self.G = num_groups
        self.eps = eps
        self.multi_precision = multi_precision

        assert self.C % self.G == 0

    def hybrid_forward(self, F, x, weight, bias):

        x_new = F.reshape(x, (0, self.G, -1))                                # (N,C,H,W) -> (N,G,H*W*C//G)

        if self.multi_precision:
            mean = F.mean(F.cast(x_new, "float32"),
                          axis=-1, keepdims=True)                            # (N,G,H*W*C//G) -> (N,G,1)
            mean = F.cast(mean, "float16")
        else:
            mean = F.mean(x_new, axis=-1, keepdims=True)

        centered_x_new = F.broadcast_minus(x_new, mean)                      # (N,G,H*W*C//G)

        if self.multi_precision:
            var = F.mean(F.cast(F.square(centered_x_new),"float32"),
                         axis=-1, keepdims=True)                             # (N,G,H*W*C//G) -> (N,G,1)
            var = F.cast(var, "float16")
        else:
            var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)

        x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps)       # (N,G,H*W*C//G) -> (N,C,H,W)
                                ).reshape_like(x)
        x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
        return x_new

Clearly there are several issues, for example:

  • An operator such as F.moments() (quite common) is not implemented in MXNet yet. Hence, my implementation here might be slow.
  • The use of reshape_like() seems unavoidable -> the input tensor has to be kept, which costs RAM.
  • When training with mixed-precision, the above implementation cast a FP16-input into FP32 to avoid loss of precision while calculating both mean & variance. Casting a FP16-tensor to FP32 and then back to FP16 wastes time (this stupid step can be eliminated if we implement this layer at the level of CUDA).

I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then F.contrib.SyncBatchNorm() will work well ).

P.s. a question to the MXNet authors:
There seem to be an OP called F.SumSquare()( see: https://github.com/dmlc/gluon-cv/blob/0a699a5ccc21310c7ce41d4737f0de9f54fbf45a/gluoncv/model_zoo/syncbn.py#L206 ), which is used for the calculation of the second-order moment I guess. I didn't find it in MXNet's API..., does this OP really exist?

All 12 comments

Thanks for submitting this issue @smorrel1
@sandeep-krishnamurthy could you add label "Feature Request" to this issue?

So have the Group Norm been added into mxnet?

Hi @kalyc @sandeep-krishnamurthy If I want to contribute Group Norm implementation in Gluon. What is the process?

@srikar2097 you can add a contrib block in gluon. If you want some feedback on your plan to implement GroupNorm, you can subscribe dev list https://mxnet.incubator.apache.org/community/ and send out an email for RFC. If you already knows how to implement it, feel free to submit a PR directly

Well, I have implemented GroupNorm. It's slower than nn.BatchNorm, but it works (as the code below):

class GroupNorm(nn.HybridBlock):
    """
    If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
    GroupNorm achieves good results even at small batch sizes.
    Reference:
      https://arxiv.org/pdf/1803.08494.pdf
    """
    def __init__(self, num_channels, num_groups=32, eps=1e-5,
                 multi_precision=False, **kwargs):
        super(GroupNorm, self).__init__(**kwargs)

        with self.name_scope():
            self.weight = self.params.get('weight', grad_req='write',
                                          shape=(1, num_channels, 1, 1))
            self.bias = self.params.get('bias', grad_req='write',
                                        shape=(1, num_channels, 1, 1))
        self.C = num_channels
        self.G = num_groups
        self.eps = eps
        self.multi_precision = multi_precision

        assert self.C % self.G == 0

    def hybrid_forward(self, F, x, weight, bias):

        x_new = F.reshape(x, (0, self.G, -1))                                # (N,C,H,W) -> (N,G,H*W*C//G)

        if self.multi_precision:
            mean = F.mean(F.cast(x_new, "float32"),
                          axis=-1, keepdims=True)                            # (N,G,H*W*C//G) -> (N,G,1)
            mean = F.cast(mean, "float16")
        else:
            mean = F.mean(x_new, axis=-1, keepdims=True)

        centered_x_new = F.broadcast_minus(x_new, mean)                      # (N,G,H*W*C//G)

        if self.multi_precision:
            var = F.mean(F.cast(F.square(centered_x_new),"float32"),
                         axis=-1, keepdims=True)                             # (N,G,H*W*C//G) -> (N,G,1)
            var = F.cast(var, "float16")
        else:
            var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)

        x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps)       # (N,G,H*W*C//G) -> (N,C,H,W)
                                ).reshape_like(x)
        x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
        return x_new

Clearly there are several issues, for example:

  • An operator such as F.moments() (quite common) is not implemented in MXNet yet. Hence, my implementation here might be slow.
  • The use of reshape_like() seems unavoidable -> the input tensor has to be kept, which costs RAM.
  • When training with mixed-precision, the above implementation cast a FP16-input into FP32 to avoid loss of precision while calculating both mean & variance. Casting a FP16-tensor to FP32 and then back to FP16 wastes time (this stupid step can be eliminated if we implement this layer at the level of CUDA).

I think this layer is quite important, as not everyone have plenty of GPUs ( if you have plenty, then F.contrib.SyncBatchNorm() will work well ).

P.s. a question to the MXNet authors:
There seem to be an OP called F.SumSquare()( see: https://github.com/dmlc/gluon-cv/blob/0a699a5ccc21310c7ce41d4737f0de9f54fbf45a/gluoncv/model_zoo/syncbn.py#L206 ), which is used for the calculation of the second-order moment I guess. I didn't find it in MXNet's API..., does this OP really exist?

@eric-haibin-lin 楼上的实现ok吗?

@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm?

@haojin2

Attempting implementation in backend...

@chi-hung I think the sumSquare op resides in @zhanghang1989 's fork. @zhanghang1989 could you confirm?

Yes, it stays in some branch of my fork. We used to implement that for SyncBN

Implementation in #14959, almost done, just gradient of data still has some work to be done.

@smorrel1 , @haojin2 has implemented GroupNorm and you may try that. Feel free to reopen the issue if you met any problems.

Was this page helpful?
0 / 5 - 0 ratings