Jax: Unimplemented NumPy core functions

Created on 11 Dec 2018  路  26Comments  路  Source: google/jax

If you implement a feature with a PR, you get to ring the bell! 馃敂
(list generated with np.UNIMPLEMENTED_FUNCS with #69)

High level categories:

Other stuff:

  • [x] np.alltrue
  • [x] np.percentile
  • [x] np.deg2rad
  • [x] np.quantile
  • [x] np.degrees
  • [x] np.empty_like
  • [x] np.trapz
  • [x] np.digitize
  • [x] np.empty
  • [x] np.meshgrid
  • [x] np.exp2
  • [ ] np.delete
  • [ ] np.insert
  • [x] np.append
  • [x] np.fabs
  • [x] np.arange
  • [x] np.float_power
  • [ ] np.find_common_type
  • [x] np.take
  • [x] np.fmax
  • [ ] np.choose
  • [x] np.fmin
  • [x] np.fmod
  • [x] np.frexp
  • [x] np.inner
  • [x] np.gcd
  • [x] np.heaviside
  • [ ] np.putmask (note: this modifies array in-place)
  • [x] np.hypot
  • [x] np.isinf
  • [x] np.isnan
  • [x] np.lcm
  • [x] np.isposinf
  • [x] np.ldexp
  • [x] np.einsum
  • [x] np.histogram_bin_edges
  • [x] np.isneginf
  • [x] np.histogram
  • [x] np.fix
  • [x] np.histogramdd
  • [x] np.full_like
  • [x] np.flip
  • [x] np.count_nonzero
  • [x] np.log10
  • [x] np.asarray
  • [x] np.iscomplex
  • [x] np.average
  • [ ] np.asanyarray
  • [x] np.isreal
  • [x] np.bincount
  • [ ] np.ascontiguousarray
  • [x] np.iscomplexobj
  • [x] np.log2
  • [x] np.piecewise
  • [x] np.isrealobj
  • [x] np.select
  • [x] np.logaddexp
  • [ ] np.require (deals with C/F contiguity not relevant for jax)
  • [ ] np.copy
  • [x] np.nan_to_num
  • [x] np.logaddexp2
  • [x] np.gradient
  • [x] np.diff
  • [x] np.unpackbits
  • [x] np.take_along_axis
  • [x] np.interp
  • [ ] np.common_type
  • [x] np.correlate
  • [x] np.unwrap
  • [x] np.apply_along_axis
  • [x] np.apply_over_axes
  • [x] np.convolve
  • [x] np.ix_
  • [x] np.outer
  • [x] np.trim_zeros
  • [x] np.tensordot
  • [x] np.dstack
  • [x] np.roll
  • [x] np.rollaxis
  • [x] np.modf
  • [x] np.array_split
  • [x] np.pad
  • [ ] np.min_scalar_type
  • [x] np.hsplit
  • [x] np.result_type
  • [x] np.array_str
  • [x] np.unravel_index
  • [x] np.vsplit
  • [x] np.dsplit
  • [x] np.broadcast_to
  • [x] np.promote_types
  • [x] np.positive
  • [x] np.broadcast_arrays
  • [x] np.kron
  • [x] np.can_cast
  • [x] np.cov
  • [x] np.tile
  • [x] np.rad2deg
  • [x] np.radians
  • [x] np.reciprocal
  • [x] np.linspace
  • [x] np.logspace
  • [x] np.corrcoef
  • [x] np.geomspace
  • [x] np.blackman
  • [x] np.rint
  • [x] np.hanning
  • [x] np.hamming
  • [x] np.signbit
  • [x] np.atleast_3d
  • [x] np.i0
  • [x] np.kaiser
  • [ ] np.spacing
  • [x] np.issubsctype
  • [ ] np.poly
  • [x] np.sinc
  • [x] np.sqrt
  • [x] np.roots
  • [x] np.issubdtype
  • [x] np.block
  • [x] np.diag_indices
  • [ ] np.polyint
  • [x] np.square
  • [x] np.diag_indices_from
  • [x] np.median
  • [x] np.polyder
  • [ ] np.polyfit
  • [x] np.arccos
  • [x] np.tan
  • [x] np.arccosh
  • [ ] np.resize
  • [x] np.polyval
  • [x] np.arcsin
  • [x] np.polyadd
  • [x] np.diagonal
  • [x] np.polysub
  • [x] np.trace
  • [x] np.trunc
  • [x] np.cross
  • [x] np.polymul
  • [x] np.indices
  • [ ] np.polydiv
  • [x] np.shape
  • [x] np.sum
  • [x] np.any
  • [x] np.rot90
  • [x] np.all
  • [x] np.identity
  • [x] np.cumsum
  • [x] np.fliplr
  • [x] np.flipud
  • [x] np.ptp
  • [x] np.amax
  • [x] np.eye
  • [x] np.array_equal
  • [x] np.array_equiv
  • [x] np.cbrt
  • [x] np.amin
  • [x] np.diag
  • [x] np.diagflat
  • [x] np.ravel_multi_index
  • [x] np.prod
  • [x] np.tri
  • [x] np.arcsinh
  • [x] np.cumprod
  • [x] np.tril
  • [x] np.ndim
  • [x] np.triu
  • [x] np.arctan
  • [x] np.size
  • [x] np.vander
  • [x] np.arctan2
  • [x] np.around
  • [x] np.histogram2d
  • [x] np.arctanh
  • [x] np.mask_indices
  • [x] np.tril_indices
  • [x] np.tril_indices_from
  • [x] np.triu_indices
  • [x] np.product
  • [x] np.triu_indices_from
  • [x] np.copysign
  • [x] np.cumproduct
  • [x] np.sometrue
enhancement good first issue

All 26 comments

The list needs a bit of refinement, since some things were already implemented and other things are out of scope, like to do with loading/saving/string manipulation. I deleted a few of the out-of-scope ones, and marked as done some things that were already included and some things that are coming in #96.

Is there a guide on how to write customized primitive functions (or is that even possible?). I am looking for the equivalent to autograd's defvjp.

A quick glance into the code suggests everything is delegated to XLA -- if an operator is nontrivial then it has to be first implemented in XLA?

Great question. It's possible, and JAX's internals are actually very similar to Autograd's in this respect, but the API is a bit different. We need to write up how to do this, and maybe add a convenience layer to the API.

There are a few different use cases that we'd want to cover. One is that you just want to define a custom VJP for a function that is otherwise implemented in terms of NumPy code, like in the Autograd tutorial section. But another use case is to define a custom primitive and VJP for some external routine, like a Cython or Fortran function that isn't implemented in terms of NumPy. If you have an intended use case, does it fit into one of those categories, or is it another?

Let's track this in #116. If you have a simplified example of what you want to do, post it there and we'll use it can help guide our convenience API and/or explanation.

Looking at contributing a function or two from this list.

What is the approach to functions with nonstatically-sized return values? For example, setxor1d.

My understanding is XLA needs to compile for each specific shape. Is there a pattern for avoiding worse-possible-case behavior, or should functions like setxor1d be removed from this list?

@mattjj I am looking to take a stab at the issue. How do I start?

@souravsingh Awesome! In general, lots of these are fairly easy to do. Pick one you like off the list, and take a look at previous PRs implementing numpy ops. For example, I just sent out https://github.com/google/jax/pull/298 adding np.cumprod/cumsum support. Usually, it's just a question of implementing a numpy op in terms of the primitives in the lax.py library.

@mattjj I am looking to implement np.unique. Is this something which is a priority at the moment?

Sounds like a great idea! The core team isn't working on any numpy operations at the moment (we're mostly focused on improving performance in the XLA runtime, parallel computation, and Cloud TPU support) so that's a great place to make contributions.

