A simple replicate (in CPU) is
np.sum(np.ones(1000000) * 0.1).copy()
which gives the result 100958.34 while it is expected to return a number around 1000000 +- 10.
I thought I had repro'd this earlier with your exact numbers, but I'm unable to do so now:
In [1]: import jax.numpy as np
print
In [2]: print np.sum(np.ones(int(1e6)) * 0.1)
/.../jax/lib/xla_bridge.py:122: UserWarning: No GPU found, falling back to CPU.
warnings.warn('No GPU found, falling back to CPU.')
100022.35
Would you mind double-checking that you're on the most recent jaxlib on pypi (pip install --upgrade jaxlib) and see if your numbers are the same? Is the number in my repro still surprising?
@mattjj I just double check that I use the latest version of jaxlib. Just discover that 100022.35 is the answer when I enable fast_math mode. While 100958.34 is the answer when I disable fast_math. It seems like a problem of LLVM but I am not sure how to report it. In addition, 100022 is still surprising to me in float32 mode (as compared to other frameworks).
Aha, perfect! I'll ask the XLA team on our internal bug tracker.
Thanks, Matt!
@jlebar of the XLA team made this acute observation:
Huh, this is counterintuitive to me, but I also suspect it's WAI.
It's counterintuitive because I'd have expected precision loss can't occur when the relative difference between the smallest summand (0.1) and the total sum (1e5) is less than 2^24 (number of mantissa bits in f32). In this case, log2(1e5/0.1) = 20 < 24, so...there should be no precision loss?
But the above analysis assumes that the summand and total sum are precisely-representable as floating-point numbers. My guess is that this behavior you're seeing is related to the fact that 0.1 can't be precisely expressed in float. If you change it to 0.125 or 0.0625, do you get a precise answer?
And indeed, as he predicts:
In [1]: import jax.numpy as np
In [2]: print np.sum(np.ones(int(1e6)) * 0.125)
jax/lib/xla_bridge.py:144: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
125000.0
The fact that enabling fast math makes the precision go _up_ seems like just coincidence, since fast math can do a lot of things.
I think this might be working as intended, but if you have a problem around this example, maybe we should figure out a workaround. What do you think?
I don't have problems with that example, and also for numbers 1/2^n. For all other numbers, the error is large. In PyTorch, the output of original example is 100002.5, in numpy, the output is 100000.1. So I believe that there is something wrong with np.sum() in CPU. But I also think that this is related to precision.
One option is for us to solve this in JAX, eg to lower np.sum into a tree of sums. Another would be to press XLA:CPU for them to do it underneath HLO. We’ll keep talking to them...
I just test out the assumption that this is WAI with
s = 0
for i in range(10000):
s = s + np.array(0.1, dtype=np.float32)
and get 999.9029 (of course, not True), which is equal to np.sum(np.ones(int(1e4)) * 0.1). So it seems like everything works as intended, and it is just the imprecision accumulates faster than other frameworks. IMO, this aligns well with the intention of XLA is for making things fast, which will have a trade-off as the issue in this topic.
I just built JAX with GPU to test out this issue. The performance in GPU is pretty good, which returns array(100001.39, dtype=float32). Searching from internet, I found that @ekelsen has some insights about this problem. It seems that CPU summation algorithm is not as good as the one in GPU.
np.sum uses a numerically stable pairwise summation algorithm (https://github.com/numpy/numpy/pull/3685)
Tensorflow's CPU implementation should also now do the same thing (exact commit is somewhere in Eigen, ask Rasmus). As does the GPU implementation.
I think the issue here is that XLA does not do this. So basically, the problem has been fixed everywhere except XLA and probably XLA should implement numerically stable summation algorithms as well...
I think the XLA:CPU folks are going to look into improving this!
I believe XLA/GPU is doing a tiled reduction but CPU is not, which explains why the precision is worse on CPU than GPU.
@fehiepsi Can you give a sense of how urgent this issue is to you?
I think even on GPU the tiles are accumulated atomically, which asides from being non-deterministic, is also less stable than a final tree.
I think even on GPU the tiles are accumulated atomically, which asides
from being non-deterministic, is also less stable than a final tree.
That is correct.
Changing this would be a large project, as this is some of the most
carefully-tuned and performance-sensitive code in XLA:GPU.
On Mon, Apr 15, 2019 at 8:06 AM ekelsen notifications@github.com wrote:
I think even on GPU the tiles are accumulated atomically, which asides
from being non-deterministic, is also less stable than a final tree.—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/585#issuecomment-483290703, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAJMh5uDhoWkR5WN-a6Gl_Tinl-OLu2Pks5vhJVWgaJpZM4ciZmb
.
@hawkinsp Thanks, it is not urgent to me. The precision in fast-math mode is enough for my applicaition. :)
Close because the precision of np.sum is much better now.
Most helpful comment
np.sum uses a numerically stable pairwise summation algorithm (https://github.com/numpy/numpy/pull/3685)
Tensorflow's CPU implementation should also now do the same thing (exact commit is somewhere in Eigen, ask Rasmus). As does the GPU implementation.
I think the issue here is that XLA does not do this. So basically, the problem has been fixed everywhere except XLA and probably XLA should implement numerically stable summation algorithms as well...