Turing.jl: Wrong prediction results on multivariate params

Created on 10 Jul 2020  Â·  12Comments  Â·  Source: TuringLang/Turing.jl

The predict function outputs wrong results depending on how the multivariate parameters are constructed. I present below a simple linear regression problem which always converge fine, but whose predict results depend on how the coef parameter is constructed.

using Turing, Plots, StatsPlots
using Turing.Inference: predict

@model function simple_linear(x, y)
    intercept ~ Normal(0,1)

    ## this corrupts `predict` output
    coef ~ MvNormal(2, 1)

    ## this alternative also
    # coef ~ filldist(Normal(0,1), 2)

    ## but this version works fine
    # coef = Vector(undef, 2)
    # for i in axes(coef, 1)
    #     coef[i] ~ Normal(0,1)
    # end

    ## this works too
    # coef1 ~ Normal(0,1)
    # coef2 ~ Normal(0,1)
    # coef = [coef1, coef2]

    coef = reshape(coef, 1, size(x,1))

    mu = intercept .+ coef * x |> vec

    error ~ truncated(Normal(0,1), 0, Inf)

    y ~ MvNormal(mu, error)
end


# simple linear transformation
x = randn(2, 100)
y = [1 + 2 * a + 3 * b for (a,b) in eachcol(x)]

chain = sample(simple_linear(x, y), NUTS(), 1000)

# model converges fine
plot(chain) |> display
@show chain

p = predict(simple_linear(x, missing), chain)

# prediction correctness depends on how multivariate params were constructed
@show y[1]
@show p["y[1]"].value.data |> mean # should be close to y[1] above
@show p["y[1]"].value.data |> std # sould be close to 0.0

I'm trying this on Julia 1.4.1 with Turing 0.13.0.

All 12 comments

I can confirm that I managed to re-produce the issue reported. Anyone knows what this might be related? @TuringLang/turing

It seems like the following lines is causing the issue: https://github.com/TuringLang/Turing.jl/blob/6db59629f1f189f63350aef9ce4fe6c0bebdaba1/src/inference/Inference.jl#L755-L757

If a variable is a vector, then vn will just be the symbol for the vector rather than the symbols corresponding to the _indices_ of the vector. And so the check vn_str ∈ c.name_map.parameters will result in false since vn_str will, in this particular case be "coef" while c.name_map.parameters contains "coef[1]" and "coef[2]". It seems like this is something that has been introduced as a result of some upstream changes, as this worked just fine when I originally implemented this functionality.

@devmotion @cpfiffer Do any of you have an idea of the "appropriate" functionality to use to ensure that we can only set the values that are present?

The following snippet demonstrates the issue:

julia> x = randn(2, 100);

julia> y = [1 + 2 * a + 3 * b for (a,b) in eachcol(x)];

julia> m = simple_linear(x, y);

julia> chain = sample(m, NUTS(), 1000);

julia> var_info = Turing.VarInfo(m);

julia> c = chain[1];

julia> v = :coef;

julia> md = var_info.metadata;

julia> vn = first(md[v].vns)
coef

julia> c.name_map.parameters
4-element Array{String,1}:
 "coef[1]"
 "coef[2]"
 "error"
 "intercept"

I think there needs to be a String(vn) function in DynamicPPL that adds indexing to the VarName. Calling Symbol(vn) maps to this:

https://github.com/TuringLang/DynamicPPL.jl/blob/275ccc8791be7d5dc4eb74b6e0c19de247f56d74/src/varname.jl#L72

I can't actually find where we've overloaded string(vn) to append the indexing -- we used to have that functionality, but now I cannot seem to find where it went.

It needs to be added back in so Symbol(vn) = Symbol(String(vn, all_parts = true)) = Symbol("coef[1]").

Yeah exactly. I also can't seem to be able to find the previous overload :confused:

But it seems like the vns field is redundant now? Given the fact that it now seems to be VarName(v, ()) where v is the symbol corresponding to the random variable in the model, e.g. coef. Pretty sure vns used to be ["coef[1]", "coef[2]"], no?

It's now fixed on tor/issue-1352 (I'll make a PR asap):

julia> using Turing
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]

julia> using Turing.Inference: predict

