Turing.jl: Jags-style samplers

Created on 9 Sep 2019  ยท  15Comments  ยท  Source: TuringLang/Turing.jl

I am opening this feature request after a discussion on Slack regarding the performance of PG. For continuous parameters in particular, particles tend to get stuck. It's not clear to me to what extent this may happen for discrete parameters. Here is an example:

using Turing,Random,StatsPlots
@model model(y) = begin
    ฮผ ~ Normal(0,10)
    ฯƒ ~ Truncated(Cauchy(0,1),0,Inf)
    for j in 1:length(y)
        y[j] ~ Normal(ฮผ,ฯƒ)
    end
end
Random.seed!(3431)
y = rand(Normal(0,1),50)
chain = sample(model(y),PG(40,4000))
chain = chain[2001:end,:,:]
println(chain)
plot(chain)

fig4

This required about 2.5 minutes to run on my system. Increasing the number of particles to 80 did not help much.

As a basis for comparison, here is the same model coded in Jags:

ENV["JAGS_HOME"] = "usr/bin/jags" #your path here
using Jags, StatsPlots, Random, Distributions
#cd(@__DIR__)
ProjDir = pwd()
Random.seed!(3431)

y = rand(Normal(0,1),50)

Model = "
model {
      for (i in 1:length(y)) {
            y[i] ~ dnorm(mu,sigma);
      }
      mu  ~ dnorm(0, 1/sqrt(10));
      sigma  ~ dt(0,1,1) T(0, );
  }
"

monitors = Dict(
  "mu" => true,
  "sigma" => true,
  )

jagsmodel = Jagsmodel(
  name="Gaussian",
  model=Model ,
  monitor=monitors,
  ncommands=4, nchains=1,
  #deviance=true, dic=true, popt=true,
  pdir=ProjDir
  )

println("\nJagsmodel that will be used:")
jagsmodel |> display

data = Dict{String, Any}(
  "y" => y,
)

inits = [
  Dict("mu" => 0.0,"sigma" => 1.0,
  ".RNG.name" => "base::Mersenne-Twister")
]

println("Input observed data dictionary:")
data |> display
println("\nInput initial values dictionary:")
inits |> display
println()
#######################################################################################
#                                 Estimate Parameters
#######################################################################################
sim = jags(jagsmodel, data, inits, ProjDir)
sim = sim[5001:end,:,:]
plot(sim)

jags

This required about .267 seconds on my machine, which is nearly a 600 fold speed up.

Here is a second example we found to perform poorly:

using Distributions
using Turing

n=500
p=20
X = rand(Float64, (n,p))
beta=[2.0 .^ (-i) for i in 0:(p-1)]
alpha=0
sigma=0.7
eps=rand(Normal(0, sigma), n)
y = alpha .+ X * beta + eps;

@model model(X, y) = begin

    n, p = size(X)

    alpha ~ Normal(0,1)
    sigma ~ Truncated(Cauchy(0,1),0,Inf)
    sigma_beta ~ Truncated(Cauchy(0,1),0,Inf)
    pind ~ Beta(2,8)

    beta = tzeros(Float64, p)
    betaT = tzeros(Float64, p)
    ind = tzeros(Int, p)

    for j in 1:p
        ind[j] ~ Bernoulli(pind)
        betaT[j] ~ Normal(0,sigma_beta)  # random effect
        beta[j] = ind[j] * betaT[j]
    end

    mu = tzeros(Float64, n)

    for i in 1:n
        mu[i] = alpha + X[i,:]' * beta 
        y[i] ~ Normal(mu[i], sigma)
    end

end

steps = 4000
chain = sample(model(X,y),PG(40,steps))

I think this would be a very useful addition. By adding Jags-style samplers, we could have the speed of Jags without the severe limitations of Jags. This would also provide Turing with an ability that Stan struggles to perform.

discussion

Most helpful comment

That might be a good one yes. Refactoring Gibbs sampling using traits might also be a good one. Personally though, my availability this summer might be a bit limited because I am having my wedding in July. So it will be hard to commit to any work in July. Let's see. I can write the proposal for now and let's worry about mentoring logistics later.

All 15 comments

Thanks for opening this issue. This is already on the priority list of Turing team. Adding support for handling discrete variables, and combining different sampling algorithms to form more efficient inference engines are among the original motivations of Turing. However, the challenge is not from the inference side. We can quickly implement samplers currently available in JAGS. The real barrier is the compiler, which currently only tracks values of random variables, but ignores their dependencies. This lack of dependency information makes it hard to derive Gibbs conditionals automatically.

