Pytorch_geometric: Instance Normalization

Created on 18 Sep 2019  路  11Comments  路  Source: rusty1s/pytorch_geometric

Is there a way to do instance normalization using batch indices after a GAT layer?

Most helpful comment

Hi, and thank you @WMF1997 for helping out. Note that you can not sadly just use InstanceNorm1d. This would actually be equivalent to applying batch norm, since we share the batch and the node dimension. Hence, you need to compute the mean and std of each example independently. The good thing is that this is quite trivial to achieve with the help of torch-scatter:

mean = scatter_mean(x, batch, dim=0)
std = scatter_std(x, batch, dim=0)
out = (x - mean[batch]) / std[batch]

We should provide our own operator for this kind of normalization :)

All 11 comments

Hello @Magnetar99 :

  1. I remember that the graph's indicies, and the nodes' and edges' belongings (which node, which edge belongs to which graph) can be found out with torch_geometric.data.Batch.

  2. Perhaps just add a torch.nn.InstanceNorm1d may work.

  3. If I wrote down the code (a proper example), I will add it to the end of this reply~

Some explainations (perhaps not that good):

batching_indices is the key of the issue!

Ah.... I see! What you are worried about is the torch_geometric.data.Batch! if that Batch could
show you the node belongings (which node and which edge belongs to the graph), then this problem can be solved.

torch.nn.InstanceNorm1d

(Sorry,I cannot think of the reason why a InstanceNorm1d is used here... I have no theoretical knowledge, just explain...)

GAT layer (torch_geometric.nn.GATConv) needs x and edge_index as input, and returns a "processed" x (edge_index, i.e. the graph structure of the graph, should be the same as input's edge_index, since no pooling(graph cut) is defined in GATConv). See the source code of forward method in GATConv.

After that, the x returned by GATConv is still a torch.tensor. (perhaps it is a [N, D] shape, where N is the number of nodes in the graph (the only one graph~), and D is the dimension of one node's feature.

then, torch.nn.InstanceNorm1d can be applied to that x.

yours sincerely,
@wmf1997

Hi, and thank you @WMF1997 for helping out. Note that you can not sadly just use InstanceNorm1d. This would actually be equivalent to applying batch norm, since we share the batch and the node dimension. Hence, you need to compute the mean and std of each example independently. The good thing is that this is quite trivial to achieve with the help of torch-scatter:

mean = scatter_mean(x, batch, dim=0)
std = scatter_std(x, batch, dim=0)
out = (x - mean[batch]) / std[batch]

We should provide our own operator for this kind of normalization :)

BatchNorm and InstanceNorm modules are now added to the master branch, see here.

@rusty1s Thanks for the amazing work. I'm having a little trouble understanding the meaning of InstanceNorm as implemented in pytorch_geometric.

count = degree(batch, batch_size, dtype=x.dtype).view(-1, 1)
tmp = scatter_add(x, batch, dim=0, dim_size=batch_size)
mean = tmp / count.clamp(min=1)
tmp = (x - mean[batch])
tmp = scatter_add(tmp * tmp, batch, dim=0, dim_size=batch_size)
var = tmp / count.clamp(min=1)
unbiased_var = tmp / (count - 1).clamp(min=1)

count is the number of vertices in a batch?
For each batch, you compute the summation of the node features, then compute the mean & std?
Then for each batch, you normalize separately?

If I were to implement layernorm instead, I would simply compute a summation across the feature dimension as well? And modify std as well to account for channel for std?
Could you clarify?

Instance normalization normalizes each example independently. For batch_size=1, both BatchNorm and InstanceNorm compute the same thing. This is also tested here. I think you are talking about LayerNorm? For that, you can simply use the PyTorch implementation.

Thanks for the pointer! My confusion came from thinking about the analogy between graph & image, and how graphs were flattened.

But your comment made me realize that the instancenorm is indeed norming separately per batch. I don't think Layernorm would work directly due to the affine having a extra dimension.

Oh, you are right. LayerNorm needs special treatment, too.

Here is an implementation of LayerNorm, based on the InstanceNorm implementation. Might be useful for people where relatives channel statistics are important to preserve (so no InstanceNorm), but the batches are too small for BatchNorm, or if you are training recurrent/NLP models.

A few design choices:

  1. Weights and Bias are of shape (1, C), since graphs can be of variable shape
  2. Weights are always applied, but in non-affine mode they are not trainable. In general I try to avoid if conditions in the forward pass. But if there are edge cases (I guess in GANs sometimes people enable then disable gradients etc), I can change it.
  3. Removed a few redundant computations in the InstanceNorm code.
import torch
from torch_scatter import scatter_add
from torch import nn
from torch_geometric.utils import degree


class LayerNorm(nn.Module):
    def __init__(self, in_channels, eps=1e-5, affine=False):
        super(LayerNorm, self).__init__()
        self.in_channels = in_channels
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(1, in_channels))
        self.bias = torch.nn.Parameter(torch.zeros(1, in_channels))
        if not affine:
            self.weight.requires_grad = False
            self.bias.requires_grad = False

    def forward(self, x, batch=None):
        """"""
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        batch_size = batch.max().item() + 1

        count = degree(batch, batch_size, dtype=x.dtype).view(-1, 1).clamp(min=1) * x.shape[1]
        tmp = scatter_add(x, batch, dim=0, dim_size=batch_size)
        mean = tmp.sum(dim=1, keepdim=True) / count

        mean_diff = (x - mean[batch])
        tmp = scatter_add(mean_diff * mean_diff, batch, dim=0, dim_size=batch_size).sum(dim=1, keepdim=True)
        var = tmp / count

        out = (mean_diff / torch.sqrt(var[batch])) * self.weight + self.bias
        return out

Are you interested in sending a PR?

I can do so. Aside from the bias/weight gradients, does the code look correct?

Yes, this looks super good already :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

zhangfuyang picture zhangfuyang  路  4Comments

WMF1997 picture WMF1997  路  4Comments

datavistics picture datavistics  路  4Comments

yuanx749 picture yuanx749  路  4Comments

zetayue picture zetayue  路  3Comments