Pyro: Cholesky error in GaussianHMM, _sequential_gaussian_tensordot()

Created on 16 Aug 2019  路  6Comments  路  Source: pyro-ppl/pyro

I'm seeing cholesky errors in _sequential_gaussian_tensordot() when computing GaussianHMM.log_prob(), including precision matrices with very negative eigenvalues. I'd like to figure out how to numerically stabilize gaussian_tensordot().

One issue that makes this difficult to debug is that torch.eig() often segfaults due to a ref count bug https://github.com/pytorch/pytorch/issues/24450

To reproduce

import torch
from pyro.distributions.hmm import _sequential_gaussian_tensordot
gaussian = torch.load("gaussian.pkl")  # file is attached to this issue
_sequential_gaussian_tensordot(gaussian)

gaussian.pkl.zip

bug

All 6 comments

So far this appears to be a deficiency in torch.distributions.constraints.lower_cholesky whereby the optimized parameters become singular. My immediate workaround was to switch from full-covariance parameters to diagonal-covariance parameters. However we should look into more stable ways of optimizing positive definite matrices. I'll leave this issue open for discussion.

@fritzo Your point is interesting! I tried to see how that issue comes

import torch
torch.manual_seed(2)
x = torch.randn(40, 40).tril(-1) + torch.diag_embed(torch.randn(40).exp())
print("tril diag\n", x.diag())
print("cov eigen values\n", x.matmul(x.T).eig()[0][:, 0])
x_inv_stable = torch.eye(40).triangular_solve(x, upper=False).solution
print("tril stable inverse\n", x_inv_stable.diagonal(dim1=-2, dim2=-1))
x_inv = x.inverse()
print("tril inverse\n", x_inv.diagonal(dim1=-2, dim2=-1))
precision = x_inv.transpose(-2, -1).matmul(x_inv)
precision_stable = x_inv_stable.transpose(-2, -1).matmul(x_inv_stable)
print(precision_stable.cholesky()) . # singular error
print(precision.cholesky())  # singular error

which returns

tril diag
 tensor([ 1.8164,  0.9008,  0.5388,  2.4168,  0.4780,  0.2259,  0.3260,  8.2935,
         0.7697,  2.7750,  1.0188,  0.1855,  0.4467,  0.6841,  1.8005,  0.8780,
         0.3297,  2.1109,  0.1622,  0.0466,  0.1453,  4.5444, 11.1870,  0.3835,
         1.8799,  0.6463,  0.6641,  0.2685,  4.1876,  1.1715,  0.9467,  2.2795,
         1.2949,  1.0002,  3.1797,  0.2366, 11.5137,  4.9844,  0.9715,  1.8482])
cov eigen values
 tensor([ 1.7721e+02,  1.7453e+02,  1.1969e+02,  1.0392e+02,  9.6891e+01,
         8.2600e+01,  7.2500e+01,  6.7973e+01,  5.4561e+01,  4.8541e+01,
         4.5381e+01,  3.7133e+01,  2.9066e+01,  2.7167e+01,  1.9756e+01,
         1.9204e+01,  1.5435e+01,  1.3110e+01,  1.0481e+01,  9.8482e+00,
         8.0088e+00,  7.1055e+00,  6.3819e+00,  4.9923e+00,  4.1197e+00,
         2.8929e+00,  2.6459e+00,  2.0562e+00,  1.2387e+00,  9.0811e-01,
         5.9837e-01,  3.6503e-01,  2.8105e-01,  1.0950e-01,  1.2662e-02,
         4.7462e-04,  5.6004e-05,  1.7590e-05, -1.3351e-06,  2.5402e-06])
tril stable inverse
 tensor([ 0.5505,  1.1102,  1.8561,  0.4138,  2.0919,  4.4262,  3.0679,  0.1206,
         1.2992,  0.3604,  0.9815,  5.3896,  2.2389,  1.4619,  0.5554,  1.1389,
         3.0333,  0.4737,  6.1656, 21.4709,  6.8805,  0.2201,  0.0894,  2.6072,
         0.5320,  1.5473,  1.5057,  3.7250,  0.2388,  0.8536,  1.0563,  0.4387,
         0.7722,  0.9998,  0.3145,  4.2262,  0.0869,  0.2006,  1.0294,  0.5411])
tril inverse
 tensor([  0.8577,   1.7351,   2.4454,   0.0628,   4.4942,  -8.6023,  -0.4101,
          0.1345,   2.5317,  -0.2871,   1.2715, -16.7043,  -6.0439,  -2.6443,
          0.5777,  -0.8966,  -5.0385,   0.6307, -13.4730, -28.8008,  -3.9316,
          0.2600,   0.1201,  -6.0625,   0.9570,   1.2705,  -4.9844,   4.0703,
          0.2386,  -0.2930,   0.1016,   0.4507,  -0.1250,  -0.0625,   0.3105,
         -9.5000,  -0.2500,  -0.2344,  -1.0000,  -2.5000])

We can see that the current implementation of precision_matrix property is not stable due to
torch.inverse (which produces a singular tril matrix). Regardless, precision matrix is singular for both versions.

