This is related to issue #1467 and the Discourse discussion "Turing: Observing sum of two random dice" (https://discourse.julialang.org/t/turing-observing-sum-of-two-random-dice/50486) where I (among other things) wondered about an @observe macro to simplify observing a condition in a model
In the following function I've integrated what I know of "manual observing".
#
# Observe a variable.
# Example:
# # ...
# observe(_sampler, _varinfo, d1+d2 == ss)
# or
# observe(_sampler, _varinfo, d1+d2 == ss) && return
function observe(_sampler,_varinfo,cond)
if !cond
if _sampler isa Turing.Sampler{<:Union{PG,SMC}}
produce(-Inf)
else
Turing.@addlogprob! -Inf
end
else
if _sampler isa Turing.Sampler{<:Union{PG,SMC}}
produce(0.0)
end
end
# return !cond so we can use `observe(...) && return` to return quickly from the model
return !cond
end
Here's a simple test model:
@model two_dice(ss) = begin
d1 ~ DiscreteUniform(1, 6)
d2 ~ DiscreteUniform(1, 6)
observe(_sampler,_varinfo, d1+d2 == ss) && return
end
model = two_dice(4)
chains = sample(model, MH(), 1000)
# chains = sample(model, PG(20), 1000)
# chains = sample(model, SMC(100), 1000)
display(chains)
# Check the values of d1 + d2 (should be 4)
display(unique(chains[:d1].+chains[:d2]))
This now works for MH and due to the special handing of PG and SMC, it also works for these two samplers.
I.e. via post processing we see that d1+d2 is now (nearly) always 4, as expected.
The drawback of this approach is that one have to add the parameters _sampler and _varinfo which is a nuisance, e.g. one get warnings Warning: you are using the internal variable '_sampler' and Warning: you are using the internal variable '_varinfo'.
Is there are a better way to write this function (or an @observe macro) so one only have to write it as observe(d1+d2 == ss)?
(My macro foo is currently not enough to do this.)
Unfortunately, in my opinion this is no general solution to the problem since it only covers the PG and SMC algorithms and their specific implementations. However, for other samplers the implementation of DynamicPPL.observe (it's an internal function which would make another observe function a bit confusing, so a different name might be good in any case in your example) can look completely different. I think, the better fix in such a case would be a Dirac or Delta distribution such that you can write properly ss ~ Dirac(d1 + d2).
@devmotion Thanks for the explanation of the complexity of this. It's a great idea to write this as a distribution - if possible - since then it will be added to the chains as well (I hope).
And I realize that the name observe is not the best since there is already an "observe" concept in Turing. I'll rename it in my utils program...
It's totally fine to use this macro in your code, I just hope we can come up with something that works in an even more general setting in Turing (such as a distribution) :slightly_smiling_face:
Excellent. I cross my fingers about this. :-)
The example with a Dirac distribution:
using Turing, Distributions
struct Dirac{T} <: ContinuousUnivariateDistribution
x::T
end
Distributions.logpdf(d::Dirac, x::Real) = x == d.x ? 0.0 : -Inf
@model function two_dice(ss)
d1 ~ DiscreteUniform(1, 6)
d2 ~ DiscreteUniform(1, 6)
ss ~ Dirac(d1 + d2)
end
sample(two_dice(4), MH(), 1_000)
# sample(two_dice(4), SMC(), 1_000)
# sample(two_dice(4), PG(15), 1_000)
Thanks for this, David. It's much more elegant that my variant.
It's a pity that ss is not included in the chain:
# ...
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
d1 1.9354 0.8312 0.0083 0.0241 904.2674 1.0004
d2 2.0646 0.8312 0.0083 0.0241 904.2674 1.0004
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
d1 1.0000 1.0000 2.0000 3.0000 3.0000
d2 1.0000 1.0000 2.0000 3.0000 3.0000
Is there some way to fix this within you implementation of Diract? Or is it the general general problem that observed (and "derived") variables are not included in the chain?
Yes, it's a general thing, Turing does not save observations in the chain. You could obtain the data from the model by running
model = two_dice(4)
model.args
There is also some discussion about tracking values (e.g. transformations of samples) by e.g. using returned values. Currently this is not supported during sampling but you can obtain the returned values of a model after sampling with generated_quantities:
julia> @model function two_dice(ss)
d1 ~ DiscreteUniform(1, 6)
d2 ~ DiscreteUniform(1, 6)
ss ~ Dirac(d1 + d2)
return ss
end
two_dice (generic function with 1 method)
julia> chain = sample(two_dice(4), PG(15), 1_000);
Sampling 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| Time: 0:00:03
julia> generated_quantities(two_dice(4), chain)
Thanks, David!
generated_quantities is a really nifty function solving, and it simplifies some of my programs. I'll check it out more.
And: tracking derived values (such as dd = d1 + d2) during sampling would be really great feature.
@hakank Today Dirac was added to Distributions (https://github.com/JuliaStats/Distributions.jl/pull/1231) and a new version of Distributions was released. I.e., with Distributions 0.24.7 you can now just use
julia> using Turing
julia> @model function two_dice(ss)
d1 ~ DiscreteUniform(1, 6)
d2 ~ DiscreteUniform(1, 6)
ss ~ Dirac(d1 + d2)
end
two_dice (generic function with 1 method)
julia> sample(two_dice(4), MH(), 1_000)
without any custom definitions or workarounds.
thanks, @devmotion for adding Dirac to Distributions - I'm a bit surprised that MH actually worked well for this problem.
@devmotion That's really great, David. Thanks!
I'll check it more when the new version of Distributions.jl is available to me (after I've run Pkg.update(), I still have Distributions v0.23.8 ).
I'll check it more when the new version of Distributions.jl is available to me (after I've run Pkg.update(), I still have Distributions v0.23.8 ).
I assume some other package holds back Distributions. You can check why Pkg does not update Distributions by inspecting the output of
] add [email protected]
in the Julia REPL.
It seems to be Gadfly that is the culprit, but what I can see, I've got the latest version (v1.3.1):
(@v1.5) pkg> add [email protected]
Resolving package versions...
ERROR: Unsatisfiable requirements detected for package Gadfly [c91e804a]:
Gadfly [c91e804a] log:
??possible versions are: [0.8.0, 1.0.0-1.0.1, 1.1.0, 1.2.0-1.2.1, 1.3.0-1.3.1] or uninstalled
??restricted to versions * by an explicit requirement, leaving only versions [0.8.0, 1.0.0-1.0.1, 1.1.0, 1.2.0-1.2.1, 1.3.0-1.3.1]
??restricted by compatibility requirements with Distributions [31c24e10] to versions: uninstalled ? no versions left
??Distributions [31c24e10] log:
??possible versions are: [0.16.0-0.16.4, 0.17.0, 0.18.0, 0.19.1-0.19.2, 0.20.0, 0.21.0-0.21.3, 0.21.5-0.21.12, 0.22.0-0.22.6, 0.23.0-0.23.12, 0.24.0-0.24.7] or uninstalled
??restricted to versions 0.24.7 by an explicit requirement, leaving only versions 0.24.7
thanks, @devmotion for adding
DiractoDistributions- I'm a bit surprised thatMHactually worked well for this problem.
I think for this case it's because Dirac is drawn from a static MH, not RWMH.
I have now tested this and it works great. Excellent!
Most helpful comment
@hakank Today
Diracwas added to Distributions (https://github.com/JuliaStats/Distributions.jl/pull/1231) and a new version of Distributions was released. I.e., with Distributions 0.24.7 you can now just usewithout any custom definitions or workarounds.