Hi everyone
Love the project! I'm using Jax to write differentiable physics simulators for my master's thesis. All my code runs fine on CPU, but I'd like to speed things up by using my GPU.
My setup is as follows:
I installed Jax with GPU support as described in the readme. My simple Jax programs seem to run fine on GPU. However, when I try to run my more complex simulations, the program starts as usual and then after a few seconds this error appears:
Process finished with exit code 137 (interrupted by signal 9: SIGKILL)
All my print statements work up until the line where grad(complex_loss_function)(...) is evaluated. When I google the error, some people seem to suggest excessive memory usage. But this does not seem to be the case. When I run my simulation and look at the %MEM of python (using the top command), it doesn't seem to exceed 10%. %CPU of python does spike to 100% just before the program is terminated.
I also tried building Jax from source, but the same issue remains.
I'm not sure what I could do next. I'd gladly investigate this issue further if someone could give me some pointers.
Just to be completely clear: the exact same program runs fine, with the same grad(complex_loss_function)(...) and without SIGKILL, when I run it in an environment with the Jax CPU only version.
Code to reproduce the issue, sorry that it's not that minimal. It runs for about 10 seconds before my OS kills it.
Maybe try running your script on GPU, and check dmesg or /var/log/syslog after it dies? I have no idea what is causing it to be killed unfortunately, but maybe there are some clues in there...
Thanks for the tip! I found this in the dmesg output:
[ 829.166013] Out of memory: Kill process 4449 (python) score 650 or sacrifice child
[ 829.166039] Killed process 4449 (python) total-vm:18418520kB, anon-rss:5601772kB, file-rss:72656kB, shmem-rss:10240kB
If I'm interpreting this correctly, Python seems to be requesting more than 18GB of memory.
18GB of virtual memory (total-vm), but only ~6GB of actual memory usage. When I run your script locally, I see ~14GB max usage on either GPU or CPU (which is kind of weird). Do you have this much memory free on your machine? Probably the fastest way forward for now is to make sure you do :) It'd also be good to understand why it's using so much memory, but I won't have time to dig deeper until after the holidays.
What do you use to measure memory usage? For the CPU version I tried this package, which tells me the program uses max 671MB of memory. This seems like a reasonable amount for the script I am trying to compile. The compile time is about 100 seconds.
For the GPU version: I have 8GB of RAM and had a swapfile of 600MB so that explains why my OS was killing it. However I've increased my swapfile to 16GB but I still can't get the program to compile. It just keeps running for more than 20 minutes and then my PC freezes completely.
Thanks already for the help. Happy holidays!
I was just using htop to look at the memory usage. I'm guessing that the memory-profiler module is under reporting given that the process is being killed, although I have no idea why.
Your PC is probably freezing because it's using swap. It's better to kill the process than start swapping IMO.
I am surprised it's using this much memory just to compile... I will try to investigate more.
I am also struggling with SIGKILL involving grad of scan. With the help of some print statements I've been able to track down to two places in the python code that both involve GetTupleElement, i.e. a print statement before the calls to GetTupleElement goes through, but the print statement immediately after it does not. Narrowing it down further would require rebuilding jaxlib which I don't really have time to do right now.
The first place is in xla_destructure. The call to xla_destructure is made at the end of jaxpr_subcomp, and in my case when it is fatal it's always while handling a while primitive (i.e. eqn.primitive.name == "while"). The number of outputs extracted from the tuple is a few hundred in my case.
The second place is in the translation rule for while loops. Again the number of outputs is a few hundred.
Which one of these locations is the immediate cause of the SIGKILL depends on my hyperparameters, but it's pretty suspicious that an ostensibly lightweight operation like tuple indexing should consistently make the fatal allocation.
What I said above doesn't seem to be true for @Victorlouisdg's code; it crashes elsewhere and I'm not sure where. FWIW I did find that much less memory is required (8GB vs 20GB) if I remove the nested jits (i.e. all the @jit decorations save for the one on grad).
@cooijmanstim thanks for digging in! Can you provide a repro for your SIGKILL?
I think I found the issue with @Victorlouisdg's code: JAX is serializing huge amounts of metadata into the XLA program generated via jit here (str(eqn) can be quite large). I'll spin up a PR to remove that metadata for now.
cc @jekbradbury
@cooijmanstim if you're able to install jax from github (I use pip install -e . inside the jax checkout), can you trying pulling https://github.com/google/jax/pull/1943 in and see if that fixes your issue too?
We merged it into master last night, but didn't update pypi. I just updated pypi to jax 0.1.56, so pip install --upgrade jax should pull in #1943. (No need to update jaxlib.)
Hey that does fix my issue. Good catch! I was stuck on this for a long time, many thanks for resolving it!
Most helpful comment
@cooijmanstim if you're able to install jax from github (I use
pip install -e .inside the jax checkout), can you trying pulling https://github.com/google/jax/pull/1943 in and see if that fixes your issue too?