Jax: Please consider facilitating debugging leaked tracers

Created on 13 May 2020  路  2Comments  路  Source: google/jax

Dear JAX people, I've been struggling again with a leaked tracer.

The last time I had a leaked tracer, it took me a few days to find it. This time, it's been a few days and I still can't find the problem. I don't know how to make progress here. Debugging leaked tracers is hard because the error happens so far from the problem.

In the code below, I've tried to modify JAX so that it names the tracers to give me an idea of where the tracer is created. Unfortuantely, it doesn't seem to retain the names probably because tracers are copied around and the name is lost. Any insight on this would be really helpful.

I realize that tracers are "deep internals" as Matt put it. I realize you don't want to document them. Would it be possible to improve the debug messages produced? Maybe with a flag?

from inspect import signature
from typing import Any, Callable, Union

from functools import wraps, partial
from jax.interpreters.partial_eval import JaxprTracer, Trace
from jax import jit, grad, vjp
from jax.tree_util import tree_map

__all__ = ['set_debug_reprs', 'djit']


def _id_str(obj: Any) -> str:
    id_ = hex(id(obj) & 0xfffff)[-5:]
    if hasattr(obj, '_name'):
        return f'{obj._name}:{id_}'
    return id_


def set_debug_reprs():
    def trace_repr(self: Trace):
        level = f'{self.level}/{self.sublevel}'
        return f'{self.__class__.__name__}({_id_str(self)}:{level})'

    def jaxpr_tracer_repr(self: JaxprTracer):
        trace = self._trace
        trace_id = _id_str(trace)
        return f'Tracer<{trace_id}::{_id_str(self)}>'

    Trace.__repr__ = trace_repr
    JaxprTracer.__repr__ = jaxpr_tracer_repr


def name_tracer(function_name, argument_name, tracer):
    if not isinstance(tracer, JaxprTracer):
        return tracer
    tracer._name = argument_name
    tracer._trace._name = function_name
    print(tracer)
    return tracer


def function_name(fun: Callable) -> str:
    if hasattr(fun, '__name__'):
        return fun.__name__
    if hasattr(fun, 'func'):
        return function_name(fun.func)
    return "NAMELESS"


def name_all_tracers(fun):
    s = signature(fun)
    @wraps(fun)
    def new_fun(*args, **kwargs):
        bound_arguments = s.bind(*args, **kwargs)
        for name, value in bound_arguments.arguments.items():
            tree_map(partial(name_tracer, function_name(fun), name), value)
        return fun(*args, **kwargs)
    return new_fun


def djit(fun: Callable, *args, **kwargs):
    new_fun = name_all_tracers(fun)
    return jit(new_fun, *args, **kwargs)


def dgrad(fun, *args, **kwargs):
    new_fun = name_all_tracers(fun)
    return grad(new_fun, *args, **kwargs)


def dvjp(fun, *args, **kwargs):
    new_fun = name_all_tracers(fun)
    return vjp(new_fun, *args, **kwargs)

When debugging, I call set_debug_reprs() and use the shim:

from jax import *
from .debugging import djit as jit, dgrad as grad, dvjp as vjp

With this, a JaxprTracer now prints out as:

Tracer<find_fixed_point:84220::x_init:87810>

This shows the function name in which it was created find_fixed_point, the parameter name that it corresponds to x_init, and some object id codes for each in case there are duplicate invocations.

It would have been nice to also name the field in case the tracer is part of a PyTree-like object, but that's hard to get unless the PyTree-like object exposes names for its fields like in flax.struct.dataclass that was recently shown to me. (Perhaps the PyTree registry could optionally accept a third function that produces a tree of strings for naming purposes.)

Most helpful comment

This is a fantastic idea. We've got to figure something out along these lines, otherwise some classes of errors are going to remain impossibly hard to find and debug.

All 2 comments

This is a fantastic idea. We've got to figure something out along these lines, otherwise some classes of errors are going to remain impossibly hard to find and debug.

Woohoo! Thank you. Please let me know if you have any ideas on how I can make this actually work as this is what's blocking me full time. I'm starting to try to figure out JAX's internals to see where the names are lost.

Was this page helpful?
0 / 5 - 0 ratings