Pyro: Boundary values in a Dirichlet sample cause score error

Created on 29 Jul 2019  路  8Comments  路  Source: pyro-ppl/pyro

Issue Description

Since the Dirichlet distribution is constrained to a simplex, sampling can result in boundary values of 0/1. As a result, scoring can evaluate to inf.

What looks like a solution to this issue, the clamping of samples with no gradients with _dirichlet_sample_nograd, was removed in #17488.

(similar to closed issue #704)

Environment

OS: Ubuntu 18.04.2
Python: 3.6.7
PyTorch: 1.2.0.dev20190728+cpu
Pyro: 0.3.4

Code Snippet

import torch
import pyro.distributions as dist

dist.Dirichlet(torch.tensor([0.5, 0.5]), validate_args=True).log_prob(torch.tensor([0., 1.]))
bug

Most helpful comment

@ajrcampbell Thanks again for the clear example. My guess is that your inference net in the guide is producing boundary values, i.e. it is not _sample_dirichlet() that produces the boundary values. Under that hypothesis, the following should suffice:

  class Net(nn.Module):
      ...
      def forward(self, x):
          x = self.dense(x)
-         return F.softmax(x, dim=1)
+         return F.softmax(x, dim=1).clamp(min=1e-8, max=1-1e-8)

All 8 comments

Thanks for the detailed bug report and references, @ajrcampbell. Regarding your code snippet, current behavior appears correct to me: a Dirichlet distribution with parameters [0.5, 0.5] indeed has infinite density on the boundary. Do you expect instead a validation error?

Note that the torch.distributions.constraints library does not distinguish open vs closed sets. This is a shortcoming of the library that is complex to resolve. Any principled suggestion is welcome, and if there is sufficient practical advantage we could try to overhaul the library.

@fritzo I agree the code snippet shows the correct behaviour.

What I was trying to say was that under the present implementation of the Dirichlet, the simplex constraint allows for boundary conditions that can cause scoring to be inf. This makes the Dirichlet unstable. Does it make it a bug?

You are right that torch.distributions.constraints does not have an open simplex constraint. It needs to be something like:

class _SimplexOpen(Constraint):
    def check(self, value):
        return (value > 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all()

Another possibility could be to allow a lambda function to be passed to the primitive pyro.sample in order to clip, add constant, etc, on the fly when sampling?

I have built my own workaround using a custom elbo but it is not very elegant. It involves looping through the model and guide traces and checking sample nodes.

I like the idea of allowing sample modifications inside pyro.sample statements. One way you could do this now is via a distribution subclass, e.g. the following might work:

class ClippedDirichlet(Dirichlet):
    def rsample(self, sample_shape=torch.Size()):
        value = super(ClippedDirichlet, self).rsample(sample_shape)
        return value.clamp_(min=1e-8, max=1-1e-8)

Note that clamping should already be done internally. I'm not sure why the internal clamping is failing in your case. cc @neerajprad

@ajrcampbell - As @fritzo mentioned, the dirichlet samples should already be clamped and shouldn't result in inf values when being scored. If you find that to be the case, let us know (that's clearly a bug!). If you are doing further manipulations on the sample (before scoring them), then you'll need to do this clamping by yourself.

@fritzo the distribution subclass looks like a really nice solution!

@neerajprad I do not think I am doing any manipulations on the sample before scoring.

My model is complex, but the following is a simple example adapted from one of the tutorials that reproduces the bug:

# Latent Dirichlet Allocation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro.distributions as dist

seed = 1234
torch.manual_seed(seed)
pyro.set_rng_seed(seed)

pyro.clear_param_store()
device = torch.device('cpu')

num_topics = 10
num_docs = 100
vocab_size = 1000

num_epochs = 1
learning_rate = 0.01

# inference network
class Net(nn.Module):
    def __init__(self, in_features, out_features):
        super(Net, self).__init__()
        self.dense = nn.Linear(in_features, out_features)

    def forward(self, x):
        x = self.dense(x)
        return F.softmax(x, dim=1)

def model(docs):
    num_words = torch.sum(docs, dim=1).cpu().detach().numpy()

    # globals
    beta = torch.ones(vocab_size, device=device) / vocab_size
    with pyro.plate('topics', num_topics):
        phi = pyro.sample('phi', dist.Dirichlet(beta))

    # locals
    alpha = torch.ones(num_topics, device=device) / num_topics
    for d in pyro.plate('documents', num_docs):
        theta_d = pyro.sample('theta_{}'.format(d), dist.Dirichlet(alpha))

        num = num_words[d]
        # collapse latent variables (discrete assignment variables)
        prob = torch.matmul(theta_d, phi)
        prob.div_(prob.sum())

        pyro.sample('x_{}'.format(d), dist.Multinomial(num, prob), obs=docs[d])

def guide(docs):
    pyro.module('net', net)

    # globals
    beta = pyro.param('beta', (torch.ones(num_topics, vocab_size, device=device) / vocab_size)*10, constraint=constraints.positive)
    with pyro.plate('topics', num_topics):
        pyro.sample('phi', dist.Dirichlet(beta))

    # locals 
    theta = net.forward(docs)
    for d in pyro.plate('documents', num_docs):
        pyro.sample('theta_{}'.format(d), dist.Delta(theta[d]))

docs = torch.randint(1, 10, (num_docs, vocab_size)).float()
net = Net(vocab_size, num_topics)
elbo = Trace_ELBO(max_plate_nesting=1, vectorize_particles=True)
opt = Adam({'lr': learning_rate})
svi = SVI(model, guide, opt, loss=elbo)

# alternate between training and testing
for epoch in range(num_epochs):
    for phase in ['train', 'test']:
        if phase == 'train':
            loss = svi.step(docs)
        else:
            loss = svi.evaluate_loss(docs)
        print(phase, 'elbo', loss)

Output:

train elbo 3556732.5736370087
test elbo -inf
# elbo = -E_q[log(p(x|z))] + log(q(z|x)) - log(p(z))
guide_trace = pyro.poutine.trace(guide).get_trace(docs)
model_trace = pyro.poutine.trace(pyro.poutine.replay(model, trace=guide_trace)).get_trace(docs)

# E_q[log(p(x|z))] + log(p(z))
model_log_p_sum = model_trace.log_prob_sum().cpu().detach().numpy()
# log(q(z|x))
guide_log_p_sum = guide_trace.log_prob_sum().cpu().detach().numpy()

print('test elbo = ', -model_log_p_sum + guide_log_p_sum)
print('model part:', model_log_p_sum)
print('guide part:', guide_log_p_sum)

Output:

test elbo =  -inf
model part: inf
guide part: 552352.44
# inspect observed samples of dirichlet distributions in model
for name, site in model_trace.nodes.items():
    if site['type'] == 'sample':
        if any(param in name for param in ['theta', 'phi']):
            log_prob_sum = site['fn'].log_prob(site['value']).sum()
            value = site['value'].cpu().detach().numpy()
            # check for invalid boundary conditions
            if (value == 0.0).any():
                print(name)
                print('min', value.min(), 'max', value.max())



md5-1d7ff06cea0f1bcf115f7592403c4dbc



theta_68
min 0.0 max 0.83488476

@ajrcampbell Thanks again for the clear example. My guess is that your inference net in the guide is producing boundary values, i.e. it is not _sample_dirichlet() that produces the boundary values. Under that hypothesis, the following should suffice:

  class Net(nn.Module):
      ...
      def forward(self, x):
          x = self.dense(x)
-         return F.softmax(x, dim=1)
+         return F.softmax(x, dim=1).clamp(min=1e-8, max=1-1e-8)

Thanks for such a clear bug and easy to reproduce report! I checked that @fritzo's hypothesis is indeed correct and you'll find that one of the outputs of the network for theta_68 is 0.

@fritzo I cannot believe I missed this! @neerajprad and @fritzo thank-you for taking the time to explain and show where I was going wrong.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

fritzo picture fritzo  路  5Comments

jpchen picture jpchen  路  5Comments

fritzo picture fritzo  路  4Comments

eb8680 picture eb8680  路  4Comments

fehiepsi picture fehiepsi  路  4Comments