I have an inference model which works robustly in pystan with default settings (nuts init, nuts sampling). It is a dirichlet process mixture with beta-binomial mixands. It also has a constraint that the mean of the mixture distribution is specified surely by the user. Satisfying this constraint involves two steps of newton's method.
I have ported the stan model to pymc3 and I have stared at it a long time: I'm pretty sure it's exactly the same model. Anyway, when I run it with pymc3 using default settings (advi init x200000, nuts sampling) it usually never accepts the nuts proposal past about step=50. If I switch the init to nuts, it complains about the positivity of the covariance matrix (I assume because of lack of movement). I've tried messing with a bunch of the settings with no more success. Advice appreciated!
Here is the model, and all of the relevant code (including stan model) can be found in this gist.
def stick_breaking(beta):
portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining
def scale_nu(weights, locations, mu0):
n = weights.shape[0]
exp_locations = tt.exp(-locations)
#print exp_locations
h = pm.math.logit(mu0) - tt.dot(weights, locations)
tmp = pm.math.sigmoid(h + locations)
h = h - (tt.dot(weights, tmp) - mu0) / (tt.dot(weights, tmp * (1-tmp)))
tmp = pm.math.sigmoid(h + locations)
h = h - (tt.dot(weights, tmp) - mu0) / (tt.dot(weights, tmp * (1-tmp)))
return pm.math.sigmoid(h + locations)
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1,1)
q = pm.Beta('q', 1, alpha, shape=data1['Nc'])
w = pm.Deterministic('w', stick_breaking(q))
nu_star = pm.Normal('nu_star', 0, 1.9, shape=data1['Nc'])
t = pm.Uniform('t', 0, 1, shape=data1['Nc'])
nu = pm.Deterministic('nu', scale_nu(w, nu_star, data1['mu']))
obs = pm.Mixture('obs',
w,
pm.BetaBinomial.dist(
alpha = 1/(t * (1-nu)) - nu,
beta = 1/(t * nu) - (1-nu),
n = data1['Nbin']
),
observed=data1['x']
)
EDIT: (Hopefully things are sufficiently documented -- let me know if something needs clarifying)
@AustinRochford has worked a bunch with these models, maybe he has an idea.
I'm travelling through Sunday, but I'll take a look next week.
Update: I have added a section to the gist where I try and initialize PyMC3 using the posterior found by stan; I pick a random point from stan's trace, and also find the dense covariance matrix of stan's trace, and plug them into the NUTS step instance. I also copy the step size that stan ended up with (as a technical aside, I'm not sure if stan also uses a scaled step size, but in any case, 46**0.25 is of order unity).
Since this also doesn't work, I'm inclined to believe the problem is not one of HMC/NUTS tuning, but I'm too new to this to speculate further.
with model:
step = pm.NUTS(scaling=cov, is_cov=True, step_scale=0.106771 * (46**0.25))
trace = pm.sample(1000, init=None, step=step, start=start)
@ihincks I had a quick try at your model and made some minimal changes:
The Stan model use an invlogit, and you used a pm.math.sigmoid (is there any particular reason?). After changing that the model seems to run fine, you can have a look at the gist
You can also try to sort the weight to avoid the potential problem of multimodularity of the mixture model:
w_ = stick_breaking(q)
w = pm.Deterministic('w', tt.sort(w_))
@junpenglao sigmoid is pretty much the same as invlogit. invlogit tries to avoid overflows by adding an epsilon somewhere. I think using sigmoid should be preferred, as it makes life much easier for theano – it has a specialized implementation and a couple of optimizations for it.
@ihincks The problem seems to be w. There are two problems, actually: First, your implementation using stick breaking leads to a strange prior on the weights, the first is likely to be much larger than the other ones. Unless you really want that behaviour, you could correct this by adding the det of the inverse jacobian of the transformation to the logp, but it would be much easier to just use pm.Dirichlet (which does just this internally):
w = pm.Dirichlet('w', a=np.ones(data1['Nc']))
The second problem is that the w do not add to 1. Mixture checks that and returns a logp of -inf in this case. NUTS can't move at all then (we should print a more informative error message in cases like this!). Again, just use pm.Dirichlet to fix that. What it does internally, is to only look at n-1 unconstrained variables, and infer the last one to make the sum 1.
@aseyboldt the odd prior for the stick breaking process can sometimes be useful (in the case of Dirichlet processes, depending on what you are trying to accomplish).
@ihincks the hack to get around the probabilities from the stick breaking process not summing to one is to increase the number of components before truncation (data['Nc']) in your example, I believe. If the number of components is large enough, the sum will eventually be close enough to one for Mixture to be happy. I have been meaning to contribute a TruncatedStickBreaking distribution for some time to make this work exactly, with any number of components. Will look into that soon.
@aseyboldt @ihincks see the above PR for something that may work; unfortunately my hotel WIFI is too slow to build the development Docker image, so it won't be able to test it for a few days.
Thanks for your help everyone! In particular thanks for pointing out that the weights don't sum to one; I took the stick_breaking function from here, and I didn't inspect it closely enough.
@AustinRochford , yes, you are correct, I really do want stick breaking rather than Dirichlet.
I will see if I can get your TruncatedStickBreaking working.
@AustinRochford Interesting and good to know. :-)
Just a spontaneous idea I didn't think much about: Would it maybe make sense to explicitly model the (log of the) remaining length? Then you could say something like this:
remaining = pm.Beta.dist(alpha=10, beta=0)
w = pm.TruncatedStickBreaking('w', alpha=1, remaining_dist=remaining)
pm.Mixture(w, tt.stack(normal_cases, model_for_remaining))
My (non-expert) understanding is that for truncated stick breaking processes, it is standard practice to assume the last probability causes a sum to 1. Equivalently, the last stick breaking fraction is not drawn from Beta, but is rather set to 1 deterministically. See, for example, section 1.1.1 of this paper, right after Eqn 3.
The following (slightly modified from the PR) seems to be working, pasted into the notebook:
class TruncatedStrickBreaking(pm.Continuous):
def __init__(self, alpha=1., transform=pm.distributions.transforms.stick_breaking, *args, **kwargs):
kwargs.setdefault('shape', 1)
super(TruncatedStrickBreaking, self).__init__(transform=transform, *args, **kwargs)
self.alpha = alpha
mean_breaking_fractions = tt.concatenate([1 / (1 + alpha) * tt.ones(kwargs['shape']-1), [1.]])
self.mean = mean_breaking_fractions * tt.concatenate([[1], tt.extra_ops.cumprod(1 - mean_breaking_fractions)[:-1]])
# the following mode is wrong
self.mode = self.mean
self._beta_like = pm.Beta.dist(1., self.alpha)
def logp(self, value):
beta_value = value[:-1] / (1. - tt.concatenate([[0.], value[:-2].cumsum()]))
return self._beta_like.logp(beta_value)
The transformation is necessary, right? It is a bit unfortunate that the first thing logp does is undo the transform.
@ihincks you understanding is correct in general, but it's not how I've implemented these things so far, will change that soon
Most helpful comment
@junpenglao
sigmoidis pretty much the same asinvlogit.invlogittries to avoid overflows by adding an epsilon somewhere. I think usingsigmoidshould be preferred, as it makes life much easier for theano – it has a specialized implementation and a couple of optimizations for it.@ihincks The problem seems to be
w. There are two problems, actually: First, your implementation using stick breaking leads to a strange prior on the weights, the first is likely to be much larger than the other ones. Unless you really want that behaviour, you could correct this by adding the det of the inverse jacobian of the transformation to the logp, but it would be much easier to just usepm.Dirichlet(which does just this internally):The second problem is that the
wdo not add to 1.Mixturechecks that and returns a logp of -inf in this case. NUTS can't move at all then (we should print a more informative error message in cases like this!). Again, just usepm.Dirichletto fix that. What it does internally, is to only look atn-1unconstrained variables, and infer the last one to make the sum 1.