np.unique is tricky though because AIUI the shape of the output depends on the values of the input. Since XLA can only express computations with static shapes (meaning independent of the input values), we'd have to re-compile the XLA computation for every new value of the input, and moreover it might be tricky to express in terms of XLA HLO (i.e. in terms of lax library calls).

@hawkinsp any thoughts on implementing np.unique?

@mattjj since this would be my first hand at this, I should probably look at numpy functions that are less tricky compared to unique?

Any suggestions would be really helpful.

@mattjj Apologies for the constant posts. I have decided to implement np.diff which seems fairly easier. I was just wondering whether there is a guide to developing the testcases in the lax_numpy_test.py file. Thanks!

+1 to adding np.cov

+1 to np.percentile

@murphyk np.cov coming in #983

@murphyk quantile and percentile are now present, at least for interpolation='linear' and floating-point types.

fyi, np.rank is deprecated in favor of np.ndim

Anyone interested in np.convolve I've commented on possibly using lax.conv. It would require reshaping and flipping the inputs. I've described it in #1561.

@jessebett I just created a PR for it #1831

I'd also like to get involved in this process of completing the remaining numpy ops. As a first pass, I am looking at implementing something really simple, to dip my toes in the water so I can be more familiar with the codebase before tackling something more serious!

I will look at np.alen unless there are any issues with that :)

Created a PR for alen and isscalar in #1924

I took another pass at deleting irrelevant functions:

  • anything currently deprecated by numpy
  • anything that relies upon details of NumPy's internal array storage (strides, etc)
  • anything related to np.matrix.
  • anything that modifies arguments in place

I think most of the rest should be relevant, though some will be tricky (e.g., with output dependent shapes).

Hi - I'm interested in getting nanmean done. Am I right in suggesting that this could be done by somehow piggybacking off nansum? If so, are there any tips that may be worth heeding in this regard?

NumPy should be a pretty good model for most NaN functions, all of its NaN functions are written in pure Python.

To make things a little easier to keep track of, I split some functions off into a handful of sub-issues:

  • np.linalg: #1999
  • np.fft: #1877
  • searching: #2080
  • sorting: #2079
  • sets: #2078
  • NaN functions: #2077

Hi, should np.asanyarray be removed since it simply becomes np.asarray when there is no subclassing?

np.cast is not implemented in Jax. this is a dict of lambdas that allows casting to certain types.

@jameskirkpatrick I agree that np.cast exists in my copy of NumPy, but it is not a documented API as far as I can tell. Are we sure it's an intentional, non-deprecated NumPy API? @shoyer any thoughts?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shyoshyo picture shyoshyo  路  26Comments

murphyk picture murphyk  路  31Comments

samuela picture samuela  路  27Comments

zymone picture zymone  路  20Comments

martiningram picture martiningram  路  21Comments