Dear JAX people, I've been struggling again with a leaked tracer.
The last time I had a leaked tracer, it took me a few days to find it. This time, it's been a few days and I still can't find the problem. I don't know how to make progress here. Debugging leaked tracers is hard because the error happens so far from the problem.
In the code below, I've tried to modify JAX so that it names the tracers to give me an idea of where the tracer is created. Unfortuantely, it doesn't seem to retain the names probably because tracers are copied around and the name is lost. Any insight on this would be really helpful.
I realize that tracers are "deep internals" as Matt put it. I realize you don't want to document them. Would it be possible to improve the debug messages produced? Maybe with a flag?
from inspect import signature
from typing import Any, Callable, Union
from functools import wraps, partial
from jax.interpreters.partial_eval import JaxprTracer, Trace
from jax import jit, grad, vjp
from jax.tree_util import tree_map
__all__ = ['set_debug_reprs', 'djit']
def _id_str(obj: Any) -> str:
id_ = hex(id(obj) & 0xfffff)[-5:]
if hasattr(obj, '_name'):
return f'{obj._name}:{id_}'
return id_
def set_debug_reprs():
def trace_repr(self: Trace):
level = f'{self.level}/{self.sublevel}'
return f'{self.__class__.__name__}({_id_str(self)}:{level})'
def jaxpr_tracer_repr(self: JaxprTracer):
trace = self._trace
trace_id = _id_str(trace)
return f'Tracer<{trace_id}::{_id_str(self)}>'
Trace.__repr__ = trace_repr
JaxprTracer.__repr__ = jaxpr_tracer_repr
def name_tracer(function_name, argument_name, tracer):
if not isinstance(tracer, JaxprTracer):
return tracer
tracer._name = argument_name
tracer._trace._name = function_name
print(tracer)
return tracer
def function_name(fun: Callable) -> str:
if hasattr(fun, '__name__'):
return fun.__name__
if hasattr(fun, 'func'):
return function_name(fun.func)
return "NAMELESS"
def name_all_tracers(fun):
s = signature(fun)
@wraps(fun)
def new_fun(*args, **kwargs):
bound_arguments = s.bind(*args, **kwargs)
for name, value in bound_arguments.arguments.items():
tree_map(partial(name_tracer, function_name(fun), name), value)
return fun(*args, **kwargs)
return new_fun
def djit(fun: Callable, *args, **kwargs):
new_fun = name_all_tracers(fun)
return jit(new_fun, *args, **kwargs)
def dgrad(fun, *args, **kwargs):
new_fun = name_all_tracers(fun)
return grad(new_fun, *args, **kwargs)
def dvjp(fun, *args, **kwargs):
new_fun = name_all_tracers(fun)
return vjp(new_fun, *args, **kwargs)
When debugging, I call set_debug_reprs() and use the shim:
from jax import *
from .debugging import djit as jit, dgrad as grad, dvjp as vjp
With this, a JaxprTracer now prints out as:
Tracer<find_fixed_point:84220::x_init:87810>
This shows the function name in which it was created find_fixed_point, the parameter name that it corresponds to x_init, and some object id codes for each in case there are duplicate invocations.
It would have been nice to also name the field in case the tracer is part of a PyTree-like object, but that's hard to get unless the PyTree-like object exposes names for its fields like in flax.struct.dataclass that was recently shown to me. (Perhaps the PyTree registry could optionally accept a third function that produces a tree of strings for naming purposes.)
This is a fantastic idea. We've got to figure something out along these lines, otherwise some classes of errors are going to remain impossibly hard to find and debug.
Woohoo! Thank you. Please let me know if you have any ideas on how I can make this actually work as this is what's blocking me full time. I'm starting to try to figure out JAX's internals to see where the names are lost.
Most helpful comment
This is a fantastic idea. We've got to figure something out along these lines, otherwise some classes of errors are going to remain impossibly hard to find and debug.