Regarding other parameterizations, there are two of them which might work:

  • LKJ way: scale_tril = D @ L, where D is positive diagonal matrix and L is cholesky of a corr matrix. This is somehow similar to the current lower_cholesky implementation but maybe helpful for avoiding bad scale_tril matrices.
  • Low rank way: cov = W @ W.t + D. This is a perturbation version of your current solution by using an additional low_rank factor matrix. I think we can relax the current assertions checking if input distributions are MVN/IndependentNormal to allow LowRankMVN. I can make a PR for it if you want to try this way.

In the mean time, I will try to see if any of these parameterizations is stable enough.

@fehiepsi I like the low rank idea. What if we implemented a PositiveDefiniteTransform as something like

class PositiveDefiniteTransform(Transform):
    ...
    def _call(self, x):
        assert x.size(-1) == x.size(-2)  # assume full rank
        y = x.matmul(x.transpose(-1, -2))
        jitter = 1e-6 * y.norm(dim=-1, -2).unsqueeze(-1).unsqueeze(-1)
        y = y + jitter * torch.eye(y.size(-1))

That is, what if we replaced most uses of constraints.lower_cholesky with constraints.positive_definite, since positive definiteness is a little easier to enforce for the actual pd matrix than for a Cholesky factor.

The only downside I see is that forward computations would require an extra cholesky() and backward computations would require its derivative, and those might be expensive.

@fritzo Except for the computation overhead of those cholesky decomposition, I think that adding jitter would be fine and is more stable than my other variants. Adding a jitter term proportional to y.norm() makes sense to me but I don't know if it is a good way in practice. At least, it has a nice property that the condition number is stable if we scale the matrix by a constant factor (condition_number(cA + 1e-6 ||cA||) = condition_number(A + 1e-6 ||A||)).

I tried to implement LKJ way but it also suffered from the issue L @ L.t is singular despite that L.diag() is positive. The low rank version is more stable but for high dimensional matrix (I tested with dim=40), small rank cov_factor also leads to singular issue.

That made me rethinking about the problem. I tried to read kalman filter literatures to see how people deal with the singular problems. At least, we have some alternatives:

  • In Kalman filter, to deal with the issue of subtracting two positive definite matrices at e.g. marginalization, there is a Joshep form which converting the subtraction to a sum. I am able to derive a similar form for information filter by leveraging the "Joshep" identity
A - A B (Bt A B + R)^-1 Bt A = (I - C Bt) A (I - B Ct) + C R Ct,
where C = A B (Bt A B + R)^-1

Here, B plays the role of transition matrix of affine transform of two variables (e.g. at matrix_and_mvn_to_gaussian function. However, in our implementation, B is consumed into the matrix AB, and Bt A B term is not available because we added it with some other precision matrix.

  • According to the book Kalman Filtering: Theory and Practice with MATLAB, using square root forms (Cholesky or LDLt) is a better solution. The formula can be found in the paper (which you sent me). It uses some sorts of "triangularization" (I think it is QR transform?). However, I haven't been able to write down the details for this approach yet.

  • A third method is a combination of the above two approaches, which is more intuitive to me. By allowing a precision matrix to be represented as H = M Mt (no need for M to be lower triangular), I am able to rescue part of the transition matrix B in the first approach during matrix_and_mvn_to_gaussian. In particular, we have

gaussian.precision = [ [B @mvn_prec @ Bt, -B @ mvn_prec], [-mvn_prec @ Bt, mvn_prec] ]

hence

gaussian.precision_sqrt = [[B @ mvn_prec_sqrt], [mvn_prec_sqrt]]

This seems enough to resolve the issue of subtracting two positive definite matrix which I mentioned above, while still takes advantages of square root filters. Indeed, the trickiest marginalization operator (Paa - Pab @ inv(Pbb) @ Pba) under this representation will require us provide a square root for a form

N = I - Ct @ inv(C @ Ct) @ C

where C is a m x n matrix such that C @ Ct is non-singular. Using Joshep identity, we can realize that the above matrix N is idempotent, hence its square root is its self!! I think that this is a nice discovery so I am following up with this approach. I'll make a PR for it soon with a hope that I didn't make a mistake during the way. :D

@fehiepsi My current thinking is that Gaussian is probably already sufficiently stable, and that we would be better off fixing the bad inputs to mvn_to_gaussian() arising from bad parameters of support constraints.lower_cholesky.

@fritzo Though fixing bad parameters of constraints.lower_cholesky is a good solution to make sure precision matrices are positive definite, working with positive definite matrices might lead to the same issues as in KF literatures. The nice properties of square root form is we don't need to preserve the symmetry and positive definite. At least, GaussianS implementation will not involve any cholesky/inverse operator. If MVN allows prec_scale_tril input, then we won't need any cholesky/inverse operator from the parameter space to the loss in GaussianHMM. Instead, the corresponding operators in square root form are qr and triangular_solve.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

neerajprad picture neerajprad  路  4Comments

tristandeleu picture tristandeleu  路  3Comments

fehiepsi picture fehiepsi  路  3Comments

jpchen picture jpchen  路  5Comments

lundlab-kaltinel picture lundlab-kaltinel  路  3Comments