Pyro: Make the ParamStore an nn.Module?

Created on 23 Oct 2019  路  6Comments  路  Source: pyro-ppl/pyro

Three recent issues #2058, #2054, #1839 suggest that serialization might be easier if the ParamStore were an nn.Module, e.g. so it could be torch::jit::saved and loaded.

Questions 2019-10-22

  • Would this make it easier to jit-and-serve Pyro models from C++?
  • Could we replace param_store.load() / .save() with torch.load()/.save()?
  • Are there conflicts in the module-vs-dictlike interfaces?
  • For example, can we .clear() a module?

Update 2019-10-29

Now that #2102 is merged, it appears a possible strategy may be to recommend against the pyro.param() statement, and instead rely entirely on pyro.module. By following this strategy, users can avoid serialization issues with Pyro's param store, and rely entirely on native PyTorch nn.Module serialization.

  • Is the "avoid pyro.param" idiom compatible with NumPyro?

Tasks

  • [x] provide autoguide methods to init autoguide parameters @fritzo
  • [x] refactor autoguides to be nn.Modules @neerajprad
discussion

Most helpful comment

For autoguides

I was just facing that issues :smile: I am also looking into making the autoguides nn.Modules.

All 6 comments

Now that #2102 is merged, it appears a possible strategy may be to recommend against the pyro.param() statement, and instead rely entirely on pyro.module.

For autoguides, I think this solution may be a bit clunky since we'll need to write a separate wrapper that sets all the autoguide parameters like init_loc as nn.Parameter attributes on a wrapper class. Maybe the ParamStoreModule solution will work better in that case?

Update - Alternatively, our autoguides should also be nn.Modules!

For autoguides

I was just facing that issues :smile: I am also looking into making the autoguides nn.Modules.

Though a bit different, the Parameterized class has some logic to handle parameters and buffers, which can be helpful for you guys. :)

Let's see how this evolves, so I might be able to make Parameterized a subclass of ParamStoreModule.

@neerajprad What do you think of a mixed autoguide solution?

  1. Use @constrained in the static autoguides
- class AutoContinuous(AutoGuide):
+ class AutoContinuous(AutoGuide, nn.Module):
  ...
  class AutoDiagonalNormal(AutoContinuous):
+ @constrained
+ def scale(self):
+     return constraints.positive
  ...
  1. Use ConstrainedModule for the dynamic autoguides
- class AutoDelta(AutoGuide):
+ class AutoDelta(AutoGuide, ConstrainedModule):
  ...

That sounds reasonable to me. I am working on a PR to try this out.

Closing in favor of #2161

Was this page helpful?
0 / 5 - 0 ratings