Using the --jit option for the examples/hmm.py example fails with the following error:
741 Loading data
784 ----------------------------------------
784 Training model_1 on 229 sequences
791 Step Loss
Traceback (most recent call last):
File ".../pyro/examples/hmm.py", line 294, in <module>
main(args)
File ".../hmm.py", line 256, in main
loss = svi.step(sequences, lengths, args, batch_size=args.batch_size)
File "...\pyro\infer\svi.py", line 96, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "...\pyro\infer\traceenum_elbo.py", line 431, in loss_and_grads
differentiable_loss = self._differentiable_loss(*args)
File "...\pyro\ops\jit.py", line 56, in __call__
ret = self.compiled(param_list, *args, **kwargs)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got Namespace
I believe this is because the command line arguments args are passed as a parameter of type Namespace, which is not supported by the JIT.
Simply run examples/hmm.py with option --jit.
there is no jit support implemented in 0.2.1 and pytorch 0.4.0. see instructions to install the pre-release versions here. note that the pytorch jit is not yet stable and ongoing active development so things may change/break up to the release.
@jpchen I am using the dev branch, sorry for being unclear.
@eb8680 Will try, thanks for the suggestion!
Thanks for the bug report, and you are right that args is the problematic bit when we try to JIT the svi.step call.
@jpchen - Reopening this, as this is an issue with the pytorch-1.0 for the reason that @ahmadsalim mentioned.
I think we just need to pass the args object as a kwarg rather than a positional arg:
- loss = svi.step(sequences, lengths, args, batch_size=args.batch_size)
+ loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)
I'll send a PR shortly.
@ahmadsalim: This should work on the pytorch-1.0 branch now.
I'm still seeing an error after #1445, but it's a different error:
ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 11 different scales.
It is unclear whether that's due to a bug in PyTorch or Pyro.
Most helpful comment
I think we just need to pass the
argsobject as a kwarg rather than a positional arg:I'll send a PR shortly.