Jax: seg fault happens when running CPU code using GPU-supported jaxlib

Created on 23 Aug 2019  路  11Comments  路  Source: google/jax

The following repro script, which is taken from #1239,

import jax
from jax.config import config; config.update('jax_platform_name', 'cpu')
import jax.numpy as np
from jax import random, lax, jit


def welford_covariance():
    def init_fn(size):
        return np.zeros(size), np.zeros(size), 0

    def update_fn(sample, state):
        mean, m2, n = state
        n = n + 1
        delta_pre = sample - mean
        mean = mean + delta_pre / n
        delta_post = sample - mean
        m2 = m2 + delta_pre * delta_post
        return mean, m2, n

    def final_fn(state):
        mean, m2, n = state
        cov = m2 / (n - 1)
        cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
        return cov, cov_inv_sqrt

    return init_fn, update_fn, final_fn


def warmup_adapter():
    mm_init, mm_update, mm_final = welford_covariance()

    def init_fn(z, rng, mass_matrix_size):
        inverse_mass_matrix = np.ones(mass_matrix_size)
        mass_matrix_sqrt = inverse_mass_matrix
        mm_state = mm_init(mass_matrix_size)
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def _update_at_window_end(z, rng_ss, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
        mm_state = mm_init(inverse_mass_matrix.shape[-1])
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def update_fn(t, accept_prob, z, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        rng, rng_ss = random.split(rng)
        state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
        state = lax.cond(t < 10,
                         (z, rng_ss, state), lambda args: _update_at_window_end(*args),
                         state, lambda x: x)
        return state

    return init_fn, update_fn


wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update)  # uncomment this will make it fast

z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
    tic = time.time()
    wa_state = wa_update(t, 0.1 * t, z, wa_state)
    print(time.time() - tic)

causes

[mutex.cc : 419] RAW: Lock blocking 0x7fbae0006640   @ 0x7fbbc9577eaf 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421
[mutex.cc : 419] RAW: Unlock 0x7fbae0006640   @ 0x7fbbc6a80ae2 0x7fbbc9578088 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421

using https://storage.googleapis.com/jax-releases/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl.

This issue just happens after the recent refactoring of JAX (related to no more tuple). The code runs fine if we change config.update('jax_platform_name', 'cpu') to config.update('jax_platform_name', 'gpu') or if we use jaxlib from pypi.

cc @mattjj @skye

bug

Most helpful comment

I think I spotted it! I had been thinking this was a jaxlib issue, but I tried rolling back jax a bit and the segfault went away. I realized it must have been xla.py accidentally mixing memory pointers on one device with those on another, or something like that, so I spotted a bug fixed in 434d175 (another of mine, I'm afraid, from the big rewrite!).

I'm not sure why we weren't able to repro this internally. I'm working on a CI test we can use for this to avoid a regression, but right now I'm able to verify that this fixes the issue on a GPU cloud VM.

All 11 comments

One of the amazing XLA:{C,G}PU team members @cheshire tried to repro this internally (meaning against the most up-to-date XLA code) and the bug didn鈥檛 appear. The optimistic explanation for that is this bug is already fixed and we just need to push out an updated jaxlib. The pessimistic explanation is that it鈥檚 just hard to repro.

We鈥檒l update jaxlib wheels today and hope that fixes things.

Thanks @mattjj, that's a great news!

@fehiepsi we just uploaded jaxlib 0.1.25 wheels. Does it still repro?

@mattjj The issue still happens, though I don't see mutex.cc messages anymore. When I uncomment the line wa_update = jit(wa_update), the iteration run for the first step and stuck there (all GPU memory is allocated). I tried to restart my computer and create a new conda environment to test, the issue still happens. Are you able to replicate the issue in your machine?

It turns out that we accidentally updated to an XLA version that is still a few days out of date, so we're going to update it one more time just to rule out that category of explanations. (I haven't tried to repro externally on the new jaxlib wheels yet, but will with 0.1.26.)

This is at the top of our priority list, but just so I understand how many fires we should try to light under people, is this currently blocking your work or your users'?

Hi @mattjj, please don't worry about it. We currently pin the jax version to 0.1.41 so this issue doesn't blocking our work. :)

We just pushed jaxlib 0.1.26 wheels, but unfortunately the bug is still there: I was able to repro using your script on a fresh cloud VM.

We'll follow up with the XLA team and see if we can track this down!

I think I spotted it! I had been thinking this was a jaxlib issue, but I tried rolling back jax a bit and the segfault went away. I realized it must have been xla.py accidentally mixing memory pointers on one device with those on another, or something like that, so I spotted a bug fixed in 434d175 (another of mine, I'm afraid, from the big rewrite!).

I'm not sure why we weren't able to repro this internally. I'm working on a CI test we can use for this to avoid a regression, but right now I'm able to verify that this fixes the issue on a GPU cloud VM.

This test can catch it but only if run in isolation (i.e. not with all the other tests, because of how platform stuff is cached I think):

  def test_instantiate_device_constant_set_platform(self):
    # cf. issue 1241
    try:
      prev_platform = FLAGS.jax_platform_name
      FLAGS.jax_platform_name = "cpu"
      jit(lambda x: x)(np.zeros(2))  # doesn't segfault
    finally:
      FLAGS.jax_platform_name = prev_platform

Thanks @mattjj ! That should be the reason.

Uploaded jax 0.1.43 on pypi with this fix.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ibulu picture ibulu  路  29Comments

murphyk picture murphyk  路  31Comments

proteneer picture proteneer  路  53Comments

ericmjl picture ericmjl  路  53Comments

froystig picture froystig  路  34Comments