Pyro: Sampling from truncated Gaussian

Created on 20 Feb 2018  路  12Comments  路  Source: pyro-ppl/pyro

Is there currently a way to sample from new distributions which are similar to the ones already implemented in Pyro?

For instance, in the case of a multivariate truncated Gaussian we would need to define how to compute the gradient as a piecewise function and I am not sure where to do so.

documentation question

Most helpful comment

@GlastonburyC For a negative Gaussian with zero mode, you could use a transformed HalfNormal

class NegativeHalfNormal(dist.TransformedDistribution):
    support = constraints.less_than(0)
    def __init__(self, scale):
        base_dist = dist.HalfNormal(scale)
        transform = dist.transforms.AffineTransform(0., -1.)
        super().__init__(base_dist, transform)

If your loc parameter is nonzero then you would need to wait for https://github.com/pytorch/pytorch/pull/32377

In practice I prefer FoldedDistribution over truncated normal, as it is more numerically stable and has qualitatively similar density. You could create a negatively-supported folded-normal via

class NegativeFoldedNormal(TransformedDistribution):
    support = constraints.less_than(0)
    def __init__(self, loc, scale):
        base_dist = dist.FoldedDistribution(dist.Normal(loc, scale))
        transform = dist.transforms.AffineTransform(0., -1.)
        super().__init__(base_dist, transform)

All 12 comments

hello, you might take a look here:

https://github.com/probtorch/pytorch/pull/121

Hi @bmazoure you could also define a Rejector distribution for a truncated normal. This would require you to implement the total probability of acceptance log_scale as a function of your truncation plane. For example to truncate by ensuring sample[0] > min_x0, you should be able to define

class TruncatedMVN(dist.Rejector):
    def __init__(self, loc, covariance_matrix, min_x0):
        propose = dist.MultivariateNormal(loc, covariance_matrix)

        def log_prob_accept(x):
            return (x[0] > min_x0).type_as(x).log()

        scale_0 = torch.sqrt(covariance_matrix[0, 0])
        log_scale = torch.log(1 - dist.Normal(loc[0], scale_0).cdf(min_x0))
        super(TruncatedMVN, self).__init__(propose, log_prob_accept, log_scale)

(Note this is available on Pyro dev branch, but not in the 0.1.2 release)

Will try these suggestions, thanks!

I'm trying to use a rejector to limit a Pareto distribution, which is wrapped from a torch distribution:

class TruncatedPareto(Rejector):
    def __init__(self, scale, alpha, upper_limit, validate_args=None):
        propose = Pareto(scale, alpha, validate_args=validate_args)

        def log_prob_accept(x):
            return (x < upper_limit).type_as(x).log()

        log_scale = torch.Tensor(alpha) * torch.log(torch.Tensor([scale / upper_limit]))
        super(TruncatedPareto, self).__init__(propose, log_prob_accept, log_scale)

I can sample from it by calling .sample() but inside an MCMC, I get the error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-353-ebb20659df6e> in <module>
      2 kernel = NUTS(conditioned_model)
      3 mcmc = MCMC(kernel, num_samples=10, warmup_steps=2)
----> 4 mcmc.run(meta)
      5 posterior_samples = mcmc.get_samples()

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    355         z_flat_acc = [[] for _ in range(self.num_chains)]
    356         with pyro.validation_enabled(not self.disable_validation):
--> 357             for x, chain_id in self.sampler.run(*args, **kwargs):
    358                 if num_samples[chain_id] == 0:
    359                     num_samples[chain_id] += 1

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    164             logger = initialize_logger(logger, "", progress_bar)
    165             hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook)
--> 166             for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
    167                                        i if self.num_chains > 1 else None,
    168                                        *args, **kwargs):

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    108 
    109 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 110     kernel.setup(warmup_steps, *args, **kwargs)
    111     params = kernel.initial_params
    112     # yield structure (key, value.shape) of params

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    264         self._warmup_steps = warmup_steps
    265         if self.model is not None:
--> 266             self._initialize_model_properties(args, kwargs)
    267         potential_energy = self.potential_fn(self.initial_params)
    268         self._cache(self.initial_params, potential_energy, None)

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    229 
    230     def _initialize_model_properties(self, model_args, model_kwargs):
