Dear jax team,
I'd like to use jax alongside other tools running on GPU in the same pipeline. Is there a possibility to "encapsulate" the usage of jax/XLA so that the GPU is freed afterwards? Even if I would have to copy over the DeviceArrays into numpy manually.
Maybe something like:
with jax.Block():
result = some_jitted_fun(a, b, c)
result = onp.copy(result)
I can imagine the (design of) handling of objects and their GPU memory is not straightforward, if not practically impossible. Could I at least tell jax to use the GPU only incrementally instead of filling the memory completely on import?
Nice idea! We've had a few related requests recently, and I think we can provide better tools here. (Actually, JAX is pretty tiny, and the way it handles GPU memory (and all backend memory) is pretty straightforward, so we should have the right tools at our disposal!)
As to freeing up memory completely on import, though, have you taken a look at the GPU memory allocation note in the docs? You can prevent JAX from allocating everything up-front, or even control the fraction of GPU memory it allocates up-front. Could that help?
That's great to hear, thank you! A programmatical solution sometime in the future would be very cool, but I think XLA_PYTHON_CLIENT_PREALLOCATE=false could do the trick for now. But I assume only affects pre-allocation, not freeing the memory afterwards?
I use jax in designated "blocks" in the pipeline, so freeing and re-allocating memory _should_ be not that bad for performance in my usecase (if it is possible).
Edit: In my usecase, the memory is not freed with XLA_PYTHON_CLIENT_PREALLOCATE=false. This leaves the GPU useable in principle by other tools, but it's not a great solution TBH. So the programmatical solution would be very cool! :wink:
But I assume only affects pre-allocation, not freeing the memory afterwards?
Device memory for an array ought to be freed once all Python references to it drop, i.e. upon destruction of any corresponding DeviceArray. You could encourage this explicitly with del my_device_array, if Python scope isn't already lined up with your pipeline "blocks."
In your example, the line
result = onp.copy(result)
will drop the only reference to a DeviceArray (from the previous line), and should clear the device memory associated with the value of some_jitted_fun(a, b, c), for the same reason.
Thank you @froystig, that sounds like a great pythonic solution! However, I don't see that behavior. In the snippet below, GPU memory is not freed after del arr. Am I missing something?
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax.numpy as np
arr = np.arange(int(1e9))
arr += 1
del arr
XLA_PYTHON_CLIENT_PREALLOCATE=false does only affect pre-allocation, so as you've observed, memory will never be released by the allocator (although it will be available for other DeviceArrays in the same process).
You could try setting XLA_PYTHON_CLIENT_ALLOCATOR=platform instead. This will be slower, but will also actually deallocate when a DeviceArray's buffer is released. I forgot to mention this in the GPU memory allocation note, my bad! I'll update the note to include this.
Thank you @skye, that solved it!
I'm not against closing this, but since the runtime does increase quite a bit (17s vs 13s in my case, 30% slower), I think there is at least _some_ demand for having blockwise preallocation with clearing the memory afterwards (similar to my crude snippet in the beginning).
Instead of, you know, instantly clearing memory once a function (for example) returns. For some nested framework this makes a noticeable performance difference.
Would it work to have a (slow) function call that tries to free unused memory? I say "try" because the default allocator allocates large regions, with multiple DeviceArrays possibly occupying a single region, so freeing one DeviceArray may not allow us to free the whole region. We could also have an even slower function that copies DeviceArrays around to free as much memory as possible.
I'm also not sure what you mean exactly by blockwise preallocation, can you explain the API you have in mind?
Thank you for your feedback! If "slow" means in the order of 0.1s to 1s that would be great.
Although I would not call my crude idea an API, I thought of using that with block to enable behavior like XLA_PYTHON_CLIENT_PREALLOCATE=false with XLA_PYTHON_CLIENT_ALLOCATOR=default within that block. After the block, the memory can be garbage collected like XLA_PYTHON_CLIENT_ALLOCATOR=platform. Maybe this would be similar to your first suggestion.
With my very limited understanding of XLA memory handling this would incrementally use GPU memory without the need to free every little piece after each small inner function call (fragmentation (?)), resulting in better performance. When done with the work and having a handfull of result arrays, copy them to host and free GPU. Please feel free to tell me if this does not make sense!
I was asked to post a utility function I use to delete DeviceArrays in colab:
import gc
import jax
def reset_device_memory(delete_objs=True):
"""Free all tracked DeviceArray memory and delete objects.
Args:
delete_objs: bool: whether to delete all live DeviceValues or just free.
Returns:
number of DeviceArrays that were manually freed.
"""
dvals = (x for x in gc.get_objects() if isinstance(x, jax.xla.DeviceValue))
n_deleted = 0
for dv in dvals:
if not isinstance(dv, jax.xla.DeviceConstant):
try:
dv._check_if_deleted() # pylint: disable=protected-access
dv.delete()
n_deleted += 1
except ValueError:
pass
if delete_objs:
del dv
del dvals
gc.collect()
return n_deleted
this and some memory reporting utils in a gist:
https://gist.github.com/levskaya/37f72b76bd5c72f9e5ce48ce154a9246
and in a public colab:
https://colab.research.google.com/drive/1odOdMbbp-47WyDhjIfTDWukOBTSUt5Q6
FYI: the memory counting in the linked Gist is no longer completely accurate because e.g. jax.numpy.zeros does not actually allocate memory but it's counted (you can make the reported memory usage arbitrarily large).
Most helpful comment
XLA_PYTHON_CLIENT_PREALLOCATE=falsedoes only affect pre-allocation, so as you've observed, memory will never be released by the allocator (although it will be available for other DeviceArrays in the same process).You could try setting
XLA_PYTHON_CLIENT_ALLOCATOR=platforminstead. This will be slower, but will also actually deallocate when a DeviceArray's buffer is released. I forgot to mention this in the GPU memory allocation note, my bad! I'll update the note to include this.