Turing.jl: observe function / @observe macro?

Created on 25 Nov 2020  路  15Comments  路  Source: TuringLang/Turing.jl

This is related to issue #1467 and the Discourse discussion "Turing: Observing sum of two random dice" (https://discourse.julialang.org/t/turing-observing-sum-of-two-random-dice/50486) where I (among other things) wondered about an @observe macro to simplify observing a condition in a model

In the following function I've integrated what I know of "manual observing".

#
# Observe a variable.
#  Example:
#  # ...
#    observe(_sampler, _varinfo, d1+d2 == ss)
#  or 
#    observe(_sampler, _varinfo, d1+d2 == ss) && return
function observe(_sampler,_varinfo,cond)
    if !cond
        if _sampler isa Turing.Sampler{<:Union{PG,SMC}}
            produce(-Inf)
        else
            Turing.@addlogprob! -Inf
        end
    else
        if _sampler isa Turing.Sampler{<:Union{PG,SMC}}
            produce(0.0)
        end
    end
    # return !cond so we can use `observe(...) && return` to return quickly from the model 
    return !cond
end

Here's a simple test model:

@model two_dice(ss) = begin
    d1 ~ DiscreteUniform(1, 6)
    d2 ~ DiscreteUniform(1, 6)

    observe(_sampler,_varinfo, d1+d2 == ss) && return
end

model = two_dice(4)

chains = sample(model, MH(), 1000)
# chains = sample(model, PG(20), 1000)
# chains = sample(model, SMC(100), 1000)

display(chains)

# Check the values of d1 + d2 (should be 4)
display(unique(chains[:d1].+chains[:d2]))

This now works for MH and due to the special handing of PG and SMC, it also works for these two samplers.
I.e. via post processing we see that d1+d2 is now (nearly) always 4, as expected.

The drawback of this approach is that one have to add the parameters _sampler and _varinfo which is a nuisance, e.g. one get warnings Warning: you are using the internal variable '_sampler' and Warning: you are using the internal variable '_varinfo'.

Is there are a better way to write this function (or an @observe macro) so one only have to write it as observe(d1+d2 == ss)?
(My macro foo is currently not enough to do this.)

Most helpful comment

