Hi,
As mentioned here, I am happy to leverage some code I'm putting together for my own project (sorted interpolation) to provide jax with an implementation of np.interp.
Modulo some high level checks on dimensionality of the arrays + the "period" argument which is trivial to transfer to jax, the below is my best attempt so far at it.
I've tested that the results were almost equal to numpy ones, plus the gradients match the numerical ones. Do the org members/collaborators have any problem with the approach (or ideas to make it more JAX-y), or should I iron it out, put it in a PR and send it over?
Adrien
@jit
def sorted_interp(x, xp, fp):
m = x.shape[0]
n = xp.shape[0]
x = jnp.atleast_1d(x)
j = 0
xp_0 = xp[0]
fp_0 = fp[0]
def inner_fun(args):
x_i, j = args
def cond_fun(state):
is_continuing, *_ = state
return is_continuing
def body_fun(state):
_, _, curr_j, curr_xp_j, curr_fp_j = state
next_xp_j = xp[curr_j + 1]
next_fp_j = fp[curr_j + 1]
cond = x_i > next_xp_j
def cond_true(_):
inner_cond = curr_j + 1 == n - 1
def fun_true(_): return False, True, curr_j, next_xp_j, next_fp_j
def fun_false(_): return True, False, curr_j + 1, next_xp_j, next_fp_j
return lax.cond(inner_cond, fun_true, fun_false, None)
def cond_false(_):
inner_cond = curr_fp_j == next_xp_j
def fun_true(_): return False, True, curr_j, next_xp_j, next_fp_j
def fun_false(_): return False, False, curr_j, next_xp_j, next_fp_j
return lax.cond(inner_cond, fun_true, fun_false, None)
return lax.cond(cond, cond_true, cond_false, None)
_, use_next_fp_j, new_j, *_ = lax.while_loop(cond_fun, body_fun, (True, False, j, xp[j], fp[j]))
# We don't compute the result inside the loop to allow for seemless backward mode differentiability
return new_j, lax.cond(use_next_fp_j,
lambda _: fp[new_j + 1],
lambda _: fp[new_j] + + (fp[new_j + 1] - fp[new_j]) * (x_i - xp[new_j]) / (xp[new_j + 1] - xp[new_j]),
None)
def body_fun(j, x_i):
return lax.cond(x_i <= xp_0, lambda *_: (j, fp_0), inner_fun, (x_i, j))
_, f = lax.scan(body_fun, 0, x)
return f
@jit
def _interp(x, xp, fp):
x = jnp.atleast_1d(x)
argsort = jnp.argsort(x)
sorted_res = sorted_interp(x[argsort], xp, fp)
return jnp.empty_like(sorted_res).at[argsort].set(sorted_res)
def interp(x, xp, fp):
# Do the checks like in the numpy version
return _interp(x, xp, fp)
This looks great! Note that there are a couple additional arguments to np.interp that we should probably handle. I took a stab a while back at a jax.numpy.interp, but it stalled because I had trouble matching the behavior of np.interp in corner cases (period boundaries & repeated values, in particualar).
I suspect your approach is more efficient for larger x arrays, but FWIW here's what I came up with:
def interp(x, xp, fp, left=None, right=None, period=None):
x, xp, fp = map(np.asarray, (x, xp, fp))
if period:
x = x % period
xp = xp % period
i = np.argsort(xp)
xp = xp[i]
fp = fp[i]
xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[1:]])
i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
f = (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])
if not period:
if left is None:
left = fp[0]
if right is None:
right = fp[-1]
f = np.where(x < xp[0], left, f)
f = np.where(x > xp[-1], right, f)
return f
I concur about the additional arguments but they are trivially implemented and I didn't want to take the attention away from the big piece of inner logic.
Also there is probably a lot of commenting to do in there too... It's probably a bit complex at first sight.
My main concern with the sorted_interp approach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation using searchsorted is just a few lines of code:
def interp(x, xp, fp):
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])
If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the searchsorted version proves too slow, maybe focusing our optimization effort on searchsorted would give more bang for the buck?
I very much agree with that. The only reason I offered it in the first place is because I needed the sorted one for myself anyway.
For the record, the numpy core implementation actually is an hybrid between your method and mine: they use some divide and conquer trick that you would find in the np.searchsorted implementation, but that also leverage (at first sight) potential regular spacing in xp _and_ potential sorting in x.
Hi all,
While we're discussing the topic of interpolation, I wanted to throw in a link to this repository that has developed some pretty slick JAX-compatible interpolators that may be of interest: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/interpolate.py
In addition to np.interp, they have some more sophisticated interpolators like scipy.interpolate.InterpolatedUnivariateSpline.
There's an issue ticket in jax_cosmo where discussion of PR'ing into JAX was started: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues/29#issuecomment-657019856
My main concern with the
sorted_interpapproach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation usingsearchsortedis just a few lines of code:def interp(x, xp, fp): i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1) return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the
searchsortedversion proves too slow, maybe focusing our optimization effort onsearchsortedwould give more bang for the buck?
I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size.
Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly).
My main concern with the
sorted_interpapproach is the maintenance burden of adding such a complex implementation. Ignoring optional parameters & edge cases, the core of the interpolation usingsearchsortedis just a few lines of code:def interp(x, xp, fp): i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1) return (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])If we're replacing this with a substantially more complicated implementation, we should make sure we're getting a commensurate performance increase. Another thought: if the
searchsortedversion proves too slow, maybe focusing our optimization effort onsearchsortedwould give more bang for the buck?I just did a rough test of your method (removing the left, right and period arguments), and it seems like it is substantially (10 times) slower indeed (I would expect that it is due to the way search sorted is implemented in jax as it is "only" between 3 and 5 times slower in raw numpy depending on the input size.
Note that np.interp is twice as fast as me (probably due to the "guessing" of the index they do, and some contiguity they manage to ensure by not sorting the input array but I've not checked exactly).
Actually I just said a lie, your method is faster than mine when jitted, in contradiction with raw numpy stuff.
Maybe my jax loops are unefficient then? Do you know how I can check the generated code to see what's going on under the hood?
Interesting, thanks for doing the benchmarks!
In general, loopy code in XLA will not be as fast as array/matrix operations, and the extent of the slowdown will vary depending on the accelerator (the hit is not so bad on CPU, but on GPU or TPU loopy code can be extremely slow).
One way to get a sense of the XLA code that's being generated is via the make_jaxpr function.
Note that searchsorted itself is currently implemented as a while_loop, so it will suffer from this as well. I've been thinking of experimenting with changing this to a scan over binary search depth, which could potentially yield a vast improvement in efficiency for multiple searches.
Note that
searchsorteditself is currently implemented as awhile_loop, so it will suffer from this as well. I've been thinking of experimenting with changing this to ascanover binary search depth, which could potentially yield a vast improvement in efficiency for multiple searches.
I guess it's not suffering from it as much as the depth of every individual loop you do is shallower than my unique one.
I expect lax.scan is also going to suffer from the slowdown then?
I made some searchsorted improvements here: #3873
So when using your method, the repeated values bit is easy to fix using either:
np.unique with return_index=True on xp at the very beginningi = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
xp_i = xp[i]
xp_i_1 = xp[i-1]
fp_i = fp[i]
fp_i_1 = fp[i-1]
f = np.where(xp_i > xp_i_1, fp_i_1 * (xp_i - x) + fp_i * (x - xp_i_1) / (xp_i - xp_i_1), fp_i)
not sure what's the most efficient (how's np.unique behaving on GPU/TPU? Shape depends on values...)
I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.
Also, very minimal, but that would probably be a bit more readable (and save a multiplication but that's secondary):
fp_i_1 + (x - xp_i_1) * (fp_i - fp_i_1) / (xp_i - xp_i_1)
I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.
For example:
xp = np.linspace(0, 10, 10)
fp = np.sin(xp)
x = np.linspace(0, 10, 100)
y1 = np.interp(x, xp, fp, period=5)
y2 = interp(x, xp, fp, period=5)
import matplotlib.pyplot as plt
plt.plot(xp % 5, fp, '.k', label='input')
plt.plot(x % 5, y1, '.', label='np.interp')
plt.plot(x % 5, y2, '.', label='interp')
plt.legend();

I don't understand what the problem with the period boundaries would be. Seems like everything should work just fine to me.
For example:
xp = np.linspace(0, 10, 10) fp = np.sin(xp) x = np.linspace(0, 10, 100) y1 = np.interp(x, xp, fp, period=5) y2 = interp(x, xp, fp, period=5) import matplotlib.pyplot as plt plt.plot(xp % 5, fp, '.k', label='input') plt.plot(x % 5, y1, '.', label='np.interp') plt.plot(x % 5, y2, '.', label='interp') plt.legend();
I think you actually made a typo when you copy pasted that part of the code from the numpy function :)
You wrote
xp = np.concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[1:]])
In the numpy code it's
xp = np.concatenate([xp[-1:] - period, xp, xp[0:1] + period])
fp = np.concatenate([fp[-1:], fp, fp[0:1]])
(note that I don't know why they wrote [0:1] instead of [:1])
When I fix it it works just fine

If I'd copy-pasted, there wouldn't be a typo :grin:
Another set of eyes is always helpful – thanks!
I've prepared a PR with a searchsorted-based interp in #3949, because we've had some feature requests for it.
@AdrienCorenflos – Please consider this to be reference implementation, and if and when it's merged, feel free to prepare a PR improving on it. Hopefully the test suite that is part of that PR will be useful!
@jakevdp solved the above
I am leaving it here for reference, but there might be a way to leverage associative scan to implement this more efficiently:
Caveat being that I don't think it has been done in any mainstream library.
Adrien
Most helpful comment
If I'd copy-pasted, there wouldn't be a typo :grin:
Another set of eyes is always helpful – thanks!