Pyro: Make autoguides more usable

Created on 6 Nov 2019  路  5Comments  路  Source: pyro-ppl/pyro

This aims to make our most common autoguides easier to use: AutoDelta, AutoDiagonalNormal, AutoMultivariateNormal, AutoLowRankMultivariateNormal.

Feel free to add to this list (anyone :smile:) if you have usability requests.

I have added this to the 1.0 milestone because it involves interface changes that will enable later refactoring (e.g. #2078) while preserving backwards compatibility.

Tasks

  • [x] #2119 document use of custom init_loc_fn, and add examples in tutorials
  • [x] #2119 add error checking logic in init_loc_fn
  • [x] #2119 add easier init_scale args for Auto*Normal guides
  • [x] #2126 make autoguides into nn.Modules so they are easier to torch.load()/.save()
    (this is an alternative to making the param store an nn.Module #2078 )
  • [x] allow non-python-identifiers as site names in autoguides, which was broken by #2126
  • make it easier to set optim configs per-guide-param (e.g. I often lower the learning rate lr *= 0.1 for the auto_cov_factor param of AutoLowRankMultivariateNormal).
  • [x] #2127 Stabilize parameterization of AutoLowRankMultivariateNormal
  • [x] #2129 Support pickling of easyguides
  • [x] #2137 Make AutoGuideList a bonafide nn.ModuleList
  • [x] #2140 make easyguides into nn.Modules
  • [x] #2148 Fix saving of AutoGuides in the param store
  • [ ] #2156 Change default init_scale from 1.0 to 0.1 (@fritzo often uses this in practice)
usability

All 5 comments

@martinjankowiak @stefanwebb let me know if you have any other suggestions!

@fritzo Could we have a method in PyroModule to get constrained values (without having to use pyro.get_param_store())? Right now, we can observe unconstrained parameters through .named_parameters() method but it would be nice to have support for constrained parameters too. :)

@fehiepsi sure, feel free to add a .get_named_pyro_params() or similar.

@fritzo Just discover a bug with the dev branch :D

class Child(PyroModule):
    def __init__(self, x):
        super().__init__()
        self.a = PyroParam(torch.tensor(x))

class Family(PyroModule):
    def __init__(self):
        super().__init__()
        self.c = Child(1.0)

    def forward(self):
        return self.c.a

f.c = Child(3.)
assert f() == 3

@fehiepsi good point, I believe that's because the __delattr__ logic currently only removes PyroParam and nn.Param from the param store (__delattr__ is triggered when __setattr__ overwrites something). Do you want to try modifying __delattr__ to correctly handle PyroModule and nn.Module? That seems like a good first PR contributing to PyroModule :smile:

Was this page helpful?
0 / 5 - 0 ratings