Hi,
Just like the nodes update ( and message passing ) by x_j = f(x_i), I want to also update the edges features e_j = f(node1, node2, e_j), with node1, node2 it's the two nodes edge e_j connect to. I don't want to use line_graph in my case. Is there anything like the MassagePassing https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html for the edges?
To the best of my knowledge, the best way of doing this is using the MetaLayer (depending on your use-case, simply ignore the GlobalModel part).
For example, for my recent work, I used something close to the example from MetaLayer:
class EdgeModel(torch.nn.Module):
def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
super().__init__()
self.residuals = residuals
self.edge_mlp = nn.Sequential(
nn.Linear(2 * n_features + n_edge_features, hiddens),
nn.ReLU(),
nn.Linear(hiddens, n_targets),
)
def forward(self, src, dest, edge_attr, u=None, batch=None):
out = torch.cat([src, dest, edge_attr], 1)
out = self.edge_mlp(out)
if self.residuals:
out = out + edge_attr
return out
class NodeModel(torch.nn.Module):
def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
super(NodeModel, self).__init__()
self.residuals = residuals
self.node_mlp_1 = nn.Sequential(
nn.Linear(n_features + n_edge_features, hiddens),
nn.ReLU(),
nn.Linear(hiddens, n_targets),
)
self.node_mlp_2 = nn.Sequential(
nn.Linear(hiddens + n_features, hiddens),
nn.ReLU(),
nn.Linear(hiddens, n_targets),
)
def forward(self, x, edge_index, edge_attr, u, batch):
row, col = edge_index
out = torch.cat([x[col], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter_mean(out, row, dim=0, dim_size=x.size(0))
out = torch.cat([x, out], dim=1)
out = self.node_mlp_2(out)
if self.residuals:
out = out + x
return out
def build_layer(self, n_hiddens, batchnorm, residuals):
return geom_nn.MetaLayer(
edge_model=EdgeModel(n_hiddens, n_hiddens, n_hiddens, n_hiddens, residuals=residuals),
node_model=NodeModel(n_hiddens, n_hiddens, n_hiddens, n_hiddens, residuals=residuals),
)
This can then be called using x, edge_attr, _ = layer(x, edge_index, edge_attr=edge_attr).
Note that this didn't use or produce a global state, but did update the edge attributes.
Great answer! Following up on this, updating edge features is quite trivial without any helper functions, e.g.:
row, col = edge_index
new_edge_attr = self.mlp(torch.cat([x[row], x[col], edge_attr], dim=-1))
Providing a more elegant API in analogy to MessagePassing is interesting though. Will think about it!
Most helpful comment
To the best of my knowledge, the best way of doing this is using the
MetaLayer(depending on your use-case, simply ignore theGlobalModelpart).For example, for my recent work, I used something close to the example from
MetaLayer:This can then be called using
x, edge_attr, _ = layer(x, edge_index, edge_attr=edge_attr).Note that this didn't use or produce a global state, but did update the edge attributes.