Turing.jl: Flux based AD is slow

Created on 7 Sep 2018  Â·  9Comments  Â·  Source: TuringLang/Turing.jl

using Turing, LinearAlgebra

# Sample a precision matrix A from a Wishart distribution
# with identity scale matrix and 250 degrees of freedome
dim2 = 250
A   = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)))
d   = MvNormal(zeros(dim2), A)

# 250-dimensional multivariate normal (MVN)
@model mdemo(d, N) = begin
    Θ = Vector(undef, N)
   for n=1:N
      Θ[n] ~ d
   end
end


ForwardDiff AD

Turing.setadbackend(:forward_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # takes 3 mins

Flux AD

Turing.setadbackend(:reverse_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # EST: more than 3 hrs

Most helpful comment

For a better understanding, the first step would be building a benchmarking script that automatically tests against all basic differentiable continuous distributions in Distributions.jl. Here is a possible benchmarking set

For each differentiable distribution, run and time the following tasks:

  • native logpdf function, i.e. the one from Distributions.jl
  • the native AD, i.e. ADBackend.gradient(Distributions.logpdf, x))
  • the Turing version of logpdf, e.g. the one in ad.jl
  • The Turing version of gradient.
  • Turing's HMC sampler.

All 9 comments

Ps. these runtimes are using the master branch.

@willtebbutt @xukai92

In this particular case, this is likely due to the way Flux's AD currently interacts with cholesky and such like. It's reasonable to expect that this poor performance won't be universal. To get an idea of how general this problem is, we really need a large suite of benchmarks that range from large models through the models, such as the one you've constructed here that only uses a single component distribution.

Thanks @willtebbutt. Here are a few more test/benchmarking examples for AD. These examples are relatively simple, but hopefully will be useful to catch issues of our current AD and HMC implementations.

using Distributions, LinearAlgebra
using Turing

@model mdemo(d, N) = begin
    Θ = Vector(undef, N)
   for n=1:N
      Θ[n] ~ d
   end
end

# Example 1

dim2 = 250
A   = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)))
d   = MvNormal(zeros(dim2), A)

Turing.setadbackend(:forward_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # takes 3 mins

Turing.setadbackend(:reverse_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # EST: more than 3 hrs

# Example 2
dim2 = 250
d = Normal(0, 10*rand())

Turing.setadbackend(:forward_diff)
chain = sample(mdemo(d, dim2), HMC(5000, 0.1, 5)) # takes ~15 mins

Turing.setadbackend(:reverse_diff)
chain = sample(mdemo(d, dim2), HMC(5000, 0.1, 5)) # takes ~2mins

Last I checked, Flux AD was not type stable. Not sure if that's a huge performance issue or not since Mike Innes should be aware of that but I found that the only 2 reverse AD options that give type stable gradient functions are ReverseDiff and Zygote. Making a benchmark suite for all AD options for HMC seems like an interesting (research?) idea. My understanding is that Zygote, Capstan or a hybrid will replace the current Flux AD's internals at some point without changing the interface. So the advantage of using Flux's AD is not necessarily speed as much as it is API stability, compatibility with Flux and "guaranteed" performance improvement as internals get swapped out in the hopefully not too far away future.

@MikeInnes is this a known issue?

For a better understanding, the first step would be building a benchmarking script that automatically tests against all basic differentiable continuous distributions in Distributions.jl. Here is a possible benchmarking set

For each differentiable distribution, run and time the following tasks:

  • native logpdf function, i.e. the one from Distributions.jl
  • the native AD, i.e. ADBackend.gradient(Distributions.logpdf, x))
  • the Turing version of logpdf, e.g. the one in ad.jl
  • The Turing version of gradient.
  • Turing's HMC sampler.

I don't know Turing well, so it would be useful if you could break this down to plain-Julia example and I can take a closer look. Reverse mode certainly has more overhead than forward for problems with a very small input dimension, so if that's the case and you're calling gradient millions of times, then those numbers could be reasonable. Other than that, I've generally found Tracker to be faster than other AD options.

There's a good chance @willtebbutt's guess is right though. Flux's current set of gradient definitions is relatively ML-focused, so we don't have things like cholesky; those things will either break or (perhaps worse) fall back to a scalar implementation, which throws you off a performance cliff. The good news is that if that is the issue, it's very easy to add those definitions.

Update:

using Turing, LinearAlgebra
using FLux.Tracker

# Sample a precision matrix A from a Wishart distribution
# with identity scale matrix and 250 degrees of freedome
dim2 = 250
A   = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)))
d   = MvNormal(zeros(dim2), A)

xp1 = [param(0) for i=1:dim2]
logpdf(d, xp1)  # This works

xp2 = param(zeros(dim2))
logpdf(d,  xp2) # This doesn't work, see below for error message.

Returns error

julia> logpdf(d,  xp2)
ERROR: MethodError: no method matching invquad(::PDMats.PDMat{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}})
Closest candidates are:
  invquad(::PDMats.PDMat, ::Union{DenseArray{T,1}, ReinterpretArray{T,1,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{T,1,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, SubArray{T,1,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}} where T) at /Users/hg344/.julia/packages/PDMats/mL7bX/src/pdmat.jl:67
  invquad(::PDMats.AbstractPDMat{T<:Real}, ::Union{DenseArray{S<:Real,2}, ReinterpretArray{S<:Real,2,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{S<:Real,2,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, SubArray{S<:Real,2,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}}) where {T<:Real, S<:Real} at /Users/hg344/.julia/packages/PDMats/mL7bX/src/generics.jl:101
Stacktrace:
 [1] sqmahal(::MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/hg344/.julia/packages/Distributions/WHjOk/src/multivariate/mvnormal.jl:263
 [2] _logpdf(::MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/hg344/.julia/packages/Distributions/WHjOk/src/multivariate/mvnormal.jl:113
 [3] logpdf(::MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/hg344/.julia/packages/Distributions/WHjOk/src/multivariates.jl:175
 [4] top-level scope at none:0

Closed in favour of #615

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yebai picture yebai  Â·  5Comments

hessammehr picture hessammehr  Â·  4Comments

xukai92 picture xukai92  Â·  3Comments

xukai92 picture xukai92  Â·  5Comments

fredcallaway picture fredcallaway  Â·  5Comments