There seems to be a type error when tracing a Graph-UNet with Torch JIT.
[omitted]/torch_sparse/matmul.py in spspmm(src, other, reduce)
94 if reduce == 'sum' or reduce == 'add':
---> 95 return spspmm_sum(src, other)
[omitted]/torch_sparse/matmul.py in spspmm_sum(src, other)
82 rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
---> 83 rowptrA, colA, valueA, rowptrB, colB, valueB, K)
RuntimeError: unsupported output type: Tensor?
Steps to reproduce the behavior:
import os.path as osp
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphUNet
from torch_geometric.utils import dropout_adj
dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset)
data = dataset[0]
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
pool_ratios = [2000 / data.num_nodes, 0.5]
self.unet = GraphUNet(1433, 32, 7,
depth=3, pool_ratios=pool_ratios)
def forward(self, x, edge_index):
e, _ = dropout_adj(edge_index, p=0.2,
force_undirected=True,
num_nodes=2708,
training=self.training)
d1 = F.dropout(x, p=0.92, training=self.training)
u = self.unet(d1, e)
print(type(u))
r = F.log_softmax(u, dim=1)
return r
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)
inp = (data.x.cuda(), data.edge_index.cuda())
scripted_model = torch.jit.trace(model, inp).eval()
Correctly returns a value of type torch.jit.TopLevelTracedModule
I'm working on compilation and deployment of models written with torch_geometric to be deployed on Jetson Nano with tvm, which requires JIT tracing first.
Thanks for this issue. I will look into this. We are currently in the process of providing jit support for all PyTorch modules, so please stay tuned!
Do you potentially have a direction to where it may go wrong? I will be working closely on integrating torch_geometric with tvm the following months, so I may be able to help with some of the issues.
We are currently in the process of making all convs jittable, see here, but tracing should generally work fine. In your case, it might be a problem with torch-sparse.
Most helpful comment
Thanks for this issue. I will look into this. We are currently in the process of providing jit support for all PyTorch modules, so please stay tuned!