Pyro: Move .expand_by() and .expand() logic upstream

Created on 26 Aug 2018  路  6Comments  路  Source: pyro-ppl/pyro

We've implemented a generic TorchDistribution.expand_by() method and patched individual distributions to support an .expand() method. Could one or both of these be moved upstream to PyTorch?

Tasks

  • [x] Assess feasibility and create refactoring plan
  • [x] https://github.com/pytorch/pytorch/pull/11341 Move .expand() upstream
  • [ ] Drop .expand_by() in Pyro
  • [ ] Remove .expand() method patches from Pyro
  • [ ] Remove ReshapedDistribution from Pyro
refactor

Most helpful comment

I like the idea of returning a shallow copy of the distribution, as torch.expand does. Maybe we should refactor to first remove ReshapedDistribution from Pyro (using Independent instead), and then move the now-standalone methods upstream.

EDIT I've updated the task list to reflect this plan.

All 6 comments

cc @neerajprad @alicanb

I would be hesitant about creating a new Distribution subclass upstream, I fear it might confuse people. Can't we implement it by expanding the necessary tensors?

Minimally, I think we can start by moving the .expand (and .expand_by) methods to optionally expand the distribution's batch shape (and resulting parameters) after it has been constructed. This will allow us to use these classes directly, with no additional patching needed. It will also be nice to profile and see how much this saves us vs.:

  • constructing a new distribution instance: I think the savings here should be substantial, but since the broadcasting logic has been converted to C, it may be smaller than what it used to be.
  • appending to sample shape (like we are doing in ReshapedDistribution): If expanding the distribution parameters is not more expensive than appending to the sample shape (due to the broadcasting logic being in C now, for instance), we don't need to move ReshapedDistribution upstream, and have them be light wrappers in Pyro that call Independent and .expand_by/.expand .

I like the idea of returning a shallow copy of the distribution, as torch.expand does. Maybe we should refactor to first remove ReshapedDistribution from Pyro (using Independent instead), and then move the now-standalone methods upstream.

EDIT I've updated the task list to reflect this plan.

@neerajprad is this issue basically finished in our pytorch-0.5.0 branch?

@neerajprad is this issue basically finished in our pytorch-0.5.0 branch?

That's right. This can be closed once our pytorch-0.5.0 branch gets merged (or even, right now, since it is code complete).

Was this page helpful?
0 / 5 - 0 ratings