Is there currently a way to sample from new distributions which are similar to the ones already implemented in Pyro?
For instance, in the case of a multivariate truncated Gaussian we would need to define how to compute the gradient as a piecewise function and I am not sure where to do so.
hello, you might take a look here:
Hi @bmazoure you could also define a Rejector distribution for a truncated normal. This would require you to implement the total probability of acceptance log_scale as a function of your truncation plane. For example to truncate by ensuring sample[0] > min_x0, you should be able to define
class TruncatedMVN(dist.Rejector):
def __init__(self, loc, covariance_matrix, min_x0):
propose = dist.MultivariateNormal(loc, covariance_matrix)
def log_prob_accept(x):
return (x[0] > min_x0).type_as(x).log()
scale_0 = torch.sqrt(covariance_matrix[0, 0])
log_scale = torch.log(1 - dist.Normal(loc[0], scale_0).cdf(min_x0))
super(TruncatedMVN, self).__init__(propose, log_prob_accept, log_scale)
(Note this is available on Pyro dev branch, but not in the 0.1.2 release)
Will try these suggestions, thanks!
I'm trying to use a rejector to limit a Pareto distribution, which is wrapped from a torch distribution:
class TruncatedPareto(Rejector):
def __init__(self, scale, alpha, upper_limit, validate_args=None):
propose = Pareto(scale, alpha, validate_args=validate_args)
def log_prob_accept(x):
return (x < upper_limit).type_as(x).log()
log_scale = torch.Tensor(alpha) * torch.log(torch.Tensor([scale / upper_limit]))
super(TruncatedPareto, self).__init__(propose, log_prob_accept, log_scale)
I can sample from it by calling .sample() but inside an MCMC, I get the error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-353-ebb20659df6e> in <module>
2 kernel = NUTS(conditioned_model)
3 mcmc = MCMC(kernel, num_samples=10, warmup_steps=2)
----> 4 mcmc.run(meta)
5 posterior_samples = mcmc.get_samples()
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
355 z_flat_acc = [[] for _ in range(self.num_chains)]
356 with pyro.validation_enabled(not self.disable_validation):
--> 357 for x, chain_id in self.sampler.run(*args, **kwargs):
358 if num_samples[chain_id] == 0:
359 num_samples[chain_id] += 1
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
164 logger = initialize_logger(logger, "", progress_bar)
165 hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook)
--> 166 for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
167 i if self.num_chains > 1 else None,
168 *args, **kwargs):
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
108
109 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 110 kernel.setup(warmup_steps, *args, **kwargs)
111 params = kernel.initial_params
112 # yield structure (key, value.shape) of params
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
264 self._warmup_steps = warmup_steps
265 if self.model is not None:
--> 266 self._initialize_model_properties(args, kwargs)
267 potential_energy = self.potential_fn(self.initial_params)
268 self._cache(self.initial_params, potential_energy, None)
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
229
230 def _initialize_model_properties(self, model_args, model_kwargs):
--> 231 init_params, potential_fn, transforms, trace = initialize_model(
232 self.model,
233 model_args,
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains)
371 model = poutine.enum(config_enumerate(model),
372 first_available_dim=-1 - max_plate_nesting)
--> 373 model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
374 has_enumerable_sites = False
375 prototype_samples = {}
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
183 Calls this poutine and returns its trace instead of the function's return value.
184 """
--> 185 self(*args, **kwargs)
186 return self.msngr.get_trace()
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError):
167 exc_type, exc_value, traceback = sys.exc_info()
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
<ipython-input-349-8dfe588a1e6e> in model(meta)
---> 25 M = pyro.sample('M', TruncatedPareto(0.1, xi[g], 1.5))
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111 msg["is_observed"] = True
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
115
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
199
200 for frame in stack[-pointer:]:
--> 201 frame._postprocess_message(msg)
202
203 cont = msg["continuation"]
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
139 method_name = "_pyro_post_{}".format(msg["type"])
140 if hasattr(self, method_name):
--> 141 return getattr(self, method_name)(msg)
142 return None
143
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/enum_messenger.py in _pyro_post_sample(self, msg)
201 if value is None:
202 return
--> 203 shape = value.shape[:value.dim() - msg["fn"].event_dim]
204 dim_to_id = msg["infer"].setdefault("_dim_to_id", {})
205 dim_to_id.update(self._param_dims.get(msg["name"], {}))
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/distributions/torch_distribution.py in event_dim(self)
51 :rtype: int
52 """
---> 53 return len(self.event_shape)
54
55 def shape(self, sample_shape=torch.Size()):
/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/torch/distributions/distribution.py in event_shape(self)
70 Returns the shape of a single sample (without batching).
71 """
---> 72 return self._event_shape
73
74 @property
AttributeError: 'TruncatedPareto' object has no attribute '_event_shape'
@dobos It looks like Rejector is missing a call to super().__init__(...). This is a bug, will push a fix.
Were truncated distributions ever added?
@GlastonburyC I believe @alicanb was working on TruncatedDistribution in https://github.com/pytorch/pytorch/pull/32377
Is there a way to implement a Distribution for sampling from a multivariate Gaussian truncated on a hyperplane? for a example, making sure that the sampled array sums to zero? @fritzo
@zoj613 You could easily implement a multivariate Gaussian truncated along a single hyperplane passing through the center by generalizing FoldedDistribution to a MultivariateFoldedDistribution; however I don't know an easy way to allow multiple truncations or to allow truncation along a single hyperplane that does not pass through the distribution center.
@fritzo How can I sample from a normal but only have negative support?
@GlastonburyC For a negative Gaussian with zero mode, you could use a transformed HalfNormal
class NegativeHalfNormal(dist.TransformedDistribution):
support = constraints.less_than(0)
def __init__(self, scale):
base_dist = dist.HalfNormal(scale)
transform = dist.transforms.AffineTransform(0., -1.)
super().__init__(base_dist, transform)
If your loc parameter is nonzero then you would need to wait for https://github.com/pytorch/pytorch/pull/32377
In practice I prefer FoldedDistribution over truncated normal, as it is more numerically stable and has qualitatively similar density. You could create a negatively-supported folded-normal via
class NegativeFoldedNormal(TransformedDistribution):
support = constraints.less_than(0)
def __init__(self, loc, scale):
base_dist = dist.FoldedDistribution(dist.Normal(loc, scale))
transform = dist.transforms.AffineTransform(0., -1.)
super().__init__(base_dist, transform)
Thanks @fritzo, that's super helpful of you. :)
Most helpful comment
@GlastonburyC For a negative Gaussian with zero mode, you could use a transformed
HalfNormalIf your
locparameter is nonzero then you would need to wait for https://github.com/pytorch/pytorch/pull/32377In practice I prefer
FoldedDistributionover truncated normal, as it is more numerically stable and has qualitatively similar density. You could create a negatively-supported folded-normal via