Jax: utility for closure-conversion in higher-order functions with custom derivatives

Created on 18 Dec 2020  路  7Comments  路  Source: google/jax

When a higher-order python function is associated with a custom derivative (e.g. via custom_vjp), differentiation doesn't handle the closures of its function arguments. A workaround is to accept auxiliary arguments and thread them back into input functions that otherwise would have closed over them. An example sketch:

@custom_vjp
def minimize(objective_fn, x0, objective_aux_args=()):
  # ...
  t = objective_fn(..., *objective_aux_args)
  # ...

An alternative approach taken by jax.experimental.ode.odeint is to stage function arguments out to jaxprs, converting closures in the process. We could try and factor this setup into a common utility for use in other similar applications.

enhancement

All 7 comments

The workaround based on auxiliary arguments has several disadvantages. It can be tricky to enforce, and it can require inconvenient code restructuring in callers. For more on why odeint takes the approach that it does, see #2718, #3557, #3558, and the PR resolving them: #3562

I'm currently struggling with this problem for a custom_jvp for a higher-order derivative from a solver. It would be really cool to have a docstring for closure_convert to make it clear how to use it in other places. :-)

Since this is an issue specific to custom derivatives, I'm wondering if we ought to offer it as part of custom_vjp, custom_gradient, etc., in a way that lets you write:

@partial(custom_vjp, closure_convert_argnums=0)
def minimize(objective_fn, x0):
  # ...

@rpadams Would that work for your solver?
@mattjj Having written ode.closure_convert and custom_vjp, what do you think?

Thanks for the fast reply @froystig! If I'm understanding the situation, I think this would resolve my situation in which the thing I'm writing a custom_vjp for is a solver that consumes multiple levels of closures. Threading the arguments through this seems daunting.

More context: I have an energy function that depends on a Jacobian (1st deriv), that is part of a Lagrangian which gives an ODE via Euler-Lagrange (2nd deriv), which I then use an implicit solver for (3rd deriv). Then I'd like to get gradients back through that stack (4th deriv) without having to backprop through my implicit (in the "backward Euler" sense) solver using the implicit (in the "implicit function theorem" sense) gradient. There are many closures/partials/lambdas along the way that I think are not playing nicely with the outer-loop custom_jvp... (However, that last level thing is not mathematically painful because I did most of the IFT work with the second-order solver.)

Threading the arguments through this seems daunting.

Based on your context, it seems that what you'd find daunting is having to pass around extra arguments explicitly on the way down to calling the solver, to avoid forming closures anywhere. Is that correct? By contrast, would you be OK threading the arguments around _within_ the solver implementation, supposing they were extracted from the incoming closures for you?

An amendment to my previous comment: it might make more sense to return to the original idea and offer a closure_convert utility directly鈥攔ather than an option to custom_vjp鈥攕o that you can control the placement and threading of arguments in your solver implementation. In my sketch above, minimize has implicit arguments (via the hypothetical closure conversion) that aren't apparent in its signature.

Altogether here's roughly how using this would look:

def minimize(objective_fn, x0):
  converted_fn, consts = closure_convert(objective_fn, x0)
  return _minimize(converted_fn, x0, consts)

@partial(custom_vjp, nondiff_argnums=0)
def _minimize(objective_fn, x0, objective_aux_args):
  z = objective_fn(x0, objective_aux_args)
  # ...

Assuming we're thinking about this the same way, it's that there are many closures going into forming the objective function minimized by the solver. Threading arguments around just the solver itself isn't too bad --- I'm basically doing that already in order to avoid re-jitting my Levenberg-Marquardt implementation every time step.

I should say that part of what I'm confused about is the statement in the docs:

A limitation to this approach is that the argument f can鈥檛 close over any values involved in differentiation.

It seems like "values involved in differentiation" necessarily covers a lot of ground, i.e., essentially everything that's gone into the objective function, no? I'm interpreting this as "you can't use closures/lambda/partials in building your objective function", but maybe that's overly broad? In my case I'm making pretty extensive use of a generator pattern, e.g., generate_lagrangian that hands back a Lagrangian function that I hand to a generate_euler_lagrange that gives a function I can hand to a generate_time_stepper that constructs an objective for the implicit Euler, etc. I think you guys can appreciate that this is a pretty "SICM" kind of thing I'm doing. :-)

Good! That makes sense. And indeed, that's the kind of stuff that we'd like to support.

In the "writing a solver" setting: custom derivative rules only see formal arguments to the solver, which is why closures (which carry with them extra values, not passed as explicit formal arguments) are an issue in the presence of custom AD specifically, but are not an issue elsewhere. As a solver writer, you can now work around this: the closure_convert utility in #5244 is meant to help you pull implicit arguments out of the closure and turn them into formal arguments before crossing the custom derivative (solver) boundary.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

clemisch picture clemisch  路  3Comments

rdaems picture rdaems  路  3Comments

yfji picture yfji  路  3Comments

shannon63 picture shannon63  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments