Suppose we have a collection of non-batched single graph with feature x and edges edge_index, where x are different for different graphs, but the edge_index is always the same. Is there an efficient way to batch this list of single-graphs?
My current strategy is to create independent copy of edge_index and assign it to each Data object constructor, then pass the data_list to the Batch constructor, but it doesn't seem efficient in memory usage. What would be the "correct way" to do it?
This is an interesting question. Your approach certainly works, although you will have a high memory layout. You can reduce this quite a bit by not copying the edge_index for each data object, e.g.,
data_list = [Data(x=x1, edge_index=edge_index), Data(x=x2, edge_index=edge_index)
should just work fine (although the batch constructor is forced to copy edge indices). I do not know of any other way if one wants to maintain sparsity. As far as I know, PyTorch does not support batch-wise sparse matrix multiplications :(
A more elegant way would be to simply consider working on dense adjacency matrices. In this setting, you do have batch-wise matrix multiplication, and your memory layout is still low by only storing a single adjacency matrix of shape 1 x N x N. We also provide a variety of dense operators like DenseGCNConv.
Sorry for reopening this issue. There is one more thing I would like to confirm:
The goal of doing data_list = [Data(x=x1, edge_index=edge_index), Data(x=x2, edge_index=edge_index) is to avoid allocating new memory for each Data's edges. However, according to the tutorial of batching, the system will create an adjacency matrix by "stacking" the edge_index of each graph (i.e. Data object) diagonally as a huge adjacency matrix. Wouldn't that process still allocate extra memory for each graph's connections? I am confused about why doing the line of code above could reduce memory usage.
Yes exactly, the batch constructor is forced to copy edge indices. We sadly cannot fix that. What I meant was that you do not need to do:
data_list = [Data(x=x1, edge_index=edge_index.clone()), Data(x=x2, edge_index=edge_index.clone()])