Dear JAX team,
thanks for all the amazing work you're doing!
I'm using jax.numpy.linalg.matrix_power but am running into an issue when trying to use it with jit. Here's a minimal example:
import jax.numpy as jnp
def minimal_example(F, x, n):
return jnp.linalg.matrix_power(F, n) @ x @ jnp.linalg.matrix_power(F, n).T
sample_F = jnp.eye(2)
sample_x = jnp.array([2, 1])
sample_n = 2
# Works fine without JIT
minimal_example(sample_F, sample_x, sample_n)
jit_fun = jit(minimal_example)
jit_fun(sample_F, sample_x, sample_n)
The last line produces an error. Here's the full trace:
FilteredStackTrace Traceback (most recent call last)
<ipython-input-163-a3fbaea19733> in <module>
----> 1 jit_fun(sample_F, sample_x, sample_n)
<ipython-input-159-ae3772d83a8f> in minimal_example(F, x, n)
2
----> 3 return jnp.linalg.matrix_power(F, n) @ x @ jnp.linalg.matrix_power(F, n).T
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/numpy/linalg.py in matrix_power(a, n)
75 except TypeError as err:
---> 76 raise TypeError("exponent must be an integer, got {}".format(n)) from err
77
FilteredStackTrace: TypeError: exponent must be an integer, got Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/numpy/linalg.py in matrix_power(a, n)
73 try:
---> 74 n = operator.index(n)
75 except TypeError as err:
TypeError: 'DynamicJaxprTracer' object cannot be interpreted as an integer
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-163-a3fbaea19733> in <module>
----> 1 jit_fun(sample_F, sample_x, sample_n)
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
215 backend=backend,
216 name=flat_fun.__name__,
--> 217 donated_invars=donated_invars)
218 return tree_unflatten(out_tree(), out)
219
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1160
1161 def bind(self, fun, *args, **params):
-> 1162 return call_bind(self, fun, *args, **params)
1163
1164 def process(self, trace, fun, tracers, params):
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1151 tracers = map(top_trace.full_raise, args)
1152 with maybe_new_sublevel(top_trace):
-> 1153 outs = primitive.process(top_trace, fun, tracers, params)
1154 return map(full_lower, apply_todos(env_trace_todo(), outs))
1155
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1163
1164 def process(self, trace, fun, tracers, params):
-> 1165 return trace.process_call(self, fun, tracers, params)
1166
1167 def post_process(self, trace, out_tracers, params):
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
573
574 def process_call(self, primitive, f, tracers, params):
--> 575 return primitive.impl(f, *tracers, **params)
576 process_map = process_call
577
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
555 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
556 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 557 *unsafe_map(arg_spec, args))
558 try:
559 return compiled_fun(*args)
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
245 fun.populate_stores(stores)
246 else:
--> 247 ans = call(fun, *args)
248 cache[key] = (ans, fun.stores)
249
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
630 abstract_args, arg_devices = unzip2(arg_specs)
631 if config.omnistaging_enabled:
--> 632 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
633 if any(isinstance(c, core.Tracer) for c in consts):
634 raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
1036 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1037 main.jaxpr_stack = () # type: ignore
-> 1038 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1039 del main
1040 return jaxpr, out_avals, consts
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1017 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1018 in_tracers = map(trace.new_arg, in_avals)
-> 1019 ans = fun.call_wrapped(*in_tracers)
1020 out_tracers = map(trace.full_raise, ans)
1021 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
154
155 try:
--> 156 ans = self.f(*args, **dict(self.params, **kwargs))
157 except:
158 # Some transformations yield from inside context managers, so we have to
<ipython-input-159-ae3772d83a8f> in minimal_example(F, x, n)
1 def minimal_example(F, x, n):
2
----> 3 return jnp.linalg.matrix_power(F, n) @ x @ jnp.linalg.matrix_power(F, n).T
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/numpy/linalg.py in matrix_power(a, n)
74 n = operator.index(n)
75 except TypeError as err:
---> 76 raise TypeError("exponent must be an integer, got {}".format(n)) from err
77
78 if n == 0:
TypeError: exponent must be an integer, got Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Is there some way around this? I suppose I could declare n a static_argnum, but this would be very inefficient for my application.
Thanks!
Martin, thanks for the kind words as always! It's really very encouraging.
I think we could revise the matrix_power implementation not to rely on Python control flow. That is, we can replace uses of Python control flow with lax.switch and lax.while_loop, so that we'd be able to stage it out with jit no problem. (If we use lax.while_loop for memory efficiency, we'd probably need to define a custom jvp.)
Do you use reverse-mode differentiation through this?
Thanks for the super-fast response Matt! That sounds great. Ideally I was planning to do reverse-mode auto-diff, but I realise that while_loop doesn't currently support that. I'm not sure how much slower forward-mode would be but I don't have a great number of parameters so it would probably be fine!
I think we can work out how to define a custom differentiation rule to make reverse-mode (and forward-mode) work efficiently. I mainly wanted to know if there was an easy way to get you un-stuck.
Can you use expm together with some matrix analogue of x^n = exp(n * log x) ? Hmm seems that we don't have a matrix logarithm function...
Thanks Matt! Following your suggestion, I've made a while_loop based version:
@jit
def matrix_power_while_inner(val, F):
i, cur_val = val
return i - 1, F @ cur_val
@jit
def matrix_power_while(F, n):
cond_fun = lambda val: val[0] >= 0
init_val = (n - 1, jnp.eye(F.shape[0]))
body_fun = lambda val: matrix_power_while_inner(val, F)
res = while_loop(cond_fun, body_fun, init_val)
return res[1]
# Returns True:
jnp.allclose(matrix_power_while(F, 10), jnp.linalg.matrix_power(F, 10))
I'll give that a go with forward mode for now. Let me know, I'd be happy to try to adapt this if it's a reasonable way to go for a new version of matrix_power.
Is there an upper-bound on the exponent? We should probably just write something in terms of lax.scan (together with lax.cond for "early exit"), where the length is ceil(log2(upper_bound_on_exponent)). Even if the upper-bound is 2**32, I'm guessing you can tolerate storing 32 copies of your array. WDYT?
Hey Matt, there's definitely an upper bound which I know in advance, and it's certainly much smaller than 2**32, probably less than 2**12 actually. Here's a new version:
import jax.numpy as jnp
from jax.lax import cond, scan
from jax import jit
from jax.numpy import divmod
n = 140
@jit
def scan_fun(carry, xs):
# One step of the iteration
n, z, result = carry
new_n, bit = divmod(n, 2)
new_result = cond(bit, lambda x: z @ x, lambda x: x, result)
# No more computation necessary if n = 0
# Is there a better way to early break rather than just returning something empty?
new_z = cond(new_n, lambda z: z @ z, lambda _: jnp.empty(z.shape), z)
return (new_n, new_z, new_result), None
@jit
def matrix_power_scan(F, n, upper_limit=32):
# TODO: I think we can avoid setting the third carry element to eye and save one matrix multiply
init_carry = n, F, jnp.eye(F.shape[0])
result = cond(n == 1, lambda _: F, lambda _: scan(scan_fun, init_carry, None, length=upper_limit)[0][2],
F)
return result
# Returns True
jnp.allclose(matrix_power_scan(F, n), jnp.linalg.matrix_power(F, n))
Thanks for the pointer. I don't think this is completely ideal (pretty sure it does one matrix multiply too many, see comments) but it seems to work and it'll be much faster than my hopelessly inefficient naive version, I should have thought of the log trick! Let me know if you have any thoughts and whether this makes sense.