Pyro: troubles with HMC warm-up

Created on 20 Jul 2019  路  5Comments  路  Source: pyro-ppl/pyro

i've been having trouble porting the sparse regression example in numpyro to pyro. see code.

in particular:

  • i can't get it to work at all with vanilla HMC (which works fine in numpyro)
  • i _can_ get it to work with NUTS but in order to do that i have to go into _find_reasonable_step_size and modify it to never set the step_size larger than some appropriate threshold like 0.05. for whatever reason, the pyro adaptation (as opposed to the numpyro adaptation) pushes into a regime where the step size is too large. this then leads to bad covariances matrices which cause the mvn log_prob to fail. are there known and/or intended differences between the two adaptation algorithms?

[pyro version: dev]

discussion

All 5 comments

cc @neerajprad

@martinjankowiak @neerajprad It seems that we need to add a try catch exception for potential_fn which involves Cholesky factorization (I faced that issue with GPyTorch example and be able to resolve it by using

def potential_fn(params):
    try:
        # Computes potential energy for a GPyTorch module
        model._load_raw_parameters(**params)
        output = model(train_x)
        log_joint = mll(output, train_y).sum()
        log_joint = log_joint + transform_jacobian(model, params)
        return -log_joint
    except:  # return `nan` instead of raise Value Error if not be able to get Cholesky
        r = 0
        for p in params.values():
            r = r + 0 * p.sum()
        return torch.tensor(float('nan')) + r

in that example). In PyTorch, when Cholesky issue happens, we get RuntimeError. While in JAX, we get nan. Getting nan will tell Adaptation scheme lower step_size for us, so the next iterations will less likely to get this issue.

@martinjankowiak I don't think that there are differences between two schemes (I ported Pyro one to NumPyro one). In many cases, step_size is pushed to a large regime because at the end of each adaptation window (iter 100, 150, 250, or 450,...), we update the mass_matrix: if mass_matrix is scaled down, step_size will usually be scaled up. If the scaled-up step_size is bad (accept_prob ~ 0), then it will be scaled down in the next few iterations.

great, thanks @fehiepsi. my current solution is to do the following, which seems to work OK (and allows me to avoid using potential functions):

_nan = (-torch.ones(1)).sqrt()

class SafeMultivariateNormal(dist.MultivariateNormal):
    def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
        try:
            super(SafeMultivariateNormal, self).__init__(loc, covariance_matrix=covariance_matrix,
                                                         precision_matrix=precision_matrix, scale_tril=scale_tril,
                                                         validate_args=validate_args)
        except:
            super(SafeMultivariateNormal, self).__init__(loc, covariance_matrix=None,
                                                         precision_matrix=precision_matrix,
                                                         scale_tril=covariance_matrix * _nan,
                                                         validate_args=validate_args)

will close this now and re-open if needed

@martinjankowiak - did you have to make any tweaks to find_reasonable_step_size or did the fix above work with our existing adaptation implementation?

no i didn't make any additional tweaks

Was this page helpful?
0 / 5 - 0 ratings