What should I start looking at to debug this error?
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/typer/main.py", line 214, in __call__
return get_command(self)(*args, **kwargs)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 829, in __call__
return self.main(*args, **kwargs)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 1259, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/typer/main.py", line 497, in wrapper
return callback(**use_params) # type: ignore
File "/home/neil/src/cmm/cmm/demo/pooling.py", line 63, in pooling
p_trajectory = solution.train(1000)
File "/home/neil/src/cmm/cmm/structure/solution/solution.py", line 111, in train
augmented, trajectory = method(None,
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/traceback_util.py", line 133, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/api.py", line 217, in f_jitted
out = xla.xla_call(
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 1177, in bind
return call_bind(self, fun, *args, **params)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 1168, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 1180, in process
return trace.process_call(self, fun, tracers, params)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/core.py", line 579, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 556, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/linear_util.py", line 251, in memoized_fun
ans = call(fun, *args)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 703, in _xla_callable
compiled = backend_compile(backend, built, options)
File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 344, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: Expected instruction to have shape equal to f32[7], actual shape is f32[6]:
%pad.231.clone = f32[6]{0} pad(f32[1]{0} %reshape.13664, f32[] %constant.18774), padding=6_0, metadata={op_type="scatter" op_name="jit(sample_trajectory)/scan/while/body/while/body/scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))\n indices_are_sorted=False\n unique_indices=False\n update_consts=( ) ]" source_file="/home/neil/src/tjax/tjax/shims.py" source_line=53}
Failed after simplification
This seems most likely to be an XLA compiler bug. We can report it to the XLA folks, but we'd need a way to reproduce it. Ideally this would be something like a small self-contained Python reproduction, but we may also be able to reproduce it from the dumped XLA IR, which you can dump by setting the environment variable XLA_FLAGS=--xla_dump_to=/tmp/somewhere and sharing the output.
@hawkinsp Thanks for taking a look at this! Here's the XLA dump you requested:
xla_bad_instruction_shape.tar.gz
It would take me a few days of work to trim down 3500 lines to a minimum working example, but I've done it for other bugs, and I'll happily do it again if the XLA dump isn't enough. Please let me know.
Which jaxlib version are you using? We just uploaded jaxlib==0.1.57, and it's worth trying with that in case this bug has already been fixed.
@mattjj Thanks, but I'm on jaxlib 0.1.57 and my JAX is master.
Congrats on (probably) finding an XLA bug then :)
I didn't look at this very hard, but I think the necessary HLO module isn't in the dumps you attached. There appear to be .ll file (i.e., LLVM IR) corresponding to all of the *_before*.txt files, but if HLO -> LLVM lowering had failed, we'd expect to see an HLO module without its matching LLVM IR.
@hawkinsp How do I get the HLO module? I just followed your instructions and then tarred the entire folder that was produced.