i've been having trouble porting the sparse regression example in numpyro to pyro. see code.
in particular:
_find_reasonable_step_size and modify it to never set the step_size larger than some appropriate threshold like 0.05. for whatever reason, the pyro adaptation (as opposed to the numpyro adaptation) pushes into a regime where the step size is too large. this then leads to bad covariances matrices which cause the mvn log_prob to fail. are there known and/or intended differences between the two adaptation algorithms?[pyro version: dev]
cc @neerajprad
@martinjankowiak @neerajprad It seems that we need to add a try catch exception for potential_fn which involves Cholesky factorization (I faced that issue with GPyTorch example and be able to resolve it by using
def potential_fn(params):
try:
# Computes potential energy for a GPyTorch module
model._load_raw_parameters(**params)
output = model(train_x)
log_joint = mll(output, train_y).sum()
log_joint = log_joint + transform_jacobian(model, params)
return -log_joint
except: # return `nan` instead of raise Value Error if not be able to get Cholesky
r = 0
for p in params.values():
r = r + 0 * p.sum()
return torch.tensor(float('nan')) + r
in that example). In PyTorch, when Cholesky issue happens, we get RuntimeError. While in JAX, we get nan. Getting nan will tell Adaptation scheme lower step_size for us, so the next iterations will less likely to get this issue.
@martinjankowiak I don't think that there are differences between two schemes (I ported Pyro one to NumPyro one). In many cases, step_size is pushed to a large regime because at the end of each adaptation window (iter 100, 150, 250, or 450,...), we update the mass_matrix: if mass_matrix is scaled down, step_size will usually be scaled up. If the scaled-up step_size is bad (accept_prob ~ 0), then it will be scaled down in the next few iterations.
great, thanks @fehiepsi. my current solution is to do the following, which seems to work OK (and allows me to avoid using potential functions):
_nan = (-torch.ones(1)).sqrt()
class SafeMultivariateNormal(dist.MultivariateNormal):
def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
try:
super(SafeMultivariateNormal, self).__init__(loc, covariance_matrix=covariance_matrix,
precision_matrix=precision_matrix, scale_tril=scale_tril,
validate_args=validate_args)
except:
super(SafeMultivariateNormal, self).__init__(loc, covariance_matrix=None,
precision_matrix=precision_matrix,
scale_tril=covariance_matrix * _nan,
validate_args=validate_args)
will close this now and re-open if needed
@martinjankowiak - did you have to make any tweaks to find_reasonable_step_size or did the fix above work with our existing adaptation implementation?
no i didn't make any additional tweaks