Turing.jl: Distributions not compatible with AD (ForwardDiff/Flux.Tracker)

Created on 6 Dec 2018  路  5Comments  路  Source: TuringLang/Turing.jl

  • [ ] MvNormal (not working with Flux), see #511
  • [x] Poisson, see #614
  • [x] Binomial, see #600
enhancement help-wanted

Most helpful comment

I prefer the solution of fixing these issues on the Distributions.jl and DiffRules side, through either pure Julia implementation of logpdf functions or customised gradient rules. This way these changes can also be helpful for other packages such as Mamba and Stheno.

For the short term, we can keep these (temporary) fixes including #600 #614 in the Turing repo and remove them after these issues are addressed in the Distributions and related AD packages.

All 5 comments

Three possible solutions:

  1. Define our own version of these functions for ForwardDiff and Flux.Tracker types. These should ideally be in Distributions.jl to avoid type piracy, with Requires.jl to avoid adding extra dependencies there,
  2. Make a pure Julia implementation of these distributions and PR them to Distributions.jl; Poisson at least seems to use ccall, or
  3. Use finite difference autodiff by default as Optim.jl does using DiffEqDiffTools.jl.

Make a pure Julia implementation of these distributions and PR them to Distributions.jl; Poisson at least seems to use ccall, or

At least for MvNormal, that the implementation isn't pure Julia actually isn't the issue. The implementation in Distributions.jl places very strong type constraints on the mean vector in particular. I made a start on rectifying this for MvNormal, see this PR, but having now dug slightly further into the internals it looks like a reasonable amount of work would be required to properly sort things out without making breaking changes to Distributions.jl, which presumably isn't an option.

Use finite difference autodiff by default as Optim.jl does using DiffEqDiffTools.jl.

I'm not keen on this because it's inaccurate in general, and not well suited to high-dimensional problems.

Define our own version of these functions for ForwardDiff and Flux.Tracker types. These should ideally be in Distributions.jl to avoid type piracy, with Requires.jl to avoid adding extra dependencies there,

We've already got this for Binomial and #614 provides something sensible for Poisson, so if someone turns that into a PR (@cpfiffer ?) that should be straightforward to resolve. MvNormal will remain slightly tricky without some (potentially slightly awkward) refactoring of Distributions.jl though.

@yebai @mohamed82008 what are your thoughts?

I prefer the solution of fixing these issues on the Distributions.jl and DiffRules side, through either pure Julia implementation of logpdf functions or customised gradient rules. This way these changes can also be helpful for other packages such as Mamba and Stheno.

For the short term, we can keep these (temporary) fixes including #600 #614 in the Turing repo and remove them after these issues are addressed in the Distributions and related AD packages.

The implementation in Distributions.jl places very strong type constraints on the mean vector in particular.

When I am done with the compiler, I can give this a look if you haven't finished it by then.

I'm not keen on this because it's inaccurate in general, and not well suited to high-dimensional problems.

Finite difference and forward differentiation are of the same complexity but the latter is obviously more accurate and can be faster by some non-trivial constant factor. I am not a huge fan of finite difference myself, but it is a fault-proof fallback that lets me focus more on the model and less on generality, at least when playing around with Optim.jl for functions not fully general.

We've already got this for Binomial and #614 provides something sensible for Poisson

I think at least for Poisson, we will also need a similar workaround for ForwardDiff which is the default now.

Overall, I agree with @yebai in that we need a quick fix which can use some type piracy and a long term solution which is more Julian. We can perhaps do a similar workaround for MvNormal for now.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mateuszbaran picture mateuszbaran  路  5Comments

marcoct picture marcoct  路  6Comments

hessammehr picture hessammehr  路  4Comments

skanskan picture skanskan  路  5Comments

yebai picture yebai  路  6Comments