Pytorch_geometric: How to update edges?

Created on 20 Nov 2019  Â·  2Comments  Â·  Source: rusty1s/pytorch_geometric

Hi,
Just like the nodes update ( and message passing ) by x_j = f(x_i), I want to also update the edges features e_j = f(node1, node2, e_j), with node1, node2 it's the two nodes edge e_j connect to. I don't want to use line_graph in my case. Is there anything like the MassagePassing https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html for the edges?

Most helpful comment

To the best of my knowledge, the best way of doing this is using the MetaLayer (depending on your use-case, simply ignore the GlobalModel part).

For example, for my recent work, I used something close to the example from MetaLayer:

class EdgeModel(torch.nn.Module):
    def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
        super().__init__()
        self.residuals = residuals
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * n_features + n_edge_features, hiddens),
            nn.ReLU(),
            nn.Linear(hiddens, n_targets),
        )

    def forward(self, src, dest, edge_attr, u=None, batch=None):
        out = torch.cat([src, dest, edge_attr], 1)
        out = self.edge_mlp(out)
        if self.residuals:
            out = out + edge_attr
        return out


class NodeModel(torch.nn.Module):
    def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
        super(NodeModel, self).__init__()
        self.residuals = residuals
        self.node_mlp_1 = nn.Sequential(
            nn.Linear(n_features + n_edge_features, hiddens),
            nn.ReLU(),
            nn.Linear(hiddens, n_targets),
        )
        self.node_mlp_2 = nn.Sequential(
            nn.Linear(hiddens + n_features, hiddens),
            nn.ReLU(),
            nn.Linear(hiddens, n_targets),
        )

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[col], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, row, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        out = self.node_mlp_2(out)
        if self.residuals:
            out = out + x
        return out


def build_layer(self, n_hiddens, batchnorm, residuals):
    return geom_nn.MetaLayer(
        edge_model=EdgeModel(n_hiddens, n_hiddens, n_hiddens, n_hiddens, residuals=residuals),
        node_model=NodeModel(n_hiddens, n_hiddens, n_hiddens, n_hiddens, residuals=residuals),
    )

This can then be called using x, edge_attr, _ = layer(x, edge_index, edge_attr=edge_attr).

Note that this didn't use or produce a global state, but did update the edge attributes.

All 2 comments

To the best of my knowledge, the best way of doing this is using the MetaLayer (depending on your use-case, simply ignore the GlobalModel part).

For example, for my recent work, I used something close to the example from MetaLayer:

class EdgeModel(torch.nn.Module):
    def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
        super().__init__()
        self.residuals = residuals
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * n_features + n_edge_features, hiddens),
            nn.ReLU(),
            nn.Linear(hiddens, n_targets),
        )

    def forward(self, src, dest, edge_attr, u=None, batch=None):
        out = torch.cat([src, dest, edge_attr], 1)
        out = self.edge_mlp(out)
        if self.residuals:
            out = out + edge_attr
        return out


class NodeModel(torch.nn.Module):
    def __init__(self, n_features, n_edge_features, hiddens, n_targets, residuals):
        super(NodeModel, self).__init__()
        self.residuals = residuals
        self.node_mlp_1 = nn.Sequential(
            nn.Linear(n_features + n_edge_features, hiddens),
            nn.ReLU(),
            nn.Linear(hiddens, n_targets),
        )
        self.node_mlp_2 = nn.Sequential(
            nn.Linear(hiddens + n_features, hiddens),
            nn.ReLU(),
            nn.Linear(hiddens, n_targets),
        )

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[col], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, row, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        out = self.node_mlp_2(out)
        if self.residuals:
            out = out + x
        return out


def build_layer(self, n_hiddens, batchnorm, residuals):
    return geom_nn.MetaLayer(
        edge_model=EdgeModel(n_hiddens, n_hiddens, n_hiddens, n_hiddens, residuals=residuals),
        node_model=NodeModel(n_hiddens, n_hiddens, n_hiddens, n_hiddens, residuals=residuals),
    )

This can then be called using x, edge_attr, _ = layer(x, edge_index, edge_attr=edge_attr).

Note that this didn't use or produce a global state, but did update the edge attributes.

Great answer! Following up on this, updating edge features is quite trivial without any helper functions, e.g.:

row, col = edge_index
new_edge_attr = self.mlp(torch.cat([x[row], x[col], edge_attr], dim=-1))

Providing a more elegant API in analogy to MessagePassing is interesting though. Will think about it!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yuanx749 picture yuanx749  Â·  4Comments

yanzhangnlp picture yanzhangnlp  Â·  3Comments

WeiyiLee6666 picture WeiyiLee6666  Â·  4Comments

zhangfuyang picture zhangfuyang  Â·  4Comments

weihua916 picture weihua916  Â·  3Comments