using Turing, LinearAlgebra
# Sample a precision matrix A from a Wishart distribution
# with identity scale matrix and 250 degrees of freedome
dim2 = 250
A = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)))
d = MvNormal(zeros(dim2), A)
# 250-dimensional multivariate normal (MVN)
@model mdemo(d, N) = begin
Θ = Vector(undef, N)
for n=1:N
Θ[n] ~ d
end
end
ForwardDiff AD
Turing.setadbackend(:forward_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # takes 3 mins
Flux AD
Turing.setadbackend(:reverse_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # EST: more than 3 hrs
Ps. these runtimes are using the master branch.
@willtebbutt @xukai92
In this particular case, this is likely due to the way Flux's AD currently interacts with cholesky and such like. It's reasonable to expect that this poor performance won't be universal. To get an idea of how general this problem is, we really need a large suite of benchmarks that range from large models through the models, such as the one you've constructed here that only uses a single component distribution.
Thanks @willtebbutt. Here are a few more test/benchmarking examples for AD. These examples are relatively simple, but hopefully will be useful to catch issues of our current AD and HMC implementations.
using Distributions, LinearAlgebra
using Turing
@model mdemo(d, N) = begin
Θ = Vector(undef, N)
for n=1:N
Θ[n] ~ d
end
end
# Example 1
dim2 = 250
A = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)))
d = MvNormal(zeros(dim2), A)
Turing.setadbackend(:forward_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # takes 3 mins
Turing.setadbackend(:reverse_diff)
chain = sample(mdemo(d, 1), HMC(5000, 0.1, 5)) # EST: more than 3 hrs
# Example 2
dim2 = 250
d = Normal(0, 10*rand())
Turing.setadbackend(:forward_diff)
chain = sample(mdemo(d, dim2), HMC(5000, 0.1, 5)) # takes ~15 mins
Turing.setadbackend(:reverse_diff)
chain = sample(mdemo(d, dim2), HMC(5000, 0.1, 5)) # takes ~2mins
Last I checked, Flux AD was not type stable. Not sure if that's a huge performance issue or not since Mike Innes should be aware of that but I found that the only 2 reverse AD options that give type stable gradient functions are ReverseDiff and Zygote. Making a benchmark suite for all AD options for HMC seems like an interesting (research?) idea. My understanding is that Zygote, Capstan or a hybrid will replace the current Flux AD's internals at some point without changing the interface. So the advantage of using Flux's AD is not necessarily speed as much as it is API stability, compatibility with Flux and "guaranteed" performance improvement as internals get swapped out in the hopefully not too far away future.
@MikeInnes is this a known issue?
For a better understanding, the first step would be building a benchmarking script that automatically tests against all basic differentiable continuous distributions in Distributions.jl. Here is a possible benchmarking set
For each differentiable distribution, run and time the following tasks:
logpdf function, i.e. the one from Distributions.jlADBackend.gradient(Distributions.logpdf, x))logpdf, e.g. the one in ad.jlI don't know Turing well, so it would be useful if you could break this down to plain-Julia example and I can take a closer look. Reverse mode certainly has more overhead than forward for problems with a very small input dimension, so if that's the case and you're calling gradient millions of times, then those numbers could be reasonable. Other than that, I've generally found Tracker to be faster than other AD options.
There's a good chance @willtebbutt's guess is right though. Flux's current set of gradient definitions is relatively ML-focused, so we don't have things like cholesky; those things will either break or (perhaps worse) fall back to a scalar implementation, which throws you off a performance cliff. The good news is that if that is the issue, it's very easy to add those definitions.
Update:
using Turing, LinearAlgebra
using FLux.Tracker
# Sample a precision matrix A from a Wishart distribution
# with identity scale matrix and 250 degrees of freedome
dim2 = 250
A = rand(Wishart(dim2, Matrix{Float64}(I, dim2, dim2)))
d = MvNormal(zeros(dim2), A)
xp1 = [param(0) for i=1:dim2]
logpdf(d, xp1) # This works
xp2 = param(zeros(dim2))
logpdf(d, xp2) # This doesn't work, see below for error message.
Returns error
julia> logpdf(d, xp2)
ERROR: MethodError: no method matching invquad(::PDMats.PDMat{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}})
Closest candidates are:
invquad(::PDMats.PDMat, ::Union{DenseArray{T,1}, ReinterpretArray{T,1,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{T,1,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, SubArray{T,1,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}} where T) at /Users/hg344/.julia/packages/PDMats/mL7bX/src/pdmat.jl:67
invquad(::PDMats.AbstractPDMat{T<:Real}, ::Union{DenseArray{S<:Real,2}, ReinterpretArray{S<:Real,2,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, ReshapedArray{S<:Real,2,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray}, SubArray{S<:Real,2,A,I,L} where L where I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, AbstractCartesianIndex},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, ReshapedArray{T,N,A,MI} where MI<:Tuple{Vararg{SignedMultiplicativeInverse{Int64},N} where N} where A<:Union{ReinterpretArray{T,N,S,A} where S where A<:Union{SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, SubArray{T,N,A,I,true} where I<:Tuple{AbstractUnitRange,Vararg{Any,N} where N} where A<:DenseArray where N where T, DenseArray} where N where T, DenseArray}}) where {T<:Real, S<:Real} at /Users/hg344/.julia/packages/PDMats/mL7bX/src/generics.jl:101
Stacktrace:
[1] sqmahal(::MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/hg344/.julia/packages/Distributions/WHjOk/src/multivariate/mvnormal.jl:263
[2] _logpdf(::MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/hg344/.julia/packages/Distributions/WHjOk/src/multivariate/mvnormal.jl:113
[3] logpdf(::MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /Users/hg344/.julia/packages/Distributions/WHjOk/src/multivariates.jl:175
[4] top-level scope at none:0
Closed in favour of #615
Most helpful comment
For a better understanding, the first step would be building a benchmarking script that automatically tests against all basic differentiable continuous distributions in
Distributions.jl. Here is a possible benchmarking setFor each differentiable distribution, run and time the following tasks:
logpdffunction, i.e. the one fromDistributions.jlADBackend.gradient(Distributions.logpdf, x))logpdf, e.g. the one inad.jl