So some take away points from the great community call today (thanks @femtomc and @cscherrer for organizing this):
tilde and dot_tilde functions. I don't know how it will turn out but I think I found the right entry point.Omega can support inference techniques currently impossible in Turing but on a limited subset of models, i.e. the ones compatible with the Ω struct, I think. CC: @zennaOmega's inference algorithms on the class of models that it supports. The main change required here will be to define the random variables to be instances of random variable type in Omega. Then the sample function can simply call the cond function or something similar in Omega iiuc. Again more details need to be figured out here but it should be doable as a summer or winter project.~. I think this might be possible today for some functions, namely bijectors. I think we can define a transformed distribution using Bijectors.jl and use that to observe for example 2y instead of y or log(y), etc. A tutorial here may be all we need. CC: @torfjeldeProbabilityModels can be implemented at the distribution level making them available to other PPLs. CC: @chriselrod. I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint. It would be interesting if we can automatically compose and inline rules in ChainRules explicitly at this level to make a bigger chain rule for our "complex distribution". The goal of this is to minimize the time spent in Zygote's type unstable parts by going through a single primitive instead of multiple AD primitives, kind of like a function barrier but for adjoints. CC: @oxinabox.I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint.
How about Soss.@model? ;)
Ad (5): this should be easy (at least to write) in my proposed IR. Or in every system that has a separation of variable names and sampling, so that you can have normal assignement to a named variable:
{x} ~ Normal()
{log_x} = log({x})
It would still have to be supported by the PPL evaluating the model, of course. THis is also the kind of thing that could allow to constrain intermediate transformations -- something I believe @zenna mentioned at some point.
Ad (6): that could even be combinator in Measures.jl, couldn't it?
How about Soss.@model? ;)
Yes! But it would be nice to add the custom adjoint part to that as well and have Chris Elrod implement some heavily optimized logpdf methods for those distributions or the distribution generators.
Ad (6): that could even be combinator in Measures.jl, couldn't it?
Yes either Measures or Soss would be a good place for this.
For multivariate normal, there's a lot of value in representing the covariance
as a tensor product. Something like Kronecker.jl, but without the flattening.
Then
x ~ MvNormal(Σ) |> iid(n)
would reduce to
x ~ NewGaussian(I(n) ⊗ Σ)
and the correlation between the (i,a)th and (j,b)th elements of x
is
0 if i≠j
Σ[a,b] if i=j
and if Σ is k×k, rand(NewGaussian(I(n) ⊗ Σ)) would generate an n×k matrix.
Yes either Measures or Soss would be a good place for this.
I was sort of half-joking originally, but defining a distribution (or a measure, more generally) over named tuples is sort of Soss's whole deal. Also, @tpapp 's TransformVariables is really good at this, if your starting point is a distribution over ℝⁿ.
One reason I'm already a huge fan of Measures.jl is the nice "multiple parametrizations" idea. So actually, there could be
x ~ NewGaussian(diag = Σ)
falling back to
NewGaussian(cov = I(length(Σ)) ⊗ Σ)
or similarly.
Like in Stan, I'd prefer to point people towards using Normal(::AbstractVector, ::LinearAlgebra.AbstractTriangular) when reasonable to skip the Cholesky factorization.
This is an unoptimized but simple and dependency-free (aside from standard library) implementation for calculating the density and gradients:
using LinearAlgebra
# y is an P x N matrix containing N total P-dimensional observations
# m is a vector of length P containing the means
# l is the lower triangular Cholesky factor of the covariance matrix
# dy, dm, and dl are the corresponding gradients
# b is preallocated memory with the same size and shape as y
function ldnorm!(dy, dm, dl, b, y, m, l)
b .= m .- y
LAPACK.trtrs!('L', 'N', 'N', l, b)
lp = dot(b, b); dy .= b
LAPACK.trtrs!('L', 'T', 'N', l, dy)
#@avx should be a little faster, and let you move the 0-assignment into the loop nest
dm .= 0
@inbounds for n ∈ 1:size(y,2)
@simd for p ∈ 1:size(y,1)
dm[p] -= dy[p,n]
end
end
BLAS.gemm!('N', 'T', 1.0, dy, b, 0.0, dl)
N = size(y,2); lp *= -0.5
@inbounds for p ∈ 1:size(dl,1) #@avx instead would be much faster
dl[p,p] -= N * l[p,p]
lp -= N * log(l[p,p])
end
lp
end
The optimized code I wrote that does only a single pass over the memory is much messier and currently broken, although it handles a variety of situations such as were the mean is defined as X * beta (fusing it into the rest of the calculations).
Unfortunately, LAPACK doesn't include a version of trtrs! for solving NxP with a PxP matrix, even though that situation is much easier to SIMD and therefore faster.
Hence why the above uses PxN instead.
I could work on rewriting and cleaning it up, but I think long term I'd rather get LoopVectorization to be able to handle the kind of optimizations and loop dependency structures needed.
One of the major goals of that in the context of a PPL would be to let it convert high level expressions like logpdfs and linear algebra into loop nests, and then let loop optimization code optimize the combined expressions.
One reason I'm already a huge fan of Measures.jl is the nice "multiple parametrizations" idea. So actually, there could be
x ~ NewGaussian(diag = Σ)falling back to
NewGaussian(cov = I(length(Σ)) ⊗ Σ)or similarly.
Isn't the same behaviour achieved by MvNormal without keyword arguments by dispatching on the type of the covariance matrix, such as AbstractVector, UniformScaling, Diagonal, etc.? IMO an annoyance with MvNormal is just that in the end the type of the covariance matrix has to be a AbstractPDMat (although also there optimizations for diagonal matrices etc exist). For the general case, it seems there is a PR to Distributions that adds keyword arguments to constructors: https://github.com/JuliaStats/Distributions.jl/pull/823
Oh wait, sorry I had misread @phipsgabler 's comment. I think it's a little weird how Distributions.jl works in this case. It has MvNormal as a distribution over vectors, but then taking multiple samples magically puts it into a matrix. It just makes the whole thing awkward to reason about. I'd rather have the size as part of the distribution explicitly.
IMO an annoyance with MvNormal is just that in the end the type of the covariance matrix has to be a AbstractPDMat (although also there optimizations for diagonal matrices etc exist).
Yep, I agree. Also, it's often handy to allow positive semidefinite covariance. Tim Holy has an approach to this in PositiveFactorizations.jl that seems promising.
For the general case, it seems there is a PR to Distributions that adds keyword arguments to constructors: JuliaStats/Distributions.jl#823
This looks nice, but it seems to not be really allow for different parameterizations. Instead, it uses keyword arguments to coerce to the standard parameterization.
A big advantage of reparameterization is to make computation cheaper, but this doesn't seem to help in that regard.
It has MvNormal as a distribution over vectors, but then taking multiple samples magically puts it into a matrix.
I agree, there is some inconsistency in how rand works here. However, the API of Distributions also accepts AbstractArrays of AbstractVectors as samples of multivariate distributions, e.g., for evaluating the loglikelihood or in-place sampling with rand!.
Thanks @mohamed82008!
It was mentioned today that one limitation of the current PPLs is that we cannot condition on a "function" of the model's "observations" instead of the "observations" themselves. I am using "observations" here to refer to the the thing on the LHS of ~. I think this might be possible today for some functions, namely bijectors. I think we can define a transformed distribution using Bijectors.jl and use that to observe for example 2y instead of y or log(y), etc. A tutorial here may be all we need.
In Turing we decide what's "observed" based on whether or not it's present in args for the model *and is not missing, right? More specifically https://github.com/TuringLang/DynamicPPL.jl/blob/405546f5f034a9c78e3687e05f3713b998cdbf0c/src/compiler.jl#L6..L19
So the issue I don't think is related to whether or not the transformation is bijective, but rather that we can't handle something like
@model function demo(x)
y = f(x)
y ~ Likelihood()
end
because we don't know that LHS is a function of the inputs.
Also, it's worth pointing out that it's still possible to conditio on observations by using the @addlogpdf!, i.e.
@model function demo(x)
y = f(x)
@addlogprob! logpdf(Likelihood(), y)
end
Alternatively we can make it look nicer by doing something like
@model function demo(x)
y = f(x)
@observe y ~ Likelihood()
end
which simply expands to the above.
I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint.
I really like going in this direction. Distributions for named tuples is something that keeps coming up as something that would be super-useful to have (also added support for transforming something like that in Bijectors.jl recently).
Also, @tpapp's TransformVariables is really good at this, if your starting point is a distribution over ℝⁿ.
Btw, we've also added support for transforming NamedTuple in Bijectors.jl too now, so I'd be keen to hear what you TransformVariables.jl people think of the approach. Would be nice if we could converge on a joint-effort:)
Also, I completely missed meeting! Sorry! Been in the process of moving and getting sorted for starting my studies, so completely forgot about it.
Btw, we've also added support for transforming
NamedTuplein Bijectors.jl too now, so I'd be keen to hear what you TransformVariables.jl people think of the approach. Would be nice if we could converge on a joint-effort:)
THIS IS SO GREAT!!!!
Link: https://github.com/TuringLang/Bijectors.jl/pull/95
Some background for others: It's often important to have the support of a variable depend stochastically on the value of another. It's not just a niche thing either, it comes up any time you want variables to be ordered in some way. And _this_ is a big deal because exchangeability in posteriors is just a mess, causing most samplers to mode-switch all over the place, which in turn screws up most diagnostics.
So yeah, :100:
It would be interesting if we can automatically compose and inline rules in ChainRules explicitly at this level to make a bigger chain rule for our "complex distribution". The goal of this is to minimize the time spent in Zygote's type unstable parts by going through a single primitive instead of multiple AD primitives, kind of like a function barrier but for adjoints. CC: @oxinabox.
Yes, this should be possible.
I have long intended to do basically the same idea to Flux.Chain when uses with only simple layers that we have rules for.
Right now because we don't have https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68
you need to know you have rules for every part; and you need to know you have rules from the start, can't back out later
So what I am proposing is practically a symbolic AD layer on top of ChainRules to write a chain rule for a function that we really really care about and that we can define in global scope (annotated with a macro) once for our users to use over and over again. For example, the logpdf of an MvNormal which is a well-defined mathematical formula but I don't want to sit down and manually derive the reverse and forward chain rules for it. And I don't want Zygote to scratch its head trying to connect all the chain rules at runtime. I want the symbolic AD to do it for me at parse-time generating a single chain rule for the whole function that Zygote can then use as a primitive.
Since we are talking AD and probprog, CC: @ChrisRackauckas.
you need to know you have rules for every part; and you need to know you have rules from the start, can't back out later
I think even with this limitation, the proposal is still useful.
So what I am proposing is practically a symbolic AD layer on top of ChainRules to write a chain rule for a function that we really really care about and that we can define in global scope (annotated with a macro) once for our users to use over and over again. For example, the logpdf of an MvNormal which is a well-defined mathematical formula but I don't want to sit down and manually derive the reverse and forward chain rules for it. And I don't want Zygote to scratch its head trying to connect all the chain rules at runtime. I want the symbolic AD to do it for me at parse-time generating a single chain rule for the whole function that Zygote can then use as a primitive.
For symbolic AD here, you're just asking for ModelingToolkit?
you're just asking for ModelingToolkit?
If MTK can connect ChainRules like lego to make a bigger chain rule, then yes. I really should play more with MTK.
like lego
Or well...like a chain!
It's built on DiffRules right now, but with @oxinabox 's changes to ChainRules it could probably use ChainRules now.
Awesome! If this happens, I think it can solve like 90% of the AD needs if one is willing to work in global scope.
So for a bigger picture, I imagine a world where more functions are annotated with @buildchain or something from MTK to define a complex chain rule symbolically for this function. Once ChainRules can call back to AD, this complex chain rule can make calls to Zygote if no chain rule exists for one of the sub-functions in the main function.
Then the entry point for AD can be Zygote (or ChainRules really at this point). Zygote calls ChainRules making use of MTK-generated ChainRules where possible, which themselves can call back into Zygote if needed. A beautiful harmony of symbolic AD (MTK), runtime-based AD (Zygote) and manual differentiation (ChainRules) all working together to avoid Zygote's type instability!
That's exactly how we're using it in DiffEq, defining Jacobians, vjps, etc. to augment the AD world. The nice thing is that symbolic derivatives are more efficient than AD derivatives in the sparse case: coloring gets close, but is never as efficient as directly defining entries of the matrix. So we mainly use it to build big sparse things, but yes it handles the case the other source-to-source ADs like Zygote are bad at (scalarized stuff).
The nice thing is that symbolic derivatives are more efficient than AD derivatives in the sparse case:
But it also needs to be taught when to give up. For example, if a function is too big. I don't know if MTK unrolls loops or tries to work with it as a single node. If it's the former, then it will be pretty taxing on the compiler and it can take a very long time to parse for long loops.
It doesn't give up. We leave it to user choice. We have cases where 500 second compile time is a reduction in compute time by 80 hours though... so it's really a choice and I think any heuristic is wrong. Instead, a heuristic on that should exist (and be overridable) in Turing, or some high level "glue AD".
Sure but at least it needs to be taught how to give up and we can tell it when :sweat_smile:
For example, setting a time limit.
Use a task and do the computation on a task and kill the task if it goes over a time limit. That's what we plan to do in AutoOptimize.jl.
Nice!
I haven't had the time to get it into a working state, but this much at least does work
julia> using ProbabilityModels, LinearAlgebra, DistributionParameters
julia> sg = quote
σ ~ Gamma(1.0, 0.05)
L ~ LKJ(2.0)
β ~ Normal(10.0) # μ = 0
μ ~ Normal(100.0)
σL = Diagonal(σ) * L
Y₁ ~ Normal( X₁, β, μ', σL )
Y₂ ~ Normal( X₂, β[1:7,:], μ', σL )
end;
A model definition API would be something like @model ModelName begin followed by the contents of the above quote. It would define a struct named ModelName with a field that would hold a named tuple. The named tuple holding a value would make it known data/prior (unless it is Missing), otherwise it'd be a parameter.
The macro would also define methods for functions dispatch on the struct to calculate the logdensity and gradient.
The macro would use read_model to read the model:
julia> m = ProbabilityModels.read_model(sg, Main);
Generating data to create an example named tuple:
julia> N, K₁, K₂, P = 100, 10, 7, 12;
julia> μ = rand(P);
julia> σU = rand(P, (3P)>>1) |> x -> cholesky(Hermitian(x * x')).U;
julia> X₁ = rand(N, K₁);
julia> X₂ = rand(N, K₂);
julia> β = randn(K₁, P);
julia> Y₁ = mul!(rand(N, P) * σU, X₁, β, 1.0, 1.0);
julia> Y₂ = mul!(rand(N, P) * σU, X₂, view(β, 1:K₂, :), 1.0, 1.0);
julia> datant = (
Y₁ = Y₁, Y₂ = Y₂, X₁ = X₁, X₂ = X₂,
μ = RealVector{12}(), β = RealMatrix{10,12}(),
σ = RealVector{12,0.0,Inf}(), L = CorrelationMatrixCholesyFactor{12}()
);
A model description would be inserted into the generated function methods for logdensity and gradients.
These can then use the type information from the named tuple to generate code, e.g. for the logdensity:
julia> ProbabilityModels.preprocess!(m, typeof(datant));
julia> ProbabilityModels.ReverseDiffExpressions.lower(m)
quote
(var"###STACK##POINTER###", var"##targetconstrained#263") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 120, RealArray{Tuple{12}, -Inf, Inf, 0}())
var"##TARGET###0#" = ReverseDiffExpressionsBase.first(var"##targetconstrained#263")
(var"###STACK##POINTER###", var"##targetconstrained#265") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 132, RealArray{Tuple{12}, 0.0, Inf, 0}())
σ = ReverseDiffExpressionsBase.second(var"##targetconstrained#265")
(var"###STACK##POINTER###", var"####TARGET###1##267") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Gamma{(true, false, false, false)}(), σ, 1.0, 0.05, var"##ONE##")
var"##TARGET###1#" = ReverseDiffExpressions.vadd!(var"##TARGET###0#", var"####TARGET###1##267")
(var"###STACK##POINTER###", var"##targetconstrained#266") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 144, CorrelationMatrixCholesyFactor{12}())
L = ReverseDiffExpressionsBase.second(var"##targetconstrained#266")
(var"###STACK##POINTER###", var"####TARGET###2##268") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", LKJ{(true, false, false)}(), L, 2.0, var"##ONE##")
var"##TARGET###2#" = ReverseDiffExpressions.vadd!(var"##TARGET###1#", var"####TARGET###2##268")
(var"###STACK##POINTER###", var"##targetconstrained#264") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 0, RealArray{Tuple{10,12}, -Inf, Inf, 0}())
β = ReverseDiffExpressionsBase.second(var"##targetconstrained#264")
(var"###STACK##POINTER###", var"####TARGET###3##269") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(true, false, false)}(), β, 10.0, var"##ONE##")
var"##TARGET###3#" = ReverseDiffExpressions.vadd!(var"##TARGET###2#", var"####TARGET###3##269")
μ = ReverseDiffExpressionsBase.second(var"##targetconstrained#263")
(var"###STACK##POINTER###", var"####TARGET###4##270") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(true, false, false)}(), μ, 100.0, var"##ONE##")
var"##TARGET###4#" = ReverseDiffExpressions.vadd!(var"##TARGET###3#", var"####TARGET###4##270")
Y₁ = (var"#DATA#").Y₁
X₁ = (var"#DATA#").X₁
(var"###STACK##POINTER###", var"##LHS#259#0#") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", μ)
(var"###STACK##POINTER###", var"##LHS#258") = ReverseDiffExpressions.stack_pointer_call(Diagonal, var"###STACK##POINTER###", σ)
(var"###STACK##POINTER###", σL) = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.:*, var"###STACK##POINTER###", var"##LHS#258", L)
(var"###STACK##POINTER###", var"####TARGET###5##271") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
var"##TARGET###5#" = ReverseDiffExpressions.vadd!(var"##TARGET###4#", var"####TARGET###5##271")
Y₂ = (var"#DATA#").Y₂
X₂ = (var"#DATA#").X₂
var"##LHS#260" = Base.view(β, StaticUnitRange{1, 7}(), :)
(var"###STACK##POINTER###", var"####TARGET###6##272") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(false, false, true, true, true, false)}(), Y₂, X₂, var"##LHS#260", var"##LHS#259", σL, var"##ONE##")
var"##TARGET###6#" = ReverseDiffExpressions.vadd!(var"##TARGET###5#", var"####TARGET###6##272")
var"####TARGET###7##273" = ReverseDiffExpressionsBase.first(var"##targetconstrained#264")
var"##TARGET###7#" = ReverseDiffExpressions.vadd!(var"##TARGET###6#", var"####TARGET###7##273")
var"####TARGET###8##274" = ReverseDiffExpressionsBase.first(var"##targetconstrained#265")
var"##TARGET###8#" = ReverseDiffExpressions.vadd!(var"##TARGET###7#", var"####TARGET###8##274")
var"####TARGET###275" = ReverseDiffExpressionsBase.first(var"##targetconstrained#266")
var"##TARGET##" = ReverseDiffExpressions.vadd!(var"##TARGET###8#", var"####TARGET###275")
vsum(var"##TARGET##")
end
Verbose, but a few notes:
StackPointers.StackPointer so that functions can opt in to be non-allocating (using the stack pointer for working memory, and incremeneting it if they need to allocate). Requires escaping to be undefined behavior/not aloud.constrain, to constrain the input vector, are sorted to try and maximize the number of aligned loads and stores.Similarly, it can differentiate the model method:
julia> dm = ProbabilityModels.ReverseDiffExpressions.differentiate(m);
julia> ProbabilityModels.ReverseDiffExpressions.lower(dm)
quote
(var"###STACK##POINTER###", var"##constrainpullbacktup#302") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 132, RealArray{Tuple{12}, 0.0, Inf, 0}())
var"##targetconstrained#265##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#302")
var"σ##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#265##BAR##")
var"##targetconstrained#265" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#302")
σ = ReverseDiffExpressionsBase.second(var"##targetconstrained#265")
(var"###STACK##POINTER###", var"##rrule_LHS#285") = ReverseDiffExpressions.stack_pointer_call(ChainRules.rrule, var"###STACK##POINTER###", Diagonal, σ)
var"##LHS#258" = Base.first(var"##rrule_LHS#285")
(var"###STACK##POINTER###", var"##temp#289") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", var"##LHS#258")
Y₁ = (var"#DATA#").Y₁
X₁ = (var"#DATA#").X₁
(var"###STACK##POINTER###", var"##constrainpullbacktup#300") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 0, RealArray{Tuple{10,12}, -Inf, Inf, 0}())
var"##targetconstrained#264##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#300")
var"β##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#264##BAR##")
var"##targetconstrained#264" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#300")
β = ReverseDiffExpressionsBase.second(var"##targetconstrained#264")
(var"###STACK##POINTER###", var"##tup#280") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"β##BAR###0#", nothing, nothing), Normal{(true, false, false)}(), β, 10.0, var"##ONE##")
var"##∂tup#281" = ReverseDiffExpressionsBase.second(var"##tup#280")
var"##β##BAR###1##306" = ReverseDiffExpressionsBase.first(var"##∂tup#281")
var"β##BAR###1#" = ReverseDiffExpressions.vadd!(var"β##BAR###0#", var"##β##BAR###1##306")
(var"###STACK##POINTER###", var"##constrainpullbacktup#298") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 120, RealArray{Tuple{12}, -Inf, Inf, 0}())
var"##targetconstrained#263##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#298")
var"μ##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#263##BAR##")
var"##targetconstrained#263" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#298")
μ = ReverseDiffExpressionsBase.second(var"##targetconstrained#263")
(var"###STACK##POINTER###", var"##rrule_LHS#291") = ReverseDiffExpressions.stack_pointer_call(ChainRules.rrule, var"###STACK##POINTER###", LoopVectorization.adjoint, μ)
var"##LHS#259" = Base.first(var"##rrule_LHS#291")
(var"###STACK##POINTER###", var"##constrainpullbacktup#304") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 144, CorrelationMatrixCholesyFactor{12}())
var"##targetconstrained#266##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#304")
var"L##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#266##BAR##")
var"##targetconstrained#266" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#304")
L = ReverseDiffExpressionsBase.second(var"##targetconstrained#266")
(var"###STACK##POINTER###", σL) = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.:*, var"###STACK##POINTER###", var"##LHS#258", L)
(var"###STACK##POINTER###", var"##tup#294") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"β##BAR###1#", nothing, nothing, nothing), Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
var"##∂tup#295" = ReverseDiffExpressionsBase.second(var"##tup#294")
var"##β##BAR###307" = ReverseDiffExpressionsBase.third(var"##∂tup#295")
var"β##BAR##" = ReverseDiffExpressions.vadd!(var"β##BAR###1#", var"##β##BAR###307")
var"##LHS#260##BAR###0#" = Base.view(var"β##BAR##", StaticUnitRange{1, 7}(), :)
Y₂ = (var"#DATA#").Y₂
X₂ = (var"#DATA#").X₂
var"##LHS#260" = Base.view(β, StaticUnitRange{1, 7}(), :)
var"##LHS#259##BAR###0#" = ReverseDiffExpressionsBase.fourth(var"##∂tup#295")
var"σL##BAR###0#" = ReverseDiffExpressionsBase.fifth(var"##∂tup#295")
(var"###STACK##POINTER###", var"##tup#296") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"##LHS#260##BAR###0#", var"##LHS#259##BAR###0#", var"σL##BAR###0#", nothing), Normal{(false, false, true, true, true, false)}(), Y₂, X₂, var"##LHS#260", var"##LHS#259", σL, var"##ONE##")
var"##∂tup#297" = ReverseDiffExpressionsBase.second(var"##tup#296")
var"####LHS#260##BAR###308" = ReverseDiffExpressionsBase.third(var"##∂tup#297")
var"##LHS#260##BAR##" = ReverseDiffExpressions.vadd!(var"##LHS#260##BAR###0#", var"####LHS#260##BAR###308")
(var"###STACK##POINTER###", var"##tup#294#1#") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"β##BAR##", var"##LHS#259##BAR###0#", var"σL##BAR###0#", nothing), Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
var"##σL##BAR###309" = ReverseDiffExpressionsBase.fifth(var"##∂tup#297")
var"σL##BAR##" = ReverseDiffExpressions.vadd!(var"σL##BAR###0#", var"##σL##BAR###309")
(var"###STACK##POINTER###", var"##L##BAR###1##310") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.vmul, var"###STACK##POINTER###", var"##temp#289", var"σL##BAR##")
var"L##BAR###1#" = ReverseDiffExpressions.vadd!(var"L##BAR###0#", var"##L##BAR###1##310")
(var"###STACK##POINTER###", var"##tup#278") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"L##BAR###1#", nothing, nothing), LKJ{(true, false, false)}(), L, 2.0, var"##ONE##")
var"##∂tup#279" = ReverseDiffExpressionsBase.second(var"##tup#278")
var"##L##BAR###311" = ReverseDiffExpressionsBase.first(var"##∂tup#279")
var"L##BAR##" = ReverseDiffExpressions.vadd!(var"L##BAR###1#", var"##L##BAR###311")
(var"###STACK##POINTER###", var"##nothing#305") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#266##BAR##", CorrelationMatrixCholesyFactor{12}())
(var"###STACK##POINTER###", var"##tup#276") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"σ##BAR###0#", nothing, nothing, nothing), Gamma{(true, false, false, false)}(), σ, 1.0, 0.05, var"##ONE##")
var"##∂tup#277" = ReverseDiffExpressionsBase.second(var"##tup#276")
var"##σ##BAR###1##312" = ReverseDiffExpressionsBase.first(var"##∂tup#277")
var"σ##BAR###1#" = ReverseDiffExpressions.vadd!(var"σ##BAR###0#", var"##σ##BAR###1##312")
var"##temp#286" = last(var"##rrule_LHS#285")
(var"###STACK##POINTER###", var"##temp#288") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", L)
(var"###STACK##POINTER###", var"##LHS#258##BAR##") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.vmul, var"###STACK##POINTER###", var"σL##BAR##", var"##temp#288")
(var"###STACK##POINTER###", var"##temp#287") = ReverseDiffExpressions.stack_pointer_call(callunthunk, var"###STACK##POINTER###", var"##temp#286", var"##LHS#258##BAR##")
var"##σ##BAR###313" = LoopVectorization.second(var"##temp#287")
var"σ##BAR##" = ReverseDiffExpressions.vadd!(var"σ##BAR###1#", var"##σ##BAR###313")
(var"###STACK##POINTER###", var"##nothing#303") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#265##BAR##", RealArray{Tuple{12}, 0.0, Inf, 0}())
(var"###STACK##POINTER###", var"##nothing#301") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#264##BAR##", RealArray{Tuple{10,12}, -Inf, Inf, 0}())
(var"###STACK##POINTER###", var"##tup#282") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"μ##BAR###0#", nothing, nothing), Normal{(true, false, false)}(), μ, 100.0, var"##ONE##")
var"##∂tup#283" = ReverseDiffExpressionsBase.second(var"##tup#282")
var"##μ##BAR###1##314" = ReverseDiffExpressionsBase.first(var"##∂tup#283")
var"μ##BAR###1#" = ReverseDiffExpressions.vadd!(var"μ##BAR###0#", var"##μ##BAR###1##314")
var"##temp#292" = last(var"##rrule_LHS#291")
var"####LHS#259##BAR###315" = ReverseDiffExpressionsBase.fourth(var"##∂tup#297")
var"##LHS#259##BAR##" = ReverseDiffExpressions.vadd!(var"##LHS#259##BAR###0#", var"####LHS#259##BAR###315")
(var"###STACK##POINTER###", var"##temp#293") = ReverseDiffExpressions.stack_pointer_call(callunthunk, var"###STACK##POINTER###", var"##temp#292", var"##LHS#259##BAR##")
var"##μ##BAR###316" = LoopVectorization.second(var"##temp#293")
var"μ##BAR##" = ReverseDiffExpressions.vadd!(var"μ##BAR###1#", var"##μ##BAR###316")
(var"###STACK##POINTER###", var"##nothing#299") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#263##BAR##", RealArray{Tuple{12}, -Inf, Inf, 0}())
var"##TARGET###0#" = ReverseDiffExpressionsBase.first(var"##tup#276")
var"####TARGET###1##317" = ReverseDiffExpressionsBase.first(var"##tup#278")
var"##TARGET###1#" = ReverseDiffExpressions.vadd!(var"##TARGET###0#", var"####TARGET###1##317")
var"####TARGET###2##318" = ReverseDiffExpressionsBase.first(var"##tup#280")
var"##TARGET###2#" = ReverseDiffExpressions.vadd!(var"##TARGET###1#", var"####TARGET###2##318")
var"####TARGET###3##319" = ReverseDiffExpressionsBase.first(var"##tup#282")
var"##TARGET###3#" = ReverseDiffExpressions.vadd!(var"##TARGET###2#", var"####TARGET###3##319")
var"####TARGET###4##320" = ReverseDiffExpressionsBase.first(var"##tup#294")
var"##TARGET###4#" = ReverseDiffExpressions.vadd!(var"##TARGET###3#", var"####TARGET###4##320")
var"####TARGET###321" = ReverseDiffExpressionsBase.first(var"##tup#296")
var"##TARGET##" = ReverseDiffExpressions.vadd!(var"##TARGET###4#", var"####TARGET###321")
vsum(var"##TARGET##")
end
It checks whether gradients are considered known (and I'd implement gradients for all supported distributions), otherwise it would fall back to another library.
I don't think the above code would run, because I'm pretty sure I didn't implement all the needed methods.
My longer term ambition here is to be able to convert many distributions and linear algebra routines into equivalent loop-based representations, and then use LoopVectorization to reorder, fuse them, etc (still working on those optimizations).
Cool! LoopVectorization as an optimization pass sounds very interesting. This will require the symbolic version of a chain rule though which is not stored by default in ChainRules. Transforming linear algebra expressions to loop-based expressions sounds like a task Tullio can probably do with a little meta-programming help.
Look at the code generated in https://mtk.sciml.ai/dev/tutorials/auto_parallel/ to get a sense of what MTK is currently doing. The current push is to get non-scalar forms as well, but that's what we have so far.
It'd be cool to add tullio / indexing notation support directly.
If we just add support manually, e.g. want to say sum with 1 argument -> loop that sums the numbers, we could also have predefined LoopSets representing the loop and skip the Expr -> LoopSet conversion step.
Although, it may be better to do the conversion, because Expr -> LoopSet promises a stable API, and I do not have one for the internal representation of the LoopSet, or for constructing one. While it'd be a good idea to create or formalize one for the latter, that hasn't happened yet.
Yes and Tullio + a macro can help with that using a few if statements checking what the types of A and b are in A * b and writing the appropriate loop expression in each branch.
The model definition would define the likelihood/gradient functions as @generated, to delay the compilation until all compile-time info is available. So the dimensionality of A and b should be known to the DSL.
If some of the axes are of static size, it should know those sizes, too.
I think it'd be easier to delay the substitution until then. Otherwise, you'd need a lot of branches, e.g. A or b could also be a Float64 instead of some sort of AbstractArray.
Otherwise, you'd need a lot of branches, e.g. A or b could also be a Float64 instead of some sort of AbstractArray.
A macro can do this no problem. Each loop can be annotated with @avx which can go on and use a generated function to find the sizes. Obviously dead branches will be eliminated by the Julia compiler because they are all type based.
Or maybe the entire body can be annotated to optimize the whole expression for each type scenario. But I imagine this is somewhat out of LoopVectorization's comfort zone.
Most helpful comment
THIS IS SO GREAT!!!!
Link: https://github.com/TuringLang/Bijectors.jl/pull/95
Some background for others: It's often important to have the support of a variable depend stochastically on the value of another. It's not just a niche thing either, it comes up any time you want variables to be ordered in some way. And _this_ is a big deal because exchangeability in posteriors is just a mess, causing most samplers to mode-switch all over the place, which in turn screws up most diagnostics.
So yeah, :100: