It is hard for non-expert users to make an intelligent choice here, e.g., see this tweet. If they don't care about the particular details, we could make jax.jacobian a generic version that makes a best effort to compute Jacobians in an intelligent way based on heuristics. Note that jacobian is currently just an alias for jaxrev.
As a first pass, we could simply compute the size of input/output vectors and choose forward mode when size(outputs) >= size(inputs) and otherwise use reverse mode.
More sophisticated approaches could attempt to estimate the total cost of forward/reverse passes, or even attempt to do some sort of optimal hybrid approach.
@mattjj notes: the slickest way to do it might be to use linearize and then have a separate transpose if needed so we only trace once
@gbaydin might have some advice here from work on DiffSharp
Hi @mattjj @shoyer, I think the plan @shoyer described above is a very good start. In DiffSharp we've had this jacobian operator for quite some time that is using the heuristic of "use forward if num_inputs < num_outputs, use reverse otherwise". https://github.com/DiffSharp/DiffSharp/blob/master/src/DiffSharp/AD.Float32.fs#L4041
There jacobianTv is the transposed-Jacobian-vector product (reverse mode) and jacobianv is the Jacobian-vector product (forward mode).
Note that the code I linked to is quite dated. I'm currently reimplementing DiffSharp in the dev branch. https://github.com/DiffSharp/DiffSharp/tree/dev
Most helpful comment
Hi @mattjj @shoyer, I think the plan @shoyer described above is a very good start. In DiffSharp we've had this
jacobianoperator for quite some time that is using the heuristic of "use forward if num_inputs < num_outputs, use reverse otherwise". https://github.com/DiffSharp/DiffSharp/blob/master/src/DiffSharp/AD.Float32.fs#L4041There
jacobianTvis the transposed-Jacobian-vector product (reverse mode) andjacobianvis the Jacobian-vector product (forward mode).Note that the code I linked to is quite dated. I'm currently reimplementing DiffSharp in the
devbranch. https://github.com/DiffSharp/DiffSharp/tree/dev