Jax: np.unwrap runtime explodes

Created on 15 Mar 2020  路  7Comments  路  Source: google/jax

Dear jax team,

I implemented np.unwrap by looking at the numpy source and changing the inplace-modification bits to np.where and np.concatenate.

While the values are correct, the runtime explodes for x.size > 1e7 and is generally a lot worse than the numpy version. See repro below.

Now, I don't want to waste your time with some micro-optimizations for my specific code. But are there some general performance tips here? I guess the where and concatenate are problematic, but I wouldn't know how to improve this in the JAX framework. Sorry for the code being so cryptic...

Repro:

import jax
import jax.numpy as np
import numpy as onp


@jax.partial(jax.jit, static_argnums=1)
def unwrap(p, axis):
    nd = np.ndim(p)
    dd = np.diff(p, axis=axis)

    ddmod = np.mod(dd + np.pi, 2 * np.pi) - np.pi
    ddmod = np.where(
        np.isclose(ddmod, -np.pi) & (dd > 0),
        np.pi,
        ddmod)

    ph_correct = np.where(
        np.abs(dd) < np.pi,
        0,
        ddmod - dd)

    up = np.concatenate((
        jax.lax.slice_in_dim(p, 0, 1, axis=axis),
        jax.lax.slice_in_dim(p, 1, None, axis=axis) + np.cumsum(ph_correct, axis=axis)
    ), axis=axis)

    return up


# OKAY
x = onp.random.randn(1000) * 10
assert onp.allclose(
    onp.unwrap(x),
    unwrap(x, 0),
    atol=1e-3, rtol=1e-3
)

# NOT OKAY
x = onp.random.randn(int(1e7)) * 10
unwrap(x, 0).block_until_ready()

All 7 comments

I think that the problem is in np.cumsum. The following:

big_x = onp.random.randn(size)
np.cumsum(big_x, axis=0)

is much slower than onp.cumsum. I have checked on CPU for now.
@hawkinsp do you have any advice of what I should try next?

I think the issue is that the algorithm implemented by cumsum is quadratic time if XLA implements it naively:
https://github.com/google/jax/blob/6b157ff91cd9b0030e62b43e857fcecc32cfdf8b/jax/numpy/lax_numpy.py#L1544

I think only the TPU does something smarter than the naive implementation, so it's unsurprising this takes forever with an input size of 1e7 on CPU and GPU.

My personal temptation here would be to try implementing a Blelloch-style parallel sum scan algorithm using gather and scatter-add. There's a good blog post explaining them here: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda

I'm also curious what happens if you autodiff the Blelloch algorithm. Currently cumprod isn't arbitrarily differentiable, but I don't see any fundamental reason one cannot differentiate through the Blelloch algorithm at an 2n log(n) space cost, which doesn't seem unreasonable for reverse-mode autodiff.

Thank you for your insight and the link to the algorithm!

Here's a version of cumprod that has asymptotically better complexity although only works for power of 2 array sizes:

@jax.jit
def cumprod_v2(z):
  n = len(z)
  log2 = int(math.log2(n - 1))
  zs = []
  for d in range(0, log2):
    z1 = lax.slice(z, (0,), (len(z),), (2,))
    zs.append(z1)
    z2 = lax.slice(z, (1,), (len(z),), (2,))
    z = z1 * z2
  zs.append(lax.slice(z, (0,), (len(z),), (2,)))

  dtype = jnp.dtype(z.dtype).type
  z = jnp.array([1], dtype=dtype)
  for w in reversed(zs):
    z1 = lax.pad(z, dtype(0), ((0, 1, 1),))
    z2 = lax.pad(z, dtype(0), ((1, 0, 1),))
    w = lax.pad(w, dtype(0), ((1, 0, 1),))
    z = z1 + z2 * w
  return z

It also has the advantage of most likely being a lot easier to differentiate than the current implementation.

cumsum is very similar.

I believe this is now fixed at head. Let me know how it goes!

It works and is super fast, thank you!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  路  3Comments

clemisch picture clemisch  路  3Comments

madvn picture madvn  路  3Comments

rdaems picture rdaems  路  3Comments

DylanMuir picture DylanMuir  路  3Comments