Pyro: [bug] `--jit` option fails for HMM example

Created on 8 Oct 2018  路  7Comments  路  Source: pyro-ppl/pyro

Issue Description

Using the --jit option for the examples/hmm.py example fails with the following error:

      741 Loading data
      784 ----------------------------------------
      784 Training model_1 on 229 sequences
      791 Step  Loss
Traceback (most recent call last):
  File ".../pyro/examples/hmm.py", line 294, in <module>
    main(args)
  File ".../hmm.py", line 256, in main
    loss = svi.step(sequences, lengths, args, batch_size=args.batch_size)
  File "...\pyro\infer\svi.py", line 96, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "...\pyro\infer\traceenum_elbo.py", line 431, in loss_and_grads
    differentiable_loss = self._differentiable_loss(*args)
  File "...\pyro\ops\jit.py", line 56, in __call__
    ret = self.compiled(param_list, *args, **kwargs)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got Namespace

I believe this is because the command line arguments args are passed as a parameter of type Namespace, which is not supported by the JIT.

Environment

  • OS and python version: Windows 10, Python 3.6
  • PyTorch version: 0.4.0
  • Pyro version: 0.2.1

Code Snippet

Simply run examples/hmm.py with option --jit.

bug jit

Most helpful comment

I think we just need to pass the args object as a kwarg rather than a positional arg:

- loss = svi.step(sequences, lengths, args, batch_size=args.batch_size)
+ loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)

I'll send a PR shortly.

All 7 comments

there is no jit support implemented in 0.2.1 and pytorch 0.4.0. see instructions to install the pre-release versions here. note that the pytorch jit is not yet stable and ongoing active development so things may change/break up to the release.

@jpchen I am using the dev branch, sorry for being unclear.

@eb8680 Will try, thanks for the suggestion!

Thanks for the bug report, and you are right that args is the problematic bit when we try to JIT the svi.step call.

@jpchen - Reopening this, as this is an issue with the pytorch-1.0 for the reason that @ahmadsalim mentioned.

I think we just need to pass the args object as a kwarg rather than a positional arg:

- loss = svi.step(sequences, lengths, args, batch_size=args.batch_size)
+ loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)

I'll send a PR shortly.

@ahmadsalim: This should work on the pytorch-1.0 branch now.

I'm still seeing an error after #1445, but it's a different error:

ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 11 different scales.

It is unclear whether that's due to a bug in PyTorch or Pyro.

Was this page helpful?
0 / 5 - 0 ratings