For my application, I am considering updating edge features along with node ones. Let's suppose for now that in each step (aka network layer) the edge attributes will be updated using their current values and the feature values of the two nodes that are relevant to the edge (e.g. using an MLP).
I have followed the discussion here, which suggests a way of caching the edge attributes. I am also thinking of overriding the forward() method such that it returns both the updated node and edge features, which are then passed into the next layer.
However, I am struggling to come up with an efficient way of implementing the edge attribute update operation based on the current edge attribute values and those of the two relevant nodes. In other words, how could I filter/access the relevant node features efficiently? Could perhaps pytorch_scatter come into the rescue here?
Any ideas would be much appreciated. Thanks
Updating edge features is actually quite easy with the use of edge_index:
src, dst = edge_index
edge_repr = torch.cat([x[src], edge_attr, x[dst]], dim=-1)
edge_attr = self.mlp(edge_repr)
You can apply this in addition to propagating nodes via MessagePassing.propagate, and return the results as a tuple in forward.
Hope this helps!
Can't believe I missed something so obvious.
It definitely helped, thanks a lot!
I need to do the almost same thing but with a subtle difference. I want to calculate the edge representations of the graph as well as the node representations. So, in each step, I want to calculate the average of the edge features of the incoming edges to a node and concatenate them with each outgoing edges' features. Then, apply a maximum function followed by an MLP to find the new representation of each edge finally. Would somebody help me to implement this mechanism?
Do you have any ideas @rusty1s?
I'm not sure if I understand you correctly, but this might do to the trick:
src, dst = edge_index
mean_incoming_edge_attr = scatter_mean(edge_attr, src, dim=0)
edge_attr = torch.cat([edge_attr, mean_incoming_edge_attr[dst], dim=0)
...
Thank you for the code, @rusty1s. It's not the exact thing that I was looking for. But it was quite helpful. I got the idea. Thanks.
Most helpful comment
Updating edge features is actually quite easy with the use of
edge_index:You can apply this in addition to propagating nodes via
MessagePassing.propagate, and return the results as a tuple inforward.Hope this helps!