Pytorch_geometric: Use Pytorch geometric to implement Tree LSTM.

Created on 11 Mar 2019  路  6Comments  路  Source: rusty1s/pytorch_geometric

I wonder whether I can use this framework to implement Tree-LSTM.

Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks [ACL 2015] https://arxiv.org/pdf/1503.00075.pdf

It can be implemented using DGL framework with an extra function: dgl.prop_nodes_topo(g), which means that "messages start from leaves of the tree, and propagate/processed upwards until they reach the roots."

https://docs.dgl.ai/en/latest/tutorials/models/2_small_graph/3_tree-lstm.html

I wonder if I can use this framework to reduce the training time of Tree-LSTM.
So.
Has this framework provided the implementation of Tree-LSTM?
Or, does it have similar function to dgl.prop_nodes_topo(g)?
Or, can you give a brief guide to me on how to implement it with this framework?

Most helpful comment

What I mean is that propagation for a N-ary Tree should not necessarily be defined the same as propagation in graph neural networks (because we can abuse that we have the same number of neighbors for each node). DGL can handle regular graphs quite well, but lacks in performance for general graphs due to its separate mailbox system for different degrees. For regular graphs, it is best to define propagation as a dense [num_nodes, num_neighbors, num_features] mechanism. And it might be even better to implement this all from scratch with the PyTorch Extension API to avoid the O(log N) Python loop for extra speed ups. In contrast, if you want to define Child-Sum Tree-LSTMs, the general graph propagation scheme from my previous reply is the way to go (and should be as fast as it can be).

All 6 comments

I also have a similar question.

It shows that pytorch_geometric is faster than DGL, but I am also not quite sure about how to propagate the message along graph neural networks like Tree-LSTM by using pytorch_geometric.

We currently do not provide a sequential propagation of messages. I'm sorry! However, given that we have a traversal order this can be easily implemented. E.g. given a (batched) tree

   /   \
  4     5
 / \   / \
0   1 2   3

with traversal order masks

orders = [torch.tensor([1, 1, 1, 1, 0, 0, 0], dtype=torch.uint8),
          torch.tensor([0, 0, 0, 0, 1, 1, 0], dtype=torch.uint8)]

this can be implemented by the MessagePassing interface via

class TreeLSTM(MessagePassing):
    def __init__(self, ...):
         # define linear layers

    def forward(self, x, edge_index, orders):
        for order in orders:
            mask = order[edge_index[1]]
            x = self.propagate('add', edge_index[:, mask], x=x)

Due to clarity, I omitted any clutter (linear layers, dropout, ...). I hope this helps. However, I'm not a big fan of treating "regular" structures like trees the same as arbitrary graphs. IMO, this is best implemented with its own dedicated CUDA kernel.

"IMO, this is best implemented with its own dedicated CUDA kernel."

Any recommendations on how to implement it with "its own dedicated CUDA kernel"?

Special thanks to your help.

What I mean is that propagation for a N-ary Tree should not necessarily be defined the same as propagation in graph neural networks (because we can abuse that we have the same number of neighbors for each node). DGL can handle regular graphs quite well, but lacks in performance for general graphs due to its separate mailbox system for different degrees. For regular graphs, it is best to define propagation as a dense [num_nodes, num_neighbors, num_features] mechanism. And it might be even better to implement this all from scratch with the PyTorch Extension API to avoid the O(log N) Python loop for extra speed ups. In contrast, if you want to define Child-Sum Tree-LSTMs, the general graph propagation scheme from my previous reply is the way to go (and should be as fast as it can be).

Perhaps I am misunderstanding Tree-LSTMs, but do they necessarily operate on N-ary trees and not arbitrary trees?
I get that if you have N-ary trees for a fixed N, then pytorch geometric seems like overkill, but what about when you have trees where each node has a different number of children?

I don't think they necessarily have to, but you may need to apply some padding nonetheless for torch.nn.LSTM to work.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yanzhangnlp picture yanzhangnlp  路  3Comments

raeidsaqur picture raeidsaqur  路  4Comments

zc-alexfan picture zc-alexfan  路  3Comments

SaschaStenger picture SaschaStenger  路  4Comments

JsBlueCat picture JsBlueCat  路  3Comments