@hakank Today Dirac was added to Distributions (https://github.com/JuliaStats/Distributions.jl/pull/1231) and a new version of Distributions was released. I.e., with Distributions 0.24.7 you can now just use

julia> using Turing

julia> @model function two_dice(ss)
           d1 ~ DiscreteUniform(1, 6)
           d2 ~ DiscreteUniform(1, 6)
           ss ~ Dirac(d1 + d2)
       end
two_dice (generic function with 1 method)

julia> sample(two_dice(4), MH(), 1_000)

without any custom definitions or workarounds.

All 15 comments

Unfortunately, in my opinion this is no general solution to the problem since it only covers the PG and SMC algorithms and their specific implementations. However, for other samplers the implementation of DynamicPPL.observe (it's an internal function which would make another observe function a bit confusing, so a different name might be good in any case in your example) can look completely different. I think, the better fix in such a case would be a Dirac or Delta distribution such that you can write properly ss ~ Dirac(d1 + d2).

@devmotion Thanks for the explanation of the complexity of this. It's a great idea to write this as a distribution - if possible - since then it will be added to the chains as well (I hope).

And I realize that the name observe is not the best since there is already an "observe" concept in Turing. I'll rename it in my utils program...

It's totally fine to use this macro in your code, I just hope we can come up with something that works in an even more general setting in Turing (such as a distribution) :slightly_smiling_face:

Excellent. I cross my fingers about this. :-)

The example with a Dirac distribution:

using Turing, Distributions

struct Dirac{T} <: ContinuousUnivariateDistribution
    x::T
end
Distributions.logpdf(d::Dirac, x::Real) = x == d.x ? 0.0 : -Inf

@model function two_dice(ss)
    d1 ~ DiscreteUniform(1, 6)
    d2 ~ DiscreteUniform(1, 6)
    ss ~ Dirac(d1 + d2)
end

sample(two_dice(4), MH(), 1_000)
# sample(two_dice(4), SMC(), 1_000)
# sample(two_dice(4), PG(15), 1_000)

Thanks for this, David. It's much more elegant that my variant.

It's a pity that ss is not included in the chain:

# ...
Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

          d1    1.9354    0.8312     0.0083    0.0241   904.2674    1.0004
          d2    2.0646    0.8312     0.0083    0.0241   904.2674    1.0004

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          d1    1.0000    1.0000    2.0000    3.0000    3.0000
          d2    1.0000    1.0000    2.0000    3.0000    3.0000

Is there some way to fix this within you implementation of Diract? Or is it the general general problem that observed (and "derived") variables are not included in the chain?

Yes, it's a general thing, Turing does not save observations in the chain. You could obtain the data from the model by running

model = two_dice(4)
model.args

There is also some discussion about tracking values (e.g. transformations of samples) by e.g. using returned values. Currently this is not supported during sampling but you can obtain the returned values of a model after sampling with generated_quantities:

julia> @model function two_dice(ss)
           d1 ~ DiscreteUniform(1, 6)
           d2 ~ DiscreteUniform(1, 6)
           ss ~ Dirac(d1 + d2)
           return ss
       end
two_dice (generic function with 1 method)

julia> chain = sample(two_dice(4), PG(15), 1_000);
Sampling 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| Time: 0:00:03

julia> generated_quantities(two_dice(4), chain)

Thanks, David!

generated_quantities is a really nifty function solving, and it simplifies some of my programs. I'll check it out more.

And: tracking derived values (such as dd = d1 + d2) during sampling would be really great feature.

@hakank Today Dirac was added to Distributions (https://github.com/JuliaStats/Distributions.jl/pull/1231) and a new version of Distributions was released. I.e., with Distributions 0.24.7 you can now just use

julia> using Turing

julia> @model function two_dice(ss)
           d1 ~ DiscreteUniform(1, 6)
           d2 ~ DiscreteUniform(1, 6)
           ss ~ Dirac(d1 + d2)
       end
two_dice (generic function with 1 method)

julia> sample(two_dice(4), MH(), 1_000)

without any custom definitions or workarounds.

thanks, @devmotion for adding Dirac to Distributions - I'm a bit surprised that MH actually worked well for this problem.

@devmotion That's really great, David. Thanks!

I'll check it more when the new version of Distributions.jl is available to me (after I've run Pkg.update(), I still have Distributions v0.23.8 ).

I'll check it more when the new version of Distributions.jl is available to me (after I've run Pkg.update(), I still have Distributions v0.23.8 ).

I assume some other package holds back Distributions. You can check why Pkg does not update Distributions by inspecting the output of

] add [email protected]

in the Julia REPL.

It seems to be Gadfly that is the culprit, but what I can see, I've got the latest version (v1.3.1):

(@v1.5) pkg> add [email protected]
  Resolving package versions...
ERROR: Unsatisfiable requirements detected for package Gadfly [c91e804a]:
 Gadfly [c91e804a] log:
 ??possible versions are: [0.8.0, 1.0.0-1.0.1, 1.1.0, 1.2.0-1.2.1, 1.3.0-1.3.1] or uninstalled
 ??restricted to versions * by an explicit requirement, leaving only versions [0.8.0, 1.0.0-1.0.1, 1.1.0, 1.2.0-1.2.1, 1.3.0-1.3.1]
 ??restricted by compatibility requirements with Distributions [31c24e10] to versions: uninstalled ? no versions left
   ??Distributions [31c24e10] log:
     ??possible versions are: [0.16.0-0.16.4, 0.17.0, 0.18.0, 0.19.1-0.19.2, 0.20.0, 0.21.0-0.21.3, 0.21.5-0.21.12, 0.22.0-0.22.6, 0.23.0-0.23.12, 0.24.0-0.24.7] or uninstalled
     ??restricted to versions 0.24.7 by an explicit requirement, leaving only versions 0.24.7

thanks, @devmotion for adding Dirac to Distributions - I'm a bit surprised that MH actually worked well for this problem.

I think for this case it's because Dirac is drawn from a static MH, not RWMH.

I have now tested this and it works great. Excellent!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mohamed82008 picture mohamed82008  路  3Comments

marcoct picture marcoct  路  6Comments

willtebbutt picture willtebbutt  路  4Comments

mohamed82008 picture mohamed82008  路  4Comments

xukai92 picture xukai92  路  5Comments