Pytorch_geometric: MessagePassing.jittable() remaining features and planning.

Created on 1 Jun 2020  路  14Comments  路  Source: rusty1s/pytorch_geometric

This is an issue to track followup on #1256, #1257, #1258, #1259.

We can annotate below with further issue and PR #s, as progress is made.
Let's use this issue as a forum for discussion about what we want out of the jittable() interface.

During review there were a few functionality issues, or requests that came up (box checked if completed, with PR reference):

  • [x] - return a jittable copy of the conv layer instead of the new class (https://github.com/rusty1s/pytorch_geometric/commit/dc87faa6a9ae1b71dbbd6d34f834dc7053803fe6)
  • [x] - deal with overloads in a more graceful way
  • [ ] - add a utility that generates model config files for deployment platforms.

Here's a list of jittable convolutional ops (box is checked if tested and confirmed):

  • [x] Basic test in test/nn/conv/test_message_passing.py
  • [x] AGNNConv
  • [x] APPNP
  • [x] ARMAConv
  • [x] CGConv
  • [x] ChebConv
  • [x] DNAConv
  • [x] EdgeConv
  • [x] DynamicEdgeConv - #1366
  • [x] FeaStConv
  • [x] GATConv
  • [x] GatedGraphConv
  • [x] GCNConv
  • [x] GINConv
  • [x] GINEConv
  • [x] GMMConv
  • [x] GraphConv
  • [x] GravNetConv - #1366
  • [ ] HypergraphConv - NB: needs overall refactoring
  • [x] NNConv
  • [x] PointConv
  • [x] PPFConv
  • [x] RGCNConv - NB: this needs to wait for better handling of multiple propagate types
  • [x] SAGEConv
  • [x] SGConv
  • [x] SignedConv
  • [x] SplineConv
  • [x] TAGConv
  • [x] XConv

All 14 comments

@liaopeiyuan @pierthodo

@rusty1s you already changed "return a jittable copy of the conv layer instead of the new class" in one of the two PRs submitted so far? I didn't catch it if you did, could you put a PR or commit reference?

Hi everyone, I worked on a follow-up PR for the JIT interface in https://github.com/rusty1s/pytorch_geometric/pull/1309 (ready to merge), with the following features:

  • jittable does not need to trace anymore, instead users specify the types passed to propagate explicitly, e.g., via # propagate_type: (x: Tensor, edge_weight: Optional[Tensor])
  • It can now handle bipartite graphs and sparse tensors within the same JIT instance via Union. Note that Union is not naturally supported by PyTorch. Instead I cast each Union combiniation into its own @overload type.

All tests pass, but bipartite graph support/SparseTensors are only added for a couple of convs for now. I would like to merge this PR first before continuing working on it.

Here is a basic example:

class MyConv(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int):
        """"""
        super(MyConv, self).__init__(aggr='add')

        self.lin_l = Linear(in_channels, out_channels)
        self.lin_r = Linear(in_channels, out_channels)

    def forward(self, x: Tensor,
                edge_index: Union[Tensor, SparseTensor]) -> Tensor:
        """"""
        # propagate_type: (x: Tensor)
        out = self.propagate(edge_index, x=x, size=None)
        return self.lin_l(out) + self.lin_r(x)

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

And here is a more complex one that supports bipartite-graphs.

class MyConv(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int):
        """"""
        super(MyConv, self).__init__(aggr='add')

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels)
        self.lin_r = Linear(in_channels[1], out_channels)

    def forward(self,
                x: Union[Tensor, Tuple[Tensor, OptTensor]],
                edge_index: Union[Tensor, SparseTensor],
                edge_weight: OptTensor = None,
                size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: Tuple[Tensor, OptTensor] = (x, x)

        # propagate_type: (x: Tuple[Tensor, OptTensor], edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=size)
        out = self.lin_l(out)

        x_r = x[1]
        if x_r is not None:
            out += self.lin_r(x_r)

        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1)

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: Tuple[Tensor, OptTensor]) -> Tensor:
        return matmul(adj_t, x[0], reduce=self.aggr)

Let me know what you think!

Hey - this is really cool, but I am not a big fan of parsing code comments to drive functionality.
Too much room for error, especially for users who might be new. Though given the constraints we have, it's a solid solution.

Maybe a better way to do this would be to specify a __propagate_signatures__ member variable containing a list of dicts of type hints.

class MyConv(MessagePassing):
    def __init__(args...):
        ...
        self.__propagate_signatures__ = [
            { 'x': Tensor, 'edge_index': Tensor, ...}
        ]

This way it naturally grows when we need to overload things, and is really properly part of the python, and still gets rid of the tracing entirely.

What do you think?

I think so, too. One reason why I opted for the current approach was because Python has a really similar type definition interface with # type: (...) -> .... Maybe I can implement both.

Ah, I thought they had deprecated that in favor of the newer system!

Let's go for both and see which one ends up feeling better in practice?

Done :)

Hi @rusty1s and @lgray ,

This new jit functionality is great! Do you guys also plan onnx support ? As far as I have seen there are issues converting a jittable model to onnx (https://github.com/pytorch/pytorch/issues/34002), and ideally I'd like to export a pytorch_geometric model to onnx.

@chriss2401 ONNX is (quite) a bit less flexible than TorchScript, and would similarly need additional C++ bindings written to get all the ops into ONNX. I am not sure of the long term plans of pytorch to keep supporting ONNX nor to what degree they'd cover onnx features. Similarly, TF is also going with their own serialization / jit format more recently.

Another way to approach this would be, what are you trying to achieve with conversion to ONNX? There may be a TorchScript friendly way to achieve the same thing.

Going back to adding the ops to ONNX:

  • it's probably best if those are contributed by those who need it and then they are refined prior to merging. Otherwise it goes onto a long backlog of things that need to get done. :-)

@lgray thanks for the quick answer. I need to deploy my application in C#, and microsoft has created a really nice library called onnxruntime for handling pre-trained models that are converted to onnx for fast inference. What I like about onnx is that it is agnostic to the training framework.

I am open to alternatives, but the only other thing I can think of at this point is to make a C# wrapper on top of PyTorch's C++ LibTorch library to run the jittable models (which is not that bad of a solution, just needs some back and forth between c++ and c#).

@chriss2401 There appears to already be a somewhat mature solution for that: https://github.com/xamarin/TorchSharp

@lgray great, I will try and one run one of your test jittable models using this. Thanks!

@chriss2401 just to make sure you don't run into problems - you'll have to build the libtorchscatter/cluster/etc. libraries yourself via cmake for those packages. Make sure the script module can load them, I think they just need to be in LD_LIBRARY_PATH and you're good.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Zhangzhk0819 picture Zhangzhk0819  路  3Comments

liaopeiyuan picture liaopeiyuan  路  3Comments

zc-alexfan picture zc-alexfan  路  3Comments

Raverss picture Raverss  路  3Comments

SaschaStenger picture SaschaStenger  路  4Comments