Howdy! Is there an easy way (or a hard way) to configure Ax so that botorch runs on GPU? I'm optimizing a cheap function with lots of trials, and so spending many hours of CPU time in total waiting for Ax to generate candidates. Docs seem to say it's possible, but I haven't found anything straightforward like a 'device' flag.
I also know GPU's aren't magic bullets, so maybe this wouldn't get the perf benefits I'm hoping for? I'm guessing that for large GP's at least the N^3 covariance matrix inversion could run much faster on GPU and make up for any copying overhead.
Hi @leopd, what API are you using (https://ax.dev/docs/api.html)? Ax certainly supports running on the GPU, but we may not expose this in a very convenient way for the simpler APIs at this point (which means we should fix that).
I'm guessing that for large GP's at least the N^3 covariance matrix inversion could run much faster on GPU and make up for any copying overhead.
Yes, this can certainly help speed things up. For larger training data sizes we also use gpytorch's linear CG method for solving the linear system that only uses matrix-vector multiplications and therefore benefits a lot from using a GPU.
I'm mostly using the "service API" but sometimes the "loop API" to get started. Haven't really dug into the "developer API".
Got it. One user-facing solution here would be to just pass optional kwargs torch_device and torch_dtype into the AxClient constructor. Would that be reasonable?
Alternatively, once https://github.com/pytorch/pytorch/issues/27878 goes in, we can also have AxClient use the global context.
Hi, @leopd! Here is how you can set correct torch_device and torch_dtype for a Service API (AxClient) optimization currently: https://gist.github.com/lena-kashtelyan/4ad4e19e8e8bb12de7ef5a27659625a2. I offer two solutions there, one a bit more principled and one a bit faster and hackier : )
We will also consider adding these settings to our user-facing APIs (e.g. as kwargs to AxClient, as @Balandat suggested above), but in the meantime the solutions in my notebook should be a fine workaround for you.
Closing the issue as my notebook should cover the question asked; @leopd, feel free to reopen for any follow-up!
Thanks, @lena-kashtelyan that's exactly what I need!
Something more convenient would be nice, but there's always a trade-off in added complexity when expanding the API surface. Personally, I'd love to see pytorch/pytorch#27878 go in and make this universally configurable for all libraries not just this one.
Most helpful comment
Hi, @leopd! Here is how you can set correct
torch_deviceandtorch_dtypefor a Service API (AxClient) optimization currently: https://gist.github.com/lena-kashtelyan/4ad4e19e8e8bb12de7ef5a27659625a2. I offer two solutions there, one a bit more principled and one a bit faster and hackier : )We will also consider adding these settings to our user-facing APIs (e.g. as kwargs to
AxClient, as @Balandat suggested above), but in the meantime the solutions in my notebook should be a fine workaround for you.