Turing.jl: Multi-threaded sampling with reversediff backend and rdcache gives bad samples

Created on 23 Sep 2020  路  17Comments  路  Source: TuringLang/Turing.jl

The following are the posterior plots for the coin flip turing example
image
As we can see the posterior from multi-threaded sampling with rdcache is not at all close to others
Code to reproduce this

# Load the modules
using Turing, MCMCChains, Distributions, StatsPlots, Random
using ReverseDiff, Memoization, Zygote

println("loaded modules")
# Set the true probability of heads in a coin.
p_true = 0.5

# Iterate from having seen 0 observations to 100 observations.
Ns = 0:100
Random.seed!(12)
data = rand(Bernoulli(p_true), last(Ns))
# define the model
@model coinflip(y) = begin
    p ~ Beta(1, 1)
    N = length(y)
    for n in 1:N
        y[n] ~ Bernoulli(p)
    end
end
# parameters
num_chains = 10
iterations = 100
model_coin = coinflip(data)

# sampling
chain_serial = mapreduce(c -> sample(model_coin, NUTS(0.65), iterations), chainscat, 1:num_chains)
chain_multithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# now enable Zygote backend
Turing.setadbackend(:zygote)
chain_zymultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# now enable ReverseDiff backend without rdcache
Turing.setadbackend(:reversediff)
Turing.setrdcache(false)
chain_rdmultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# with rdcache
Turing.setrdcache(true)
chain_rdcachemultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# plot the combined density of all those chains
density(chain_serial[:p][:], label="serial", legend=:topleft)
density!(chain_multithread[:p][:], label="multi-thread-normal", legend=:topleft)
density!(chain_zymultithread[:p][:], label="multi-thread-zygote", legend=:topleft)
density!(chain_rdmultithread[:p][:], label="multi-thread-reversediff", legend=:topleft)
density!(chain_rdcachemultithread[:p][:], label="multi-thread-rdcache", legend=:topleft)

savefig("mwe.png")

My configuration

  JULIA_NUM_THREADS=4
  [31c24e10] Distributions v0.23.12
  [ced4e74d] DistributionsAD v0.6.9
  [c7f686f2] MCMCChains v4.2.1
  [6fafb56a] Memoization v0.1.4
  [37e2e3b7] ReverseDiff v1.4.3
  [f3b207a7] StatsPlots v0.14.13
  [fce5fe82] Turing v0.14.3
  [e88e6eb3] Zygote v0.5.7

Also posted in the slack channel

Most helpful comment

Turing 0.14.4 is available now which should contain the fix for this problem.

All 17 comments

@mohamed82008

Maybe related to the other ReverseDiff/Memoization issues (see, e.g., https://github.com/TuringLang/Turing.jl/issues/1393). According to its README, Memoization is also not thread-safe (see https://github.com/marius311/Memoization.jl), so I'm not surprised that multithreaded sampling is problematic. Does ReverseDiff work without Memoization (you have to restart the Julia process to make sure there's no memoized stuff floating around anymore)?

@devmotion yes it works fine. I will update the plot and mwe to include it

Then it seems the issue here is that Memoization is not threadsafe

Thank you @devmotion. I was wondering if we could have a warning/error message with multi-threaded sampling when rdcache is enabled. It could be something like this
Warning: Memoization isn't threadsafe. Please don't use rdcache with Multi-Threaded sampling

I think maybe it's possible to fix the problem on our side by using a Dict for memoization and adding the thread id to the list of keys. So I would like to check that first before adding a warning.

Thanks for the suggestion. Could you point me out to any example which does this?

I think this is fixable too. I am actually surprised this is not working at all. Because even a race condition shouldn't affect the result because all the threads are writing the same compiled tape.

Could you point me out to any example which does this?

It's something in Turing that has to be changed (if it fixes the issue).

Hmm ok I think I know what's going on. It's not the memoization, it's ReverseDiff. The compiled tape has cache fields re-used every time the tape is differentiated. Different compiled tapes for different threads should solve the problem here.

Yep, that's what I suggested above :slightly_smiling_face:

Yes I was agreeing with you :)

Closing this as #1414 fixes it.

Turing 0.14.4 is available now which should contain the fix for this problem.

@devmotion The repository still shows 0.14.3 as the latest release. Does it take some time to update?

Tags for the git repo are created automatically at midnight every day. For the Julia package manager it is only relevant that https://github.com/JuliaRegistries/General/pull/21909 was merged - as soon as the registry is updated users are able to update Turing with Pkg.

Thanks for the clarification!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mohamed82008 picture mohamed82008  路  3Comments

scheidan picture scheidan  路  5Comments

yebai picture yebai  路  5Comments

willtebbutt picture willtebbutt  路  4Comments

fredcallaway picture fredcallaway  路  5Comments