Pyro: Getting the posterior predictive samples for a model

Created on 22 Apr 2018  路  4Comments  路  Source: pyro-ppl/pyro

Consider a simple model like:

def model(at_bats, hits):
    phi_prior = Uniform(hits.new_tensor(0), hits.new_tensor(1))
    phi = pyro.sample("phi", phi_prior)
    return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)

I have collected some samples into an empirical distribution and want to look at the posterior predictive distribution. Note that I cannot just look at the marginal distribution for "_RETURN", which is fixed at "obs". One way to do this would be to change the model to:

def model(at_bats, hits):
    phi_prior = Uniform(hits.new_tensor(0), hits.new_tensor(1))
    phi = pyro.sample("phi", phi_prior)
    pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
    return Binomial(at_bats, phi).sample()

This seems somewhat of a hack. Another solution is to use a poutine to convert the observe sites to sample sites and condition the existing sample sites to the joint posterior and then draw samples from this model. This is too circuitous and inefficient. Yet another solution would be to have observe sites only affect log_prob but not sample (not sure if this would really work). Perhaps, there is a simpler way? @eb8680 - what would you suggest?

question

Most helpful comment

@fritzo that's the right idea, except it's the prior predictive (and obs is still constrained)! What you really want to do according to the current API after #1019 is something like this:

def model(at_bats, hits):
    phi_prior = Uniform(hits.new_tensor(0), hits.new_tensor(1))
    phi = pyro.sample("phi", phi_prior)
    return pyro.sample("obs", Binomial(at_bats, phi))

conditioned_model = pyro.condition(model, data={"obs": hits})

posterior_traces = MCMC(conditioned_model, ...).run(at_bats, hits).exec_traces

# sample a random trace from the posterior somehow, 
# e.g. by taking a random index into posterior_traces
# and then remove the "obs" node so that it only contains latent variables
posterior_trace = sample_and_prune_trace(posterior_traces)

# replay the posterior trace against the unconstrained model
pp = poutine.replay(model, trace=posterior_trace)(at_bats, hits)

The awkwardness of this answer is exactly the motivation for the new interfaces and behaviors in the design doc, and for using condition to separate the specification of model and data.

All 4 comments

Let me see if I've groked @eb8680's approach :smile: Could you block-trace-replay?

tr = poutine.trace(poutine.block(model, hide=["obs"])).get_trace(at_bats, hits)
pp = poutine.replay(model, trace=tr)(at_bats, hits)

Could you block-trace-replay?

@fritzo - I'm not sure if that would work, since the replayed model will still be constrained by the "obs" site. I think we need a way to remove the obs from the model itself.

@fritzo that's the right idea, except it's the prior predictive (and obs is still constrained)! What you really want to do according to the current API after #1019 is something like this:

def model(at_bats, hits):
    phi_prior = Uniform(hits.new_tensor(0), hits.new_tensor(1))
    phi = pyro.sample("phi", phi_prior)
    return pyro.sample("obs", Binomial(at_bats, phi))

conditioned_model = pyro.condition(model, data={"obs": hits})

posterior_traces = MCMC(conditioned_model, ...).run(at_bats, hits).exec_traces

# sample a random trace from the posterior somehow, 
# e.g. by taking a random index into posterior_traces
# and then remove the "obs" node so that it only contains latent variables
posterior_trace = sample_and_prune_trace(posterior_traces)

# replay the posterior trace against the unconstrained model
pp = poutine.replay(model, trace=posterior_trace)(at_bats, hits)

The awkwardness of this answer is exactly the motivation for the new interfaces and behaviors in the design doc, and for using condition to separate the specification of model and data.

posterior_trace = sample_and_prune_trace(posterior_traces)

To make it a bit more pleasant looking, we can change __call__ to return traces with obs sites removed (as per you suggestion in the PR). ~Still the fact that we need to write our model differently and use condition poutine to do some basic checks, seems like an area of improvement for the model specification API. That will need separate explaining to the users.~

EDIT - After spending some time with this, I think using condition is actually the least hacky and explicit way of being able to generate samples from the model. I'll make this change in the tutorial.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

jpchen picture jpchen  路  5Comments

fritzo picture fritzo  路  4Comments

neerajprad picture neerajprad  路  4Comments

tobyclh picture tobyclh  路  3Comments

neerajprad picture neerajprad  路  5Comments