julia> @model function simple_linear(x, y)
           intercept ~ Normal(0,1)

           ## now works
           coef ~ MvNormal(2, 1)

           ## now works
           # coef ~ filldist(Normal(0,1), 2)

           ## but this version works fine
           # coef = Vector(undef, 2)
           # for i in axes(coef, 1)
           #     coef[i] ~ Normal(0,1)
           # end

           ## this works too
           # coef1 ~ Normal(0,1)
           # coef2 ~ Normal(0,1)
           # coef = [coef1, coef2]

           coef = reshape(coef, 1, size(x,1))

           mu = intercept .+ coef * x |> vec

           error ~ truncated(Normal(0,1), 0, Inf)

           y ~ MvNormal(mu, error)
       end;

julia> # simple linear transformation
       x = randn(2, 100);

julia> y = [1 + 2 * a + 3 * b for (a,b) in eachcol(x)];

julia> m = simple_linear(x, y);

julia> chain = sample(m, NUTS(), 1000);
┌ Info: Found initial step size
└   ϵ = 0.00625

julia> p = predict(simple_linear(x, missing), chain);

julia> # prediction correctness depends on how multivariate params were constructed
       @show y[1]
y[1] = -1.5372850579938522
-1.5372850579938522

julia> @show p["y[1]"].data |> mean # should be close to y[1] above
(p["y[1]"]).data |> mean = -1.537285141344423
-1.537285141344423

julia> @show p["y[1]"].data |> std # sould be close to 0.0
(p["y[1]"]).data |> std = 1.2985395670414417e-6
1.2985395670414417e-6

I can't actually find where we've overloaded string(vn) to append the indexing -- we used to have that functionality, but now I cannot seem to find where it went.

It needs to be added back in so Symbol(vn) = Symbol(String(vn, all_parts = true)) = Symbol("coef[1]").

An overload of string is not needed (actually, it was not defined intentionally when @phipsgabler refactored it) since it falls back to the output of show which is defined (as suggested if I understand you correctly).

Right, string falls back to show, which ought to already serialize out all parts of the indexing: string(@varname(x[1][2])) == "x[1][2]" (if it doesn't, it's a bug, but I can't see why it shouldn't). getsym is what returns just the name witout indexing as a symbol.

I only left in the Symbol conversion because I knew that it was used somewhere else. It'd be much more elegant, IMHO, to use VarNames directly in all places, and have show only for printing/debugging. Possibly with a more refined discussion about and publicly documented interface for subsumes.

It's now fixed on tor/issue-1352 (I'll make a PR asap):

julia> using Turing
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]

julia> using Turing.Inference: predict

julia> @model function simple_linear(x, y)
           intercept ~ Normal(0,1)

           ## now works
           coef ~ MvNormal(2, 1)

           ## now works
           # coef ~ filldist(Normal(0,1), 2)

           ## but this version works fine
           # coef = Vector(undef, 2)
           # for i in axes(coef, 1)
           #     coef[i] ~ Normal(0,1)
           # end

           ## this works too
           # coef1 ~ Normal(0,1)
           # coef2 ~ Normal(0,1)
           # coef = [coef1, coef2]

           coef = reshape(coef, 1, size(x,1))

           mu = intercept .+ coef * x |> vec

           error ~ truncated(Normal(0,1), 0, Inf)

           y ~ MvNormal(mu, error)
       end;

julia> # simple linear transformation
       x = randn(2, 100);

julia> y = [1 + 2 * a + 3 * b for (a,b) in eachcol(x)];

julia> m = simple_linear(x, y);

julia> chain = sample(m, NUTS(), 1000);
┌ Info: Found initial step size
└   ϵ = 0.00625

julia> p = predict(simple_linear(x, missing), chain);

julia> # prediction correctness depends on how multivariate params were constructed
       @show y[1]
y[1] = -1.5372850579938522
-1.5372850579938522

julia> @show p["y[1]"].data |> mean # should be close to y[1] above
(p["y[1]"]).data |> mean = -1.537285141344423
-1.537285141344423

julia> @show p["y[1]"].data |> std # sould be close to 0.0
(p["y[1]"]).data |> std = 1.2985395670414417e-6
1.2985395670414417e-6

Hi, so I was running into some issues with predict exactly as described here, then found this thread which was very related.

It seems like this approach

coef = Vector(undef, 2)
for i in axes(coef, 1)
    coef[i] ~ Normal(0,1)
end

Does not actually work (anymore). The first post implied it used to work (presumably prior to this change?), which was odd.

