Jax: segment_sum primitives / advanced indexing in jax.ops.index_add

Created on 30 Apr 2019  路  5Comments  路  Source: google/jax

Would it be useful to others (aside from me) to have better support of sorted and unsorted segment_sums?

https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum
https://www.tensorflow.org/api_docs/python/tf/math/segment_sum

In numpy one way to do unsorted segment sums is to sort then call np.add.reduceat, but this doesn't seem to be in jax or autograd:

>>> import jax.numpy as np
np.add
>>> np.add
<function _one_to_one_binop.<locals>.<lambda> at 0x1430cdbf8>
>>> np.add.reduceat
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'function' object has no attribute 'reduceat'
enhancement

Most helpful comment

Yup, that's the idea!

All 5 comments

Does jax.ops.index_add do what you need? It seems roughly equivalent to unsorted_segment_sum.

https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add

Ah, I guess it does not yet because it does not support advanced indexing. We should fix that.

Nice! Would it be hard to support the basic operators +,-,*,/ on top of add? Not sure if there's associativity assumptions here in the underlying implementation or not

Edit: Nevermind: pretty sure - and / will be very hard to do correctly now that I think about it a little more for more than a single operation (since JAX says that it applies all the updates)

After playing around a little bit with the numpy version of advanced index assignment, I take it the idea is to implement the unsorted/sorted segment sums using something similar to:

ys = np.arange(5)
idxs = [0,0,1,0,1]
sums = np.zeros(2) # two slots
seg_sum = jax.ops.index_add(sums, jax.ops.index[idxs], ys) 

edit: of course the jax version will work because it actually accumulates properly as opposed to just applying the last element

Yup, that's the idea!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sussillo picture sussillo  路  3Comments

kunc picture kunc  路  3Comments

sursu picture sursu  路  3Comments

murphyk picture murphyk  路  3Comments

clemisch picture clemisch  路  3Comments