One reason why it's harder to implement dependency tracking in Turing, compared to other libraries like JAGS, Mamba.jl, is that Turing takes a tracing approach (aka define-by-run) for defining models. Libraries like JAGS take a different approach, which is based on a scripting (aka define-and-run) approach. The tracing approach is argubly more general and user-friendly: 1) it supports models with varying dimensionality, like Dirichlet processes; 2) it makes models easier to implement and debug.
Unfortunately, these properties also mean that the graphical model underlying a Turing program can be dynamic, i.e. both edges and the total number of nodes could vary during inference.

To address these issues, add support for JAGS style inference and other advanced inference methods in Turing, we have started several projects. Below is some related ongoing PRs/work:

  • @cpfiffer is working on a significant PR https://github.com/TuringLang/Turing.jl/pull/793 which will bring Turing one step closer towards plug-and-play inference.

  • Over the summer, @trappmartin @phipsgabler and myself have started re-implementing the Turing compiler to support dynamic dependency tracking. This new compiler should enable JAGS style Gibbs sampling, and other advanced inference methods like messaging passing algorithms. If you're interested, pls take a look at the following repo
    https://github.com/phipsgabler/DynamicComputationGraphs.jl
    and post your thoughts here.

As a side note, there is also an alternative approach to avoid dependency tracking. It requires the user to write their models in several smaller Turing programs, and run a different sampler on each Turing program, in a way similar to JAGS, then "glue" together inference results from these smaller models. It only requires a relatively small amount of work to support this approach after the MCMC Interface PR (https://github.com/TuringLang/Turing.jl/pull/793) is merged. I don't really like this approach because it requires the user to break one model into several smaller programs. But it loosely fits into the "models as code" philosophy, in the sense that it encourages modularity in modelling, and encourages building complex models by composing common modelling parts if possible.

Pls, let me know if any parts of the above plan are unclear, and/or if you have any thoughts and suggestions!

Thank you for taking the time to write a detailed reply. It looks like some real exciting new features are on the horizon. I realize that this might be difficult to answer, but do you have a rough idea of when Jags-style sampling might be implemented? Approximately, six months, or a year? This will help me plan and prioritize some projects, including the benchmarking work I am doing with Rob. Thanks!

We're targeting 3-6 months, but it might take a bit longer.

For the record, the second example in the initial pull request (an important case for my work) takes about 2 hours to run and the trace plots of some parameters look as follows:

Screenshot 2019-09-09 at 14 26 45

On a related note, I also want to point out that the Hidden Markov Model from the tutorial produces very low effective sample size, consistently less than 10.

Summary Statistics

โ”‚ Row โ”‚ parameters โ”‚ mean      โ”‚ std         โ”‚ naive_se    โ”‚ mcse       โ”‚ ess     โ”‚ r_hat    โ”‚
โ”‚     โ”‚ Symbol     โ”‚ Float64   โ”‚ Float64     โ”‚ Float64     โ”‚ Float64    โ”‚ Any     โ”‚ Any      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ 1   โ”‚ T[1][1]    โ”‚ 0.60352   โ”‚ 0.0305084   โ”‚ 0.00096476  โ”‚ 0.00964355 โ”‚ 4.23888 โ”‚ 1.59418  โ”‚
โ”‚ 2   โ”‚ T[1][2]    โ”‚ 0.309543  โ”‚ 0.0206837   โ”‚ 0.000654076 โ”‚ 0.00630832 โ”‚ 6.18149 โ”‚ 1.26903  โ”‚
โ”‚ 3   โ”‚ T[1][3]    โ”‚ 0.086937  โ”‚ 0.0135024   โ”‚ 0.000426984 โ”‚ 0.00439716 โ”‚ 4.01606 โ”‚ 1.88707  โ”‚
โ”‚ 4   โ”‚ T[2][1]    โ”‚ 0.706185  โ”‚ 0.0210481   โ”‚ 0.0006656   โ”‚ 0.00628791 โ”‚ 6.92471 โ”‚ 1.04481  โ”‚
โ”‚ 5   โ”‚ T[2][2]    โ”‚ 0.253944  โ”‚ 0.0181811   โ”‚ 0.000574936 โ”‚ 0.00547099 โ”‚ 7.6714  โ”‚ 0.999274 โ”‚
โ”‚ 6   โ”‚ T[2][3]    โ”‚ 0.0398708 โ”‚ 0.00523195  โ”‚ 0.000165449 โ”‚ 0.00158937 โ”‚ 4.01606 โ”‚ 2.09816  โ”‚
โ”‚ 7   โ”‚ T[3][1]    โ”‚ 0.430283  โ”‚ 0.0183518   โ”‚ 0.000580334 โ”‚ 0.00535442 โ”‚ 4.60138 โ”‚ 1.64891  โ”‚
โ”‚ 8   โ”‚ T[3][2]    โ”‚ 0.450252  โ”‚ 0.0186215   โ”‚ 0.000588864 โ”‚ 0.00555454 โ”‚ 4.5526  โ”‚ 1.5442   โ”‚
โ”‚ 9   โ”‚ T[3][3]    โ”‚ 0.119464  โ”‚ 0.00988752  โ”‚ 0.000312671 โ”‚ 0.0029534  โ”‚ 7.08331 โ”‚ 1.00227  โ”‚
โ”‚ 10  โ”‚ m[1]       โ”‚ 2.30276   โ”‚ 0.16282     โ”‚ 0.00514881  โ”‚ 0.0352831  โ”‚ 6.55215 โ”‚ 1.03373  โ”‚
โ”‚ 11  โ”‚ m[2]       โ”‚ 0.991943  โ”‚ 0.0645865   โ”‚ 0.00204241  โ”‚ 0.0153109  โ”‚ 10.7751 โ”‚ 1.04687  โ”‚
โ”‚ 12  โ”‚ m[3]       โ”‚ 0.159171  โ”‚ 0.148796    โ”‚ 0.00470534  โ”‚ 0.0471829  โ”‚ 4.01606 โ”‚ 1.76961  โ”‚
โ”‚ 13  โ”‚ s[1]       โ”‚ 1.994     โ”‚ 0.0772656   โ”‚ 0.00244335  โ”‚ 0.006      โ”‚ 6.49518 โ”‚ 1.00505  โ”‚
โ”‚ 14  โ”‚ s[2]       โ”‚ 1.991     โ”‚ 0.113719    โ”‚ 0.0035961   โ”‚ 0.009      โ”‚ 7.81415 โ”‚ 1.00528  โ”‚
โ”‚ 15  โ”‚ s[3]       โ”‚ 1.993     โ”‚ 0.0834144   โ”‚ 0.00263779  โ”‚ 0.007      โ”‚ 6.96785 โ”‚ 1.00607  โ”‚
โ”‚ 16  โ”‚ s[4]       โ”‚ 1.991     โ”‚ 0.0944877   โ”‚ 0.00298796  โ”‚ 0.009      โ”‚ 6.4939  โ”‚ 1.00811   

Yeah. PG seems to perform poorly on that model. I suppose the number of samples could be increased, but it would slow it down more.

Hi @yebai. Just out of curiosity, I was wondering if there are any status updates?

Hi @itsdfish, there are promising progress towards this goal, e.g.

  • @mohamed82008 did a significant refactoring of Turing's compiler #965
  • the work on DynamicPPL, which is the new home to Turing's compiler and tracing data structures #1042
  • @devmotion recently implemented elliptical slice sampling #991, which enables ESS within Turing's Gibbs sampler

These PRs are gradually paving the way for a JAGS-style sampler. There is still one important missing part, being able to represent and manipulate dynamic computational graphs to automatically derive Gibbs conditionals. It is quite hard to implement this in a generic way and @phipsgabler is still working on this in DynamicComputationGraphs.jl.

Also, @mohamed82008 found a way to use caching to speed up Gibbs substantially. This has a similar spirit to DynamicComputationGraphs in terms of saving unnecessary computation in Gibbs. See performance tips. We might automate this caching, or make it substantially easier to use (in fact, it's already easy to use) to provide efficient JAGS-style sampling.

Perhaps improving compiler to automate caching could be an interesting GSoC project?
@mohamed82008 @cpfiffer

That might be a good one yes. Refactoring Gibbs sampling using traits might also be a good one. Personally though, my availability this summer might be a bit limited because I am having my wedding in July. So it will be hard to commit to any work in July. Let's see. I can write the proposal for now and let's worry about mentoring logistics later.

Congrats @mohamed82008!

Thanks :)

As a short update. @phipsgabler is working on a PR for Turing, implementing an interface for Gibbs conditionals. Feel free to comment and help if you feel like it.
See: https://github.com/TuringLang/Turing.jl/pull/1172

And in the near future, there will even be a JAGS style Gibbs sampler. Which needs a bit more work but it seems that Philipp is doing good progress.

Hello-

Out of curiosity, can you provide a status update? Thanks!

Sure.

We recently merged the PR that allows users to use custom Gibbs conditionals and Philipp is currently finishing up his work on AutoGibbs, which automatically computes Gibbs conditionals for discrete RVs in any Turing model. The AutoGibbs code passes the test for simpler models atm. and will hopefully work for dynamic models soon too. Shouldn't take too long anymore.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

skanskan picture skanskan  ยท  5Comments

mohamed82008 picture mohamed82008  ยท  4Comments

yebai picture yebai  ยท  6Comments

hessammehr picture hessammehr  ยท  4Comments

trappmartin picture trappmartin  ยท  3Comments