I can confirm the other approaches do actually work. However, this particular approach (the one that doesn't work) is certainly most convenient in my case.
Because I actually have multilevel priors, which I now additionally pass in through a dict, and it is possible that I will have different initializations for certain priors, and hence the only easy way to initialize them, is through a for-loop.

As a side note, to try the model above with the for-loop prior initialization method, I had to make some additional changes too, because when I used the model exactly as given above, but then with the 3rd method uncommented, it gave me errors.
ERROR: LoadError: ArgumentError: type does not have a definite number of fields

So I fixed that error for now by simply not using the coef reshape, and doing:

mu = intercept .+ coef[1] * x[1,:] .+ coef[2]*x[2,:] |> vec

Does anyone have any idea why this method of constructing priors no longer works (to give the correct results) with the predict method?

I can confirm that the following now has issues:

coef = Vector(undef, 2)
for i in axes(coef, 1)
    coef[i] ~ Normal(0,1)
end

Thank you @mgmverburg for bringing attention to this!

The issue comes down to https://github.com/TuringLang/Turing.jl/blob/b9db77c493a4a4c1d7e5782bc7bbf99a0269a420/src/inference/Inference.jl#L618-L625 because if:

  1. You use the above implementation, Symbol.(md[:coef].vns) is [Symbol("coef[1]"), Symbol("coef[2]")] and MCMCChains.namesingroup(c, Symbol("coef[1]")) is going to be empty.
  2. You instead use coef ~ MvNormal(2, 1), Symbol.(md[:coef].vns) is going to be [:coef] and then you get the correct call MCMCChains.namesingroup(c, :coef).

I'm not sure how this got through though. Either I completely forgot to check this implementation and the comment above saying ## but this version works fine was was maybe referring to how it worked fine in the past, or something changed somewhere upstream that broke it. But if I had to wager, I actually now think I just brainfarted back then. No matter I definitively messed up with not adding all the above cases to the test-suite. And looking now, I see that I even overwrote the previous model used for testing!!! Though that model actually worked so np, but still not intentional.

I got a meeting for the next hour, but I'll get this sorted ASAP afterwards :+1:

EDIT: The below "hotfix" should not be used anymore. This has been fixed in [email protected].

Here's a "hotfix" for the issue:

using Turing
import Random

function Turing.Inference.transitions_from_chain(
    rng::Random.AbstractRNG,
    model::Turing.Model,
    chain::MCMCChains.Chains;
    sampler = DynamicPPL.SampleFromPrior()
)
    vi = Turing.VarInfo(model)

    chain_idx = 1
    transitions = map(1:length(chain)) do sample_idx
        # NEW! Using the "recent" improvement to `setval!` in to do the job + the change in `_setval!` below.
        DynamicPPL.setval!(vi, chain, sample_idx, chain_idx)
        model(rng, vi, sampler)

        # Convert `VarInfo` into `NamedTuple` and save
        theta = DynamicPPL.tonamedtuple(vi)
        lp = Turing.getlogp(vi)
        Turing.Inference.Transition(theta, lp)
    end

    return transitions
end

function DynamicPPL._setval_kernel!(vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, values, keys)
    string_vn = string(vn)
    string_vn_indexing = string_vn * "["
    indices = findall(keys) do x
        string_x = string(x)
        return string_x == string_vn || startswith(string_x, string_vn_indexing)
    end
    if !isempty(indices)
        sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=DynamicPPL.NaturalSort.natural)
        val = mapreduce(vcat, sorted_indices) do i
            values[i]
        end
        DynamicPPL.setval!(vi, val, vn)
        DynamicPPL.settrans!(vi, false, vn)
    else
        # NEW! If `vn` is not present in `keys`, i.e. no value was given, we assume it should be resampled.
        # Alternatively we can whether or not to resample or warn a keyword argument.
        DynamicPPL.set_flag!(vi, vn, "del")
    end
end

This requires one PR to Turing.jl and DynamicPPL.jl, but I'll try to get those up today.

Aight, so finally the issue has been resolved.
It unfortunately took quite a while because there were issues related to upgrade to Julia 1.6, etc. But should be good now:)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yebai picture yebai  Â·  5Comments

marcoct picture marcoct  Â·  6Comments

mohamed82008 picture mohamed82008  Â·  3Comments

fredcallaway picture fredcallaway  Â·  5Comments

xukai92 picture xukai92  Â·  3Comments