Would it be useful to others (aside from me) to have better support of sorted and unsorted segment_sums?
https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum
https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
In numpy one way to do unsorted segment sums is to sort then call np.add.reduceat, but this doesn't seem to be in jax or autograd:
>>> import jax.numpy as np
np.add
>>> np.add
<function _one_to_one_binop.<locals>.<lambda> at 0x1430cdbf8>
>>> np.add.reduceat
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'function' object has no attribute 'reduceat'
Does jax.ops.index_add do what you need? It seems roughly equivalent to unsorted_segment_sum.
https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add
Ah, I guess it does not yet because it does not support advanced indexing. We should fix that.
Nice! Would it be hard to support the basic operators +,-,*,/ on top of add? Not sure if there's associativity assumptions here in the underlying implementation or not
Edit: Nevermind: pretty sure - and / will be very hard to do correctly now that I think about it a little more for more than a single operation (since JAX says that it applies all the updates)
After playing around a little bit with the numpy version of advanced index assignment, I take it the idea is to implement the unsorted/sorted segment sums using something similar to:
ys = np.arange(5)
idxs = [0,0,1,0,1]
sums = np.zeros(2) # two slots
seg_sum = jax.ops.index_add(sums, jax.ops.index[idxs], ys)
edit: of course the jax version will work because it actually accumulates properly as opposed to just applying the last element
Yup, that's the idea!
Most helpful comment
Yup, that's the idea!