--> 231         init_params, potential_fn, transforms, trace = initialize_model(
    232             self.model,
    233             model_args,

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains)
    371     model = poutine.enum(config_enumerate(model),
    372                          first_available_dim=-1 - max_plate_nesting)
--> 373     model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
    374     has_enumerable_sites = False
    375     prototype_samples = {}

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    183         Calls this poutine and returns its trace instead of the function's return value.
    184         """
--> 185         self(*args, **kwargs)
    186         return self.msngr.get_trace()

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError):
    167                 exc_type, exc_value, traceback = sys.exc_info()

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

<ipython-input-349-8dfe588a1e6e> in model(meta)
---> 25         M = pyro.sample('M', TruncatedPareto(0.1, xi[g], 1.5))

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    111             msg["is_observed"] = True
    112         # apply the stack and return its return value
--> 113         apply_stack(msg)
    114         return msg["value"]
    115 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    199 
    200     for frame in stack[-pointer:]:
--> 201         frame._postprocess_message(msg)
    202 
    203     cont = msg["continuation"]

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
    139         method_name = "_pyro_post_{}".format(msg["type"])
    140         if hasattr(self, method_name):
--> 141             return getattr(self, method_name)(msg)
    142         return None
    143 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/enum_messenger.py in _pyro_post_sample(self, msg)
    201         if value is None:
    202             return
--> 203         shape = value.shape[:value.dim() - msg["fn"].event_dim]
    204         dim_to_id = msg["infer"].setdefault("_dim_to_id", {})
    205         dim_to_id.update(self._param_dims.get(msg["name"], {}))

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/distributions/torch_distribution.py in event_dim(self)
     51         :rtype: int
     52         """
---> 53         return len(self.event_shape)
     54 
     55     def shape(self, sample_shape=torch.Size()):

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/torch/distributions/distribution.py in event_shape(self)
     70         Returns the shape of a single sample (without batching).
     71         """
---> 72         return self._event_shape
     73 
     74     @property

AttributeError: 'TruncatedPareto' object has no attribute '_event_shape'

@dobos It looks like Rejector is missing a call to super().__init__(...). This is a bug, will push a fix.

Were truncated distributions ever added?

@GlastonburyC I believe @alicanb was working on TruncatedDistribution in https://github.com/pytorch/pytorch/pull/32377

Is there a way to implement a Distribution for sampling from a multivariate Gaussian truncated on a hyperplane? for a example, making sure that the sampled array sums to zero? @fritzo

@zoj613 You could easily implement a multivariate Gaussian truncated along a single hyperplane passing through the center by generalizing FoldedDistribution to a MultivariateFoldedDistribution; however I don't know an easy way to allow multiple truncations or to allow truncation along a single hyperplane that does not pass through the distribution center.

@fritzo How can I sample from a normal but only have negative support?

@GlastonburyC For a negative Gaussian with zero mode, you could use a transformed HalfNormal

class NegativeHalfNormal(dist.TransformedDistribution):
    support = constraints.less_than(0)
    def __init__(self, scale):
        base_dist = dist.HalfNormal(scale)
        transform = dist.transforms.AffineTransform(0., -1.)
        super().__init__(base_dist, transform)

If your loc parameter is nonzero then you would need to wait for https://github.com/pytorch/pytorch/pull/32377

In practice I prefer FoldedDistribution over truncated normal, as it is more numerically stable and has qualitatively similar density. You could create a negatively-supported folded-normal via

class NegativeFoldedNormal(TransformedDistribution):
    support = constraints.less_than(0)
    def __init__(self, loc, scale):
        base_dist = dist.FoldedDistribution(dist.Normal(loc, scale))
        transform = dist.transforms.AffineTransform(0., -1.)
        super().__init__(base_dist, transform)

Thanks @fritzo, that's super helpful of you. :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

neerajprad picture neerajprad  路  4Comments

null-a picture null-a  路  4Comments

fehiepsi picture fehiepsi  路  4Comments

fehiepsi picture fehiepsi  路  4Comments

fritzo picture fritzo  路  4Comments