Pyro: JIT for differentiable_loss

Created on 30 Nov 2018  路  3Comments  路  Source: pyro-ppl/pyro

Currently, we have jit version of loss_and_grads. It would be nice to have that version for differentiable_loss too. I expect that it would be similar?

jit question

Most helpful comment

Just use pyro.ops.jit.trace:

loss = TraceEnum_ELBO(...).differentiable_loss
jit_loss = pyro.ops.jit.trace(functools.partial(loss, model, guide))

All 3 comments

Just use pyro.ops.jit.trace:

loss = TraceEnum_ELBO(...).differentiable_loss
jit_loss = pyro.ops.jit.trace(functools.partial(loss, model, guide))

@@ it is so simple. Thanks a lot @eb8680 ! I'll test it on some gp models and get back to you.

Note that pyro.ops.jit.trace is just a thin wrapper around torch.jit.trace that makes sure pyro.param statements are visible, so you might even be able to just use torch.jit.trace if you're getting your parameters from a nn.Module and not ever interacting with the parameter store.

Was this page helpful?
0 / 5 - 0 ratings