Jax: Implement all numpy.fft operations

Created on 17 Dec 2019  Â·  15Comments  Â·  Source: google/jax

As noted by https://twitter.com/peter_melchior/status/1206645602546790400 this is a notable gap in jax.numpy. We support the core fftn/ifftn routines, but basic routines like fft are the more obvious entry point, so we should make those work too.

Standard FFTs

  • [x] fft
  • [x] ifft
  • [x] fft2
  • [x] ifft2
  • [x] fftn
  • [x] ifftn

Real FFTs (see https://github.com/google/jax/issues/1839)

  • [x] rfft
  • [x] irfft
  • [x] rfft2
  • [x] irfft2
  • [x] rfftn
  • [x] irfftn

Hermitian FFTs (these seem less useful, but they are easy to implement)

  • [x] hfft
  • [x] ihfft

Helper routines

  • [x] fftfreq
  • [x] rfftfreq
  • [x] fftshift (https://github.com/google/jax/pull/1850)
  • [x] ifftshift (https://github.com/google/jax/pull/1850)

These would be good opportunities for new contributor(s):

  1. Easy: write fft/ifft/fft2/ifft2 in terms of fftn/ifftn.
  2. Easy: write real-valued FFTs in terms of the complex valued FFTs.
  3. Easy: write fftfreq/rfftfreq with jax.numpy functions.
  4. Harder: write real-valued FFTs and their derivatives in terms of the XLA's RFFT/IRFFT.
  5. No idea: Hermitian FFTs. I don't think the underlying primitives for these exist in XLA, and I'm not sure how widely used these actually are (I've never needed them).

cc @dynamicwebpaige

enhancement good first issue

Most helpful comment

With the latest additions of hfft and ihfft merged, we now have a complete implementation of the numpy.fft module in JAX!

Thanks everyone for contributing!

All 15 comments

I'll take fft and ifft on.

Please see https://github.com/google/jax/pull/1926 for those

Do we want to add TPU versions of all of these as well? Right now TPU only supports 1d FFT but the higher dimensional ones can be implemented by just applying the 1d FFT per-axis.

I think higher dimensional FFTs on TPU would be very useful to have! The main question is where the logic should live -- in JAX or lower down in the XLA stack?

The dream would be in the XLA stack, but that is beyond my abilities. Perhaps we should poke the XLA team first and see what they think.

See b/146436467 for the Google internal feature request.

I would also be OK with adding a temporary work around in JAX -- the code for this should be pretty minimal.

happy to take a look at hfft and ihfft

Hi,

is there any reference for hfft and ihfft? I can try to take a crack at it.

Thanks

Looking at the NumPy implementation is generally a good place to start

On Wed, May 27, 2020 at 3:53 PM gaurav pathak notifications@github.com
wrote:

Hi,

is there any reference for hfft and ihfft? I can try to take a crack at it.

Thanks

—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1877#issuecomment-634985358, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AAJJFVWDK2IE3AYLXFQWQ4LRTWKUZANCNFSM4J37XO5A
.

okay I will start with those

Thanks

Hi, I've started to look into the hfft and ihfft and noticed there are no XLA primitive bindings, should we wait for them to be added?

I believe NumPy implements hfft and ihfft in terms of other fft operations.
We could probably do the same for JAX.

On Fri, Jul 3, 2020 at 1:05 PM AlexDragan notifications@github.com wrote:

Hi, I've started to look into the hfft and ihfft and noticed there are no
XLA primitive bindings, should we wait for them to be added?

—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1877#issuecomment-653665732, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AAJJFVUMUGABH5ILOIDU6C3RZY2YRANCNFSM4J37XO5A
.

Should I then define these directly in the fft.py file?

Sure

On Fri, Jul 3, 2020 at 1:25 PM AlexDragan notifications@github.com wrote:

Should I then define these directly in the fft.py file?

—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1877#issuecomment-653669409, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AAJJFVUPMTMXEHFJEGRHGMDRZY5EXANCNFSM4J37XO5A
.

With the latest additions of hfft and ihfft merged, we now have a complete implementation of the numpy.fft module in JAX!

Thanks everyone for contributing!

Hi,

I was recently using fft2 from jax.numpy, and found out I couldn't use the 'ortho' norm. It could be nice to add it as well. I am willing to open a PR for this, since it should be relatively simple.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ricardobarroslourenco picture ricardobarroslourenco  Â·  35Comments

christopherhesse picture christopherhesse  Â·  32Comments

shoyer picture shoyer  Â·  35Comments

martiningram picture martiningram  Â·  21Comments

dwang55 picture dwang55  Â·  22Comments