It would be nice to support HMC/NUTS-within-Gibbs / online HMC (e.g. HMC with subsampling), but this may require some changes to the HMC and NUTS interfaces. This issue tracks those changes.
*args,**kwargs across HMC steps.potential_fn() closure created by _pe_maker(). It's easy enough to pass references rather than tensors (e.g. [x] rather than a tensor x), but that breaks jit compilation.potential_fn, transforms, and set .initial_params.potential_energy in .sample().params (e.g. when new data is added via a side channel into potential_fn()). Maybe a .clear_cache() arg and logic to recompute the results of ._fetch_from_cache()?adapt_window steps when provided, and when warmup_steps=float('inf').kernel.setup(warmup_steps=9999999999), forcing effectively infinite adaptation phase. To set a non-exponential schedule for mass matrix update, set the kernel._adapter._adaptation_schedule property.@neerajprad @fehiepsi maybe we can incorporate these into your current refactoring work.
Workaround: define a custom potential_fn, transforms, and set .initial_params.
Are you able to get this to work? I think we may need some changes for this since by default HMC assumes that we need to run MCMC on all the parameters passed to sample. It makes similar assumptions for potential_fn, which should take a dict of unconstrained samples as arguments. While it can be bound in a closure with changing args, kwargs, we'll not be able to JIT it.
Are you able to get [the workaround] to work?
Yes, I believe this should work but I haven't verified end-to-end. My plan is to:
potential_fn() closure.potential_fn() by hand before passing it to HMC, taking care to handle both MCMC params and non-MCMC dynamic params. I don't think it's worth building the extra JIT handling into the general HMC machinery. It seems cleaner to use the flexibility of the potential_fn() arg.@fritzo About warmup indefinitely, do you mean to collect samples from warmup phase? If that is the case, then I think we can just support collecting from warmup phase and define a very large warmup_steps or make a generator to keep generate window lengths (currently, the length of the next window will be 2x the length of the previous window) if warmup_steps=inf. Currently, the mass matrix is calculated from samples in a window (rather than calculating from Hessian of potential fn for the last sample).
do you mean to collect samples from warmup phase?
Yes, I plan to discard a few samples at the beginning but then start collecting samples while continuing to warmup (i.e. continuing to adapt the mass matrix and stepsize etc.).
Then I think we just add a flag discard_warmup to maintain that behaviour. WDYT @neerajprad ?
I don't think a discard_warmup flag will help. Already to do online MCMC, we'll need to use an HMC or NUTS kernel directly rather than use the higher level MCMC interface. Also I believe you'll need a new adapt_window parameter, or sth since when I set warmup_steps=float('inf') then we'll need to choose some window size other than float('inf') / 5 :smile: (sorry if I'm misunderstand!)
@fritzo In case you use MCMC, then discard_warmup=False does what it did: collect warmup samples for you. If you don't use MCMC, then no need for that flag. Currently, the window size goes as follows: 75, 25, 50, 100, 200, 400, 800, 1600, 3200, ..., 50, so either you define a large warmup_steps or we can make a generator to generate window length (rather than store them in a list). In case you want fixed window lengths (rather than 25, 50, 100,...), then we can support a window_length_generator arg for adapt scheme so that you can define your own method to generate window lengths.
@fehiepsi I think your window_length_generator arg would work; an adaptation_schedule list would also work fine for me. As a near-term workaround, I believe I can simply set nuts._adapter._adaptation_schedule after calling nuts.setup(), right?
Yes, I believe that will work perfectly. :smiley:
Closing as this now seems to be working (with workarounds). There's still room to make the workarounds easier, but I am unblocked :smile: