Jax: Internal runtime error?

Created on 11 Nov 2020  路  7Comments  路  Source: google/jax

What should I start looking at to debug this error?

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/typer/main.py", line 214, in __call__
    return get_command(self)(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 829, in __call__
    return self.main(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 782, in main
    rv = self.invoke(ctx)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 1259, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/typer/main.py", line 497, in wrapper
    return callback(**use_params)  # type: ignore
  File "/home/neil/src/cmm/cmm/demo/pooling.py", line 63, in pooling
    p_trajectory = solution.train(1000)
  File "/home/neil/src/cmm/cmm/structure/solution/solution.py", line 111, in train
    augmented, trajectory = method(None,
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/traceback_util.py", line 133, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/api.py", line 217, in f_jitted
    out = xla.xla_call(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 1177, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 1168, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 1180, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 579, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 556, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/linear_util.py", line 251, in memoized_fun
    ans = call(fun, *args)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 703, in _xla_callable
    compiled = backend_compile(backend, built, options)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 344, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: Expected instruction to have shape equal to f32[7], actual shape is f32[6]:
%pad.231.clone = f32[6]{0} pad(f32[1]{0} %reshape.13664, f32[] %constant.18774), padding=6_0, metadata={op_type="scatter" op_name="jit(sample_trajectory)/scan/while/body/while/body/scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))\n                                                           indices_are_sorted=False\n                                                           unique_indices=False\n                                                           update_consts=(  ) ]" source_file="/home/neil/src/tjax/tjax/shims.py" source_line=53}

Failed after simplification
bug

All 7 comments

This seems most likely to be an XLA compiler bug. We can report it to the XLA folks, but we'd need a way to reproduce it. Ideally this would be something like a small self-contained Python reproduction, but we may also be able to reproduce it from the dumped XLA IR, which you can dump by setting the environment variable XLA_FLAGS=--xla_dump_to=/tmp/somewhere and sharing the output.

@hawkinsp Thanks for taking a look at this! Here's the XLA dump you requested:
xla_bad_instruction_shape.tar.gz

It would take me a few days of work to trim down 3500 lines to a minimum working example, but I've done it for other bugs, and I'll happily do it again if the XLA dump isn't enough. Please let me know.

Which jaxlib version are you using? We just uploaded jaxlib==0.1.57, and it's worth trying with that in case this bug has already been fixed.

@mattjj Thanks, but I'm on jaxlib 0.1.57 and my JAX is master.

Congrats on (probably) finding an XLA bug then :)

I didn't look at this very hard, but I think the necessary HLO module isn't in the dumps you attached. There appear to be .ll file (i.e., LLVM IR) corresponding to all of the *_before*.txt files, but if HLO -> LLVM lowering had failed, we'd expect to see an HLO module without its matching LLVM IR.

@hawkinsp How do I get the HLO module? I just followed your instructions and then tarred the entire folder that was produced.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

fehiepsi picture fehiepsi  路  3Comments

sursu picture sursu  路  3Comments

madvn picture madvn  路  3Comments

lonelykid picture lonelykid  路  3Comments

yfji picture yfji  路  3Comments