Is there a way to do instance normalization using batch indices after a GAT layer?
Hello @Magnetar99 :
I remember that the graph's indicies, and the nodes' and edges' belongings (which node, which edge belongs to which graph) can be found out with torch_geometric.data.Batch.
Perhaps just add a torch.nn.InstanceNorm1d may work.
batching_indices is the key of the issue!Ah.... I see! What you are worried about is the torch_geometric.data.Batch! if that Batch could
show you the node belongings (which node and which edge belongs to the graph), then this problem can be solved.
torch.nn.InstanceNorm1d(Sorry,I cannot think of the reason why a InstanceNorm1d is used here... I have no theoretical knowledge, just explain...)
GAT layer (torch_geometric.nn.GATConv) needs x and edge_index as input, and returns a "processed" x (edge_index, i.e. the graph structure of the graph, should be the same as input's edge_index, since no pooling(graph cut) is defined in GATConv). See the source code of forward method in GATConv.
After that, the x returned by GATConv is still a torch.tensor. (perhaps it is a [N, D] shape, where N is the number of nodes in the graph (the only one graph~), and D is the dimension of one node's feature.
then, torch.nn.InstanceNorm1d can be applied to that x.
yours sincerely,
@wmf1997
Hi, and thank you @WMF1997 for helping out. Note that you can not sadly just use InstanceNorm1d. This would actually be equivalent to applying batch norm, since we share the batch and the node dimension. Hence, you need to compute the mean and std of each example independently. The good thing is that this is quite trivial to achieve with the help of torch-scatter:
mean = scatter_mean(x, batch, dim=0)
std = scatter_std(x, batch, dim=0)
out = (x - mean[batch]) / std[batch]
We should provide our own operator for this kind of normalization :)
BatchNorm and InstanceNorm modules are now added to the master branch, see here.
@rusty1s Thanks for the amazing work. I'm having a little trouble understanding the meaning of InstanceNorm as implemented in pytorch_geometric.
count = degree(batch, batch_size, dtype=x.dtype).view(-1, 1)
tmp = scatter_add(x, batch, dim=0, dim_size=batch_size)
mean = tmp / count.clamp(min=1)
tmp = (x - mean[batch])
tmp = scatter_add(tmp * tmp, batch, dim=0, dim_size=batch_size)
var = tmp / count.clamp(min=1)
unbiased_var = tmp / (count - 1).clamp(min=1)
count is the number of vertices in a batch?
For each batch, you compute the summation of the node features, then compute the mean & std?
Then for each batch, you normalize separately?
If I were to implement layernorm instead, I would simply compute a summation across the feature dimension as well? And modify std as well to account for channel for std?
Could you clarify?
Instance normalization normalizes each example independently. For batch_size=1, both BatchNorm and InstanceNorm compute the same thing. This is also tested here. I think you are talking about LayerNorm? For that, you can simply use the PyTorch implementation.
Thanks for the pointer! My confusion came from thinking about the analogy between graph & image, and how graphs were flattened.
But your comment made me realize that the instancenorm is indeed norming separately per batch. I don't think Layernorm would work directly due to the affine having a extra dimension.
Oh, you are right. LayerNorm needs special treatment, too.
Here is an implementation of LayerNorm, based on the InstanceNorm implementation. Might be useful for people where relatives channel statistics are important to preserve (so no InstanceNorm), but the batches are too small for BatchNorm, or if you are training recurrent/NLP models.
A few design choices:
import torch
from torch_scatter import scatter_add
from torch import nn
from torch_geometric.utils import degree
class LayerNorm(nn.Module):
def __init__(self, in_channels, eps=1e-5, affine=False):
super(LayerNorm, self).__init__()
self.in_channels = in_channels
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(1, in_channels))
self.bias = torch.nn.Parameter(torch.zeros(1, in_channels))
if not affine:
self.weight.requires_grad = False
self.bias.requires_grad = False
def forward(self, x, batch=None):
""""""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
batch_size = batch.max().item() + 1
count = degree(batch, batch_size, dtype=x.dtype).view(-1, 1).clamp(min=1) * x.shape[1]
tmp = scatter_add(x, batch, dim=0, dim_size=batch_size)
mean = tmp.sum(dim=1, keepdim=True) / count
mean_diff = (x - mean[batch])
tmp = scatter_add(mean_diff * mean_diff, batch, dim=0, dim_size=batch_size).sum(dim=1, keepdim=True)
var = tmp / count
out = (mean_diff / torch.sqrt(var[batch])) * self.weight + self.bias
return out
Are you interested in sending a PR?
I can do so. Aside from the bias/weight gradients, does the code look correct?
Yes, this looks super good already :)
Most helpful comment
Hi, and thank you @WMF1997 for helping out. Note that you can not sadly just use InstanceNorm1d. This would actually be equivalent to applying batch norm, since we share the batch and the node dimension. Hence, you need to compute the mean and std of each example independently. The good thing is that this is quite trivial to achieve with the help of
torch-scatter:We should provide our own operator for this kind of normalization :)