We've implemented a generic TorchDistribution.expand_by() method and patched individual distributions to support an .expand() method. Could one or both of these be moved upstream to PyTorch?
.expand() upstream.expand_by() in Pyro.expand() method patches from PyroReshapedDistribution from Pyrocc @neerajprad @alicanb
I would be hesitant about creating a new Distribution subclass upstream, I fear it might confuse people. Can't we implement it by expanding the necessary tensors?
Minimally, I think we can start by moving the .expand (and .expand_by) methods to optionally expand the distribution's batch shape (and resulting parameters) after it has been constructed. This will allow us to use these classes directly, with no additional patching needed. It will also be nice to profile and see how much this saves us vs.:
ReshapedDistribution upstream, and have them be light wrappers in Pyro that call Independent and .expand_by/.expand . I like the idea of returning a shallow copy of the distribution, as torch.expand does. Maybe we should refactor to first remove ReshapedDistribution from Pyro (using Independent instead), and then move the now-standalone methods upstream.
EDIT I've updated the task list to reflect this plan.
@neerajprad is this issue basically finished in our pytorch-0.5.0 branch?
@neerajprad is this issue basically finished in our pytorch-0.5.0 branch?
That's right. This can be closed once our pytorch-0.5.0 branch gets merged (or even, right now, since it is code complete).
Most helpful comment
I like the idea of returning a shallow copy of the distribution, as
torch.expanddoes. Maybe we should refactor to first removeReshapedDistributionfrom Pyro (usingIndependentinstead), and then move the now-standalone methods upstream.EDIT I've updated the task list to reflect this plan.