Jax: `jax.scipy.linalg.expm` causes an infinite loop inside two nested `fori_loop`/`scan`s.

Created on 27 Aug 2020  Â·  3Comments  Â·  Source: google/jax

This issue arose in discussion in #4114.

The following double loop of fixed length (total length six iterations) with a matrix exponential inside causes jax to hang and never terminate.

import jax
from jax import lax, jit
from jax.config import config
import jax.random as rnd
import time
import jax.numpy as np

# config.enable_omnistaging()
print(f"Number of devices: {jax.device_count()}")
print("Device:", jax.devices()[0].device_kind)


def potential(x):
    W = np.array([[0.0, x[0]], [x[1], 0.0]])
    return np.trace(jax.scipy.linalg.expm(W * W))


def leapfrog(q):
    def scan_fun(q, xs):
        V = potential(q)
        q -= V
        return q, (q, V)

    _, (qs, V) = lax.scan(scan_fun, q, xs=None, length=3)

    return qs[-1]


def hamiltonian_monte_carlo_fori(initial_position):
    initial_position = np.array(initial_position)
    n_samples = 2

    def scan_fun(carry, xs):
        q = carry
        q_new = leapfrog(q)
        return q_new, q_new

    _, samples = lax.scan(scan_fun, initial_position, None, length=n_samples)

    return samples

t0 = time.time()

print("Starting Sampling")
t0 = time.time()
fori_samples = hamiltonian_monte_carlo_fori(
    rnd.normal(rnd.PRNGKey(0), shape=(2,)) * 3
).block_until_ready()

print(f"Finished Sampling, took {time.time() - t0}s")
available bug

Most helpful comment

Good point. We could return early in a jit-friendly way using lax.cond, but I'm wary of affecting performance and introducing new global flags. This might also obscure an underlying solve-related bug that will only reappear later.

From what I can tell, the "hanging loop" isn't one directly in our codebase. At some point we pass the inf/nan values over to a lower-level linalg routine. It may be the getrf in our lax.lu_p translation, which comes from lapack via jaxlib, or it may be an XLA solve op after that. It seems better to check and return early at whatever that more upstream location is.

All 3 comments

I think this is expm hanging on invalid numeric input. The bug doesn't seem to require loop/scan. I've distilled the original example down to:

import numpy as onp
import scipy as osp
from jax import scipy as jsp

sp = jsp

def potential(x):
  W = onp.array([[0.0, x[0]], [x[1], 0.0]])
  return onp.trace(sp.linalg.expm(W * W))

def leapfrog(q):
  def scan_fun(q):
    V = potential(q)
    q -= V
    return q, (q, V)

  q, (q, V) = scan_fun(q)
  q, (q, V) = scan_fun(q)
  q, (q, V) = scan_fun(q)       # error happens here
  return q

leapfrog(onp.random.normal(size=2) * 3)

which hangs similarly. If I substitute sp = osp on line 5, I encounter:

Traceback (most recent call last):
  ...
  File ".../scipy/sparse/linalg/matfuncs.py", line 671, in _expm
    s = max(int(np.ceil(np.log2(eta_5 / theta_13))), 0)
ValueError: cannot convert float NaN to integer

The third invocation of potential passes it [-inf -inf]. Those inf values are handed to expm.

We could consider erring on input like this. A question is where to do so. Because expm reduces to matrix solve, I suspect that this attempts a solve on a terribly conditioned matrix. Maybe we should err in our solve routines, more generally.

Standard scipy's error isn't directly clear either, but at least the program halts.

Paging @shoyer who knows expm better. Thoughts?

Well, we can’t raise error from inside JIT (at least st present), but we
should certainly detect invalid input and return NaN rather than doing
infinite loops (unless error checking is explicitly disabled for
performance reasons).

On Fri, Sep 4, 2020 at 9:48 AM Roy Frostig notifications@github.com wrote:

>
>

I think this is expm hanging on invalid numeric input. The bug doesn't
seem to require loop/scan. I've distilled the original example down to:

import numpy as onp

import scipy as osp

from jax import scipy as jsp

sp = jsp

def potential(x):

W = onp.array([[0.0, x[0]], [x[1], 0.0]])

return onp.trace(sp.linalg.expm(W * W))

def leapfrog(q):

def scan_fun(q):

V = potential(q)

q -= V

return q, (q, V)

q, (q, V) = scan_fun(q)

q, (q, V) = scan_fun(q)

q, (q, V) = scan_fun(q) # error happens here

return q

leapfrog(onp.random.normal(size=2) * 3)

which hangs similarly. If I substitute sp = osp on line 5, I encounter:

Traceback (most recent call last):

...

File ".../scipy/sparse/linalg/matfuncs.py", line 671, in _expm

s = max(int(np.ceil(np.log2(eta_5 / theta_13))), 0)

ValueError: cannot convert float NaN to integer

The third invocation of potential passes it [-inf -inf]. Those inf values
are handed to expm.

We could consider erring on input like this. A question is where to do so.
Because expm reduces to matrix solve, I suspect that this attempts a
solve on a terribly conditioned matrix. Maybe we should err in our solve
routines, more generally.

Standard scipy's error isn't directly clear either, but at least the
program halts.

Paging @shoyer https://github.com/shoyer who knows expm better.
Thoughts?

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/4162#issuecomment-687264439, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AAJJFVSOEOZ2DEJYQLKOAMDSEEK7FANCNFSM4QNLGBVQ
.

Good point. We could return early in a jit-friendly way using lax.cond, but I'm wary of affecting performance and introducing new global flags. This might also obscure an underlying solve-related bug that will only reappear later.

From what I can tell, the "hanging loop" isn't one directly in our codebase. At some point we pass the inf/nan values over to a lower-level linalg routine. It may be the getrf in our lax.lu_p translation, which comes from lapack via jaxlib, or it may be an XLA solve op after that. It seems better to check and return early at whatever that more upstream location is.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ricardobarroslourenco picture ricardobarroslourenco  Â·  35Comments

shoyer picture shoyer  Â·  24Comments

proteneer picture proteneer  Â·  53Comments

kirk86 picture kirk86  Â·  22Comments

dwang55 picture dwang55  Â·  22Comments