Three recent issues #2058, #2054, #1839 suggest that serialization might be easier if the ParamStore were an nn.Module, e.g. so it could be torch::jit::saved and loaded.
param_store.load() / .save() with torch.load()/.save()?.clear() a module?Now that #2102 is merged, it appears a possible strategy may be to recommend against the pyro.param() statement, and instead rely entirely on pyro.module. By following this strategy, users can avoid serialization issues with Pyro's param store, and rely entirely on native PyTorch nn.Module serialization.
pyro.param" idiom compatible with NumPyro?nn.Modules @neerajprad Now that #2102 is merged, it appears a possible strategy may be to recommend against the pyro.param() statement, and instead rely entirely on pyro.module.
For autoguides, I think this solution may be a bit clunky since we'll need to write a separate wrapper that sets all the autoguide parameters like init_loc as nn.Parameter attributes on a wrapper class. Maybe the ParamStoreModule solution will work better in that case?
Update - Alternatively, our autoguides should also be nn.Modules!
For autoguides
I was just facing that issues :smile: I am also looking into making the autoguides nn.Modules.
Though a bit different, the Parameterized class has some logic to handle parameters and buffers, which can be helpful for you guys. :)
Let's see how this evolves, so I might be able to make Parameterized a subclass of ParamStoreModule.
@neerajprad What do you think of a mixed autoguide solution?
@constrained in the static autoguides - class AutoContinuous(AutoGuide):
+ class AutoContinuous(AutoGuide, nn.Module):
...
class AutoDiagonalNormal(AutoContinuous):
+ @constrained
+ def scale(self):
+ return constraints.positive
...
ConstrainedModule for the dynamic autoguides- class AutoDelta(AutoGuide):
+ class AutoDelta(AutoGuide, ConstrainedModule):
...
That sounds reasonable to me. I am working on a PR to try this out.
Closing in favor of #2161
Most helpful comment
I was just facing that issues :smile: I am also looking into making the autoguides
nn.Modules.