This aims to make our most common autoguides easier to use: AutoDelta, AutoDiagonalNormal, AutoMultivariateNormal, AutoLowRankMultivariateNormal.
Feel free to add to this list (anyone :smile:) if you have usability requests.
I have added this to the 1.0 milestone because it involves interface changes that will enable later refactoring (e.g. #2078) while preserving backwards compatibility.
init_loc_fn, and add examples in tutorialsinit_loc_fninit_scale args for Auto*Normal guidesnn.Modules so they are easier to torch.load()/.save()nn.Module #2078 )lr *= 0.1 for the auto_cov_factor param of AutoLowRankMultivariateNormal).AutoLowRankMultivariateNormalAutoGuideList a bonafide nn.ModuleListnn.ModulesAutoGuides in the param storeinit_scale from 1.0 to 0.1 (@fritzo often uses this in practice)@martinjankowiak @stefanwebb let me know if you have any other suggestions!
@fritzo Could we have a method in PyroModule to get constrained values (without having to use pyro.get_param_store())? Right now, we can observe unconstrained parameters through .named_parameters() method but it would be nice to have support for constrained parameters too. :)
@fehiepsi sure, feel free to add a .get_named_pyro_params() or similar.
@fritzo Just discover a bug with the dev branch :D
class Child(PyroModule):
def __init__(self, x):
super().__init__()
self.a = PyroParam(torch.tensor(x))
class Family(PyroModule):
def __init__(self):
super().__init__()
self.c = Child(1.0)
def forward(self):
return self.c.a
f.c = Child(3.)
assert f() == 3
@fehiepsi good point, I believe that's because the __delattr__ logic currently only removes PyroParam and nn.Param from the param store (__delattr__ is triggered when __setattr__ overwrites something). Do you want to try modifying __delattr__ to correctly handle PyroModule and nn.Module? That seems like a good first PR contributing to PyroModule :smile: