Pytorch_geometric: How to implement other GraphSAGE Aggregator functions like Max Pool?

Created on 23 Apr 2020  路  10Comments  路  Source: rusty1s/pytorch_geometric

Thanks for all the great work on the package and sorry if this has been asked before.

I was wondering if it possible to use different aggregation functions, like Max-Pool or LSTM aggregation from the original GraphSAGE paper, with the existing SAGEConv layer? Looking at the documentation, it seems that only the mean aggregation is available.

Most helpful comment

torch-scatter does already support max reduction :)

All 10 comments

Max-Pool is indeed possible, but we do not have an implementation for it. Feel free to add it if you like. A fast LSTM aggregation is currently possible, but I am working on it.

Hi, thanks so much for the reply. I would love to add it if I am able.
Would it be possibe to implemet it how it is in the paper (with a second paramter matrix W_pool) using the current SAGEConv layer? Or would I need to implement it from scratch using MessagePassing? Would you have any tips on where to start?

We should maybe first come up with a separate implementation and then decide how to merge the different variants. I suggest to copy most of the SAGEConv layer as a starting point. And yes, you need to introduce W_pool for the max aggregation.

Thanks - I'll hopefully get time to look at this over the next week.

Hey guys, is there an update on the two aggregator functions ? Would be keen to know.

Hey - I am still planning to have a go at this as soon as possible, I was inundated with work since the lock down.
If you have any idea, or would like to help, I would be most grateful.

It looks like maybe you just have to add an additional function here maybe called scatter_maxpool or something?

torch-scatter does already support max reduction :)

Hi,
Is there any update on max Pool aggregator? I would love to know.
Also I think, can we not just change the aggregator value by letting pass an extra argument in SAGEConv constructor of torch_geometric/nn/conv/sage_conv.py ? Is there anything else that needs to be done? The aggregator value there is hard-coded as 'add'.
Thanks

Update 1: Okay, somehow my jupyter notebook still shows the previous implementation of Sage_conv.py, I checked the new code and see that the custom aggregator has been added, how are we supposed to pass it as an argument? just like aggr = 'max' right? Error reports

TypeError: __init__() got multiple values for keyword argument 'aggr', Line 36 sage_conv.py

Update 2: Okay, I think I need to reinstall torch_geometric in python.

Update 3: Okay, I am Unable to install the latest code! What to do????? I have uninstall and reinstall all binaries and the pytorch_geometric repo(with --no-cache-dir flag) ?? It says PyG 1.6.1 installed, but I still can't see this update.

HELP!!!!!! :/

Very Very Thanks

You can install from source:

pip install git+https://github.com/rusty1s/pytorch_geometric.git

or change the aggr flag after initialization:

conv = SAGEConv(...)
conv.aggr = 'max'
Was this page helpful?
0 / 5 - 0 ratings