Currently, we have jit version of loss_and_grads. It would be nice to have that version for differentiable_loss too. I expect that it would be similar?
Just use pyro.ops.jit.trace:
loss = TraceEnum_ELBO(...).differentiable_loss
jit_loss = pyro.ops.jit.trace(functools.partial(loss, model, guide))
@@ it is so simple. Thanks a lot @eb8680 ! I'll test it on some gp models and get back to you.
Note that pyro.ops.jit.trace is just a thin wrapper around torch.jit.trace that makes sure pyro.param statements are visible, so you might even be able to just use torch.jit.trace if you're getting your parameters from a nn.Module and not ever interacting with the parameter store.
Most helpful comment
Just use
pyro.ops.jit.trace: