jax.jacobian could automatically choose between forward and reverse mode

Created on 2 Mar 2020  路  2Comments  路  Source: google/jax

It is hard for non-expert users to make an intelligent choice here, e.g., see this tweet. If they don't care about the particular details, we could make jax.jacobian a generic version that makes a best effort to compute Jacobians in an intelligent way based on heuristics. Note that jacobian is currently just an alias for jaxrev.

As a first pass, we could simply compute the size of input/output vectors and choose forward mode when size(outputs) >= size(inputs) and otherwise use reverse mode.

More sophisticated approaches could attempt to estimate the total cost of forward/reverse passes, or even attempt to do some sort of optimal hybrid approach.

@mattjj notes: the slickest way to do it might be to use linearize and then have a separate transpose if needed so we only trace once

Most helpful comment

Hi @mattjj @shoyer, I think the plan @shoyer described above is a very good start. In DiffSharp we've had this jacobian operator for quite some time that is using the heuristic of "use forward if num_inputs < num_outputs, use reverse otherwise". https://github.com/DiffSharp/DiffSharp/blob/master/src/DiffSharp/AD.Float32.fs#L4041

There jacobianTv is the transposed-Jacobian-vector product (reverse mode) and jacobianv is the Jacobian-vector product (forward mode).

Note that the code I linked to is quite dated. I'm currently reimplementing DiffSharp in the dev branch. https://github.com/DiffSharp/DiffSharp/tree/dev

All 2 comments

@gbaydin might have some advice here from work on DiffSharp

Hi @mattjj @shoyer, I think the plan @shoyer described above is a very good start. In DiffSharp we've had this jacobian operator for quite some time that is using the heuristic of "use forward if num_inputs < num_outputs, use reverse otherwise". https://github.com/DiffSharp/DiffSharp/blob/master/src/DiffSharp/AD.Float32.fs#L4041

There jacobianTv is the transposed-Jacobian-vector product (reverse mode) and jacobianv is the Jacobian-vector product (forward mode).

Note that the code I linked to is quite dated. I'm currently reimplementing DiffSharp in the dev branch. https://github.com/DiffSharp/DiffSharp/tree/dev

Was this page helpful?
0 / 5 - 0 ratings

Related issues

zhongwen picture zhongwen  路  3Comments

kunc picture kunc  路  3Comments

lonelykid picture lonelykid  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments

shannon63 picture shannon63  路  3Comments