Jax: Implementation of np.interp

Created on 25 Jul 2020  Â·  19Comments  Â·  Source: google/jax

Hi,

As mentioned here, I am happy to leverage some code I'm putting together for my own project (sorted interpolation) to provide jax with an implementation of np.interp.

Modulo some high level checks on dimensionality of the arrays + the "period" argument which is trivial to transfer to jax, the below is my best attempt so far at it.

I've tested that the results were almost equal to numpy ones, plus the gradients match the numerical ones. Do the org members/collaborators have any problem with the approach (or ideas to make it more JAX-y), or should I iron it out, put it in a PR and send it over?

Adrien

@jit
def sorted_interp(x, xp, fp):
    m = x.shape[0]
    n = xp.shape[0]

    x = jnp.atleast_1d(x)

    j = 0
    xp_0 = xp[0]
    fp_0 = fp[0]


    def inner_fun(args):
        x_i, j = args
        def cond_fun(state):
            is_continuing, *_ = state
            return is_continuing

        def body_fun(state):
            _, _, curr_j, curr_xp_j, curr_fp_j = state

            next_xp_j = xp[curr_j + 1]
            next_fp_j = fp[curr_j + 1]

            cond = x_i > next_xp_j

            def cond_true(_):
                inner_cond = curr_j + 1 == n - 1

                def fun_true(_): return False, True, curr_j, next_xp_j, next_fp_j
                def fun_false(_): return True, False, curr_j + 1, next_xp_j, next_fp_j

                return lax.cond(inner_cond, fun_true, fun_false, None)

            def cond_false(_):
                inner_cond = curr_fp_j == next_xp_j

                def fun_true(_):  return False, True, curr_j, next_xp_j, next_fp_j
                def fun_false(_):  return False, False, curr_j, next_xp_j, next_fp_j

                return lax.cond(inner_cond, fun_true, fun_false, None)

            return lax.cond(cond, cond_true, cond_false, None)

        _, use_next_fp_j, new_j, *_ = lax.while_loop(cond_fun, body_fun, (True, False, j, xp[j], fp[j]))
        # We don't compute the result inside the loop to allow for seemless backward mode differentiability
        return new_j, lax.cond(use_next_fp_j, 
                               lambda _: fp[new_j + 1],
                               lambda _: fp[new_j] + + (fp[new_j + 1] - fp[new_j]) * (x_i - xp[new_j]) / (xp[new_j + 1] - xp[new_j]),
                               None)

    def body_fun(j, x_i):
        return lax.cond(x_i <= xp_0, lambda *_: (j, fp_0), inner_fun, (x_i, j))

    _, f = lax.scan(body_fun, 0, x)
    return f


@jit
def _interp(x, xp, fp):
    x = jnp.atleast_1d(x)
    argsort = jnp.argsort(x)
    sorted_res = sorted_interp(x[argsort], xp, fp)
    return jnp.empty_like(sorted_res).at[argsort].set(sorted_res)

def interp(x, xp, fp):
    # Do the checks like in the numpy version
    return _interp(x, xp, fp)
enhancement question

Most helpful comment

If I'd copy-pasted, there wouldn't be a typo :grin:

Another set of eyes is always helpful – thanks!

All 19 comments

This looks great! Note that there are a couple additional arguments to np.interp that we should probably handle. I took a stab a while back at a jax.numpy.interp, but it stalled because I had trouble matching the behavior of np.interp in corner cases (period boundaries & repeated values, in particualar).

I suspect your approach is more efficient for larger x arrays, but FWIW here's what I came up with:

def interp(x, xp, fp, left=None, right=None, period=None):
  x, xp, fp = map(np.asarray, (x, xp, fp))
  if period:
    x = x % period
    xp = xp % period
    i = np.argsort(xp)
    xp = xp[i]
    fp = fp[i]
    xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
    fp = np.concatenate([fp[-1:], fp, fp[1:]])

  i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  f = (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

  if not period:
    if left is None:
      left = fp[0]
    if right is None:
      right = fp[-1]
    f = np.where(x < xp[0], left, f)
    f = np.where(x > xp[-1], right, f)
  return f

I concur about the additional arguments but they are trivially implemented and I didn't want to take the attention away from the big piece of inner logic.

Also there is probably a lot of commenting to do in there too... It's probably a bit complex at first sight.

My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:

def interp(x, xp, fp):
  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?

I very much agree with that. The only reason I offered it in the first place is because I needed the sorted one for myself anyway.
For the record, the numpy core implementation actually is an hybrid between your method and mine: they use some divide and conquer trick that you would find in the np.searchsorted implementation, but that also leverage (at first sight) potential regular spacing in xp _and_ potential sorting in x.

Hi all,

While we're discussing the topic of interpolation, I wanted to throw in a link to this repository that has developed some pretty slick JAX-compatible interpolators that may be of interest: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/interpolate.py

In addition to np.interp, they have some more sophisticated interpolators like scipy.interpolate.InterpolatedUnivariateSpline.

There's an issue ticket in jax_cosmo where discussion of PR'ing into JAX was started: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/29#issuecomment-657019856

My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:

def interp(x, xp, fp):
  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?

I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size.

Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly).

My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:

def interp(x, xp, fp):
  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
  return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?

I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size.

Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly).

Actually I just said a lie, your method is faster than mine when jitted, in contradiction with raw numpy stuff.
Maybe my jax loops are unefficient then? Do you know how I can check the generated code to see what's going on under the hood?

Interesting, thanks for doing the benchmarks!

In general, loopy code in XLA will not be as fast as array/matrix operations, and the extent of the slowdown will vary depending on the accelerator (the hit is not so bad on CPU, but on GPU or TPU loopy code can be extremely slow).

One way to get a sense of the XLA code that's being generated is via the make_jaxpr function.

Note that searchsorted itself is currently implemented as a while_loop, so it will suffer from this as well. I've been thinking of experimenting with changing this to a scan over binary search depth, which could potentially yield a vast improvement in efficiency for multiple searches.

Note that searchsorted itself is currently implemented as a while_loop, so it will suffer from this as well. I've been thinking of experimenting with changing this to a scan over binary search depth, which could potentially yield a vast improvement in efficiency for multiple searches.

I guess it's not suffering from it as much as the depth of every individual loop you do is shallower than my unique one.

I expect lax.scan is also going to suffer from the slowdown then?

I made some searchsorted improvements here: #3873

So when using your method, the repeated values bit is easy to fix using either:

  • np.unique with return_index=True on xp at the very beginning
  • the following
i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
xp_i = xp[i]
xp_i_1 = xp[i-1]
fp_i = fp[i]
fp_i_1 = fp[i-1]
f = np.where(xp_i > xp_i_1, fp_i_1 *  (xp_i - x) + fp_i * (x - xp_i_1) / (xp_i - xp_i_1), fp_i)

not sure what's the most efficient (how's np.unique behaving on GPU/TPU? Shape depends on values...)

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.

Also, very minimal, but that would probably be a bit more readable (and save a multiplication but that's secondary):

fp_i_1 + (x - xp_i_1) * (fp_i - fp_i_1) / (xp_i - xp_i_1)

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.

For example:

xp = np.linspace(0, 10, 10)
fp = np.sin(xp)
x = np.linspace(0, 10, 100)

y1 = np.interp(x, xp, fp, period=5)
y2 = interp(x, xp, fp, period=5)

import matplotlib.pyplot as plt
plt.plot(xp % 5, fp, '.k', label='input')
plt.plot(x % 5, y1, '.', label='np.interp')
plt.plot(x % 5, y2, '.', label='interp')
plt.legend();

periodic

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.

For example:

xp = np.linspace(0, 10, 10)
fp = np.sin(xp)
x = np.linspace(0, 10, 100)

y1 = np.interp(x, xp, fp, period=5)
y2 = interp(x, xp, fp, period=5)

import matplotlib.pyplot as plt
plt.plot(xp % 5, fp, '.k', label='input')
plt.plot(x % 5, y1, '.', label='np.interp')
plt.plot(x % 5, y2, '.', label='interp')
plt.legend();

periodic

I think you actually made a typo when you copy pasted that part of the code from the numpy function :)

You wrote

xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[1:]])

In the numpy code it's

xp = np.concatenate([xp[-1:] - period, xp, xp[0:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[0:1]])

(note that I don't know why they wrote [0:1] instead of [:1])

When I fix it it works just fine

Screenshot from 2020-07-28 17-02-44

If I'd copy-pasted, there wouldn't be a typo :grin:

Another set of eyes is always helpful – thanks!

I've prepared a PR with a searchsorted-based interp in #3949, because we've had some feature requests for it.

@AdrienCorenflos – Please consider this to be reference implementation, and if and when it's merged, feel free to prepare a PR improving on it. Hopefully the test suite that is part of that PR will be useful!

@jakevdp solved the above

I am leaving it here for reference, but there might be a way to leverage associative scan to implement this more efficiently:

https://www.researchgate.net/publication/225920730_A_parallel_method_for_fast_and_practical_high-order_newton_interpolation

Caveat being that I don't think it has been done in any mainstream library.

Adrien

Was this page helpful?
0 / 5 - 0 ratings

Related issues

asross picture asross  Â·  3Comments

sschoenholz picture sschoenholz  Â·  3Comments

harshit-2115 picture harshit-2115  Â·  3Comments

shannon63 picture shannon63  Â·  3Comments

murphyk picture murphyk  Â·  3Comments