Turing.jl: Is init_theta keyword working?

Created on 13 Apr 2021  路  13Comments  路  Source: TuringLang/Turing.jl

I've been trying to start my samplers in "better" regions of the parameter space, since it's a bit slow to converge otherwise on certain parameters. But from what I can tell, whenever I pass init_theta as per the docs, whatever values I have there are just ignored.

MWE:

using Turing

# Model definition.
@model demo(x) = begin
    u ~ MvNormal(zeros(2), ones(2))
    x ~ MvNormal(u, ones(2))
end

a = ones(2)
start_vals = [-10, -10]
m1 = sample(demo(a), MH(diagm([0.001, 0.001.])), MCMCThreads(), 10, 4, init_theta=start_vals)
m2 = sample(demo(a), HMC(0.0001, 1), MCMCThreads(), 10, 4, init_theta=start_vals)

Array(m1)
Array(m2)

Here, I start the sampler in a far-away place. The MH proposals in m1 are tiny, so it should be taking quite a while to get away from the region around (-10,-10); apparently, the starting values are just being ignored. I wasn't sure if it was just a MH issue, so I try HMC as well in the above example, and it appears to also ignore the supplied starting values.

After digging around some more, I tried changing it to the keyword init_params instead of init_theta. This works for the MH code above, but not for HMC (oddly enough, since the HMC sampling code is where I found the init_params keyword in the first place. I also tried NUTS, and didn't see any evidence init_params was being respected there, either (although it's much harder to tell with an algorithm like NUTS, that can move away from starting values so quickly).

So, it's not clear to me if this is just a case of the docs being in need of an update, or if there's something else wrong.

And in any case, if init_theta isn't a valid keyword anymore, perhaps supplying such a keyword should give an error or something? I just spent several days trying to debug code that I thought had some subtle identification error that was making my sampler "blow up" right at the start of sampling, when really it was just sample silently ignoring starting values set by a deprecated(?) keyword.

FWIW, being able to catch stuff like this easily is one reason I strongly favor recording the initial values of samplers as part of the chain itself (per the discussion in #1282), something I always do if writing samplers by hand.

Most helpful comment

The fix is available in the latest release of DynamicPPL. It should be installed automatically if you update your packages.

All 13 comments

Unfortunately, the documentation is largely outdated. It is called init_params (e.g., https://github.com/TuringLang/Turing.jl/blob/602aa5f23cde2985bf0d7d1b44d6c4f7c265422c/src/inference/hmc.jl#L148).

Edit: Ah, I see you already noticed this.

Did you keep the warmup phase in NUTS? Otherwise probably it is difficult to check if your sampler started from your initial setting.

Regarding the implementation: it is actually not part of Turing.jl but defined in DynamicPPL.jl (https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/sampler.jl#L76). Maybe at some point it could be moved to AbstractMCMC.jl even. IIRC HMC just has to be handled in a special way since we set some default values if they are not provided by the user with the sampler, and therefore the keyword argument shows up there (and in the emcee sampler AFAIK) but not in any other sampler in Turing.jl.

With NUTS, I first used a warmup-phase of 1, but even with a warm-up phase of 0, it still seems to ignore the supplied values.

Can you add a print statement in https://github.com/TuringLang/Turing.jl/blob/602aa5f23cde2985bf0d7d1b44d6c4f7c265422c/src/inference/hmc.jl#L157 and check if your initial values show up there? If not there's a problem with how keyword arguments are forwarded and/or how we forward calls to DynamicPPL and AbstractMCMC. Otherwise it works as expected (but maybe not as it should) since currently we perform one initial step with AdvancedHMC before returning the first sample (https://github.com/TuringLang/Turing.jl/blob/602aa5f23cde2985bf0d7d1b44d6c4f7c265422c/src/inference/hmc.jl#L194-L220).

I did a println(theta) at line 158. The result was not the values I passed -- should have been [-10,-10], instead I got [1.563217652248911, -0.5123821037824636].

Oh, I assume the problem is that SampleFromUniform that is used for sampling the initial values and rerunning the model for computing the log joint probability when we set the provided initial values is very weird: it always overwrites existing samples (https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/context_implementations.jl#L127-L128). I.e., the initial values that we set in https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/sampler.jl#L77 are resampled in https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/sampler.jl#L80.

It would make sense that the error does only occur with HMC since AFAIK SampleFromUniform is only used by HMC.

IMO we should change the behaviour of SampleFromUniform and only resample values if they are clearly marked to be resampled.

Alternatively, one could always resample when using a sampler and perform evaluation only with a dedicated evaluation context (would require some refactoring along the lines of https://github.com/TuringLang/DynamicPPL.jl/issues/80 probably). One has to be a bit careful though and probably special case the evaluation context for every sampler since depending on the sampler we store the log-likelihood, the log joint probability, or other things in varinfo.logp (and the whole point of rerunning the model after setting the user-provided initial values is to ensure that varinfo.logp is consistent with the variables and would be equivalent to the value that one would get if one would have sampled the variables).

My hypothesis in https://github.com/TuringLang/Turing.jl/issues/1588#issuecomment-818709759 is correct, the problem is that SampleFromUniform resamples the user-provided initial values. Printing the values in DynamicPPL shows that the initial values are set correctly and commenting out https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/sampler.jl#L80 fixes the problem for HMC (but, of course, that's not something we want to do).

Sounds like this is the same problem that @mateuszbaran had #1563. Someone was also asking about this in the slack channel a few weeks ago if I remember.

Yes, that looks like the same problem. Unfortunately all I can tell is that commenting out that line didn't improve sampling from my model (but fixing one problem with my data and giving better prior distributions did help).

I opened a PR with a temporary workaround in https://github.com/TuringLang/DynamicPPL.jl/pull/232.

The fix is available in the latest release of DynamicPPL. It should be installed automatically if you update your packages.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

willtebbutt picture willtebbutt  路  4Comments

yebai picture yebai  路  4Comments

fredcallaway picture fredcallaway  路  5Comments

scheidan picture scheidan  路  5Comments

hessammehr picture hessammehr  路  4Comments