Jax: Issue in readme for vmap

Created on 20 Dec 2020  路  3Comments  路  Source: google/jax

Hi great Jax team!

I just noticed in the readme file of the repository, in the explanation of vmap, there may be an unintended typo:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  for W, b in params:
    output_vec = jnp.dot(W, input_vec) + b  # `input_vec` on the right-hand side!
    input_vec = jnp.tanh(output_vec)
  return output_vec

I guess it should be:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  for W, b in params:
    output_vec = jnp.dot(W, input_vec) + b  # `input_vec` on the right-hand side!
    output_vec = jnp.tanh(output_vec) # I think this line was intended to be like this
  return output_vec

Thanks for Jax!

question

Most helpful comment

Hey, thanks for calling attention to this, and for unpacking it so clearly!

Actually, it is written as intended. The value of input_vec is used on the next iteration of the loop, so that one layer's output is the next layer's input.

Perhaps it looks surprising at first because the last iteration's updated value of input_vec is not consumed. But writing it this way lets us avoid special-casing the last layer (to be linear, i.e. not to have an activation). That is, it avoids us having to write it more like:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  for W, b in params[:-1]:
    output_vec = jnp.dot(W, input_vec) + b
    input_vec = jnp.tanh(output_vec)
  W_final, b_final = params[-1]
  output_vec = jnp.dot(W_final, input_vec) + b_final
  return output_vec

We chose to make the example code more concise. Moreover, any dead code is eliminated automatically under a jit, so we don't usually worry about things like an extra jnp.tanh evaluation!

Hope that makes sense. Thanks for the kind words, and for surfacing any possible issue you spot. It's really helpful to have extra checks; false positives are better than missed detections!

All 3 comments

Yup, makes sense: prediction_output = jnp.tanh(jnp.dot(W, input_vec) + b) 馃憤

Thanks for Jax!

Same here!

image

Hey, thanks for calling attention to this, and for unpacking it so clearly!

Actually, it is written as intended. The value of input_vec is used on the next iteration of the loop, so that one layer's output is the next layer's input.

Perhaps it looks surprising at first because the last iteration's updated value of input_vec is not consumed. But writing it this way lets us avoid special-casing the last layer (to be linear, i.e. not to have an activation). That is, it avoids us having to write it more like:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  for W, b in params[:-1]:
    output_vec = jnp.dot(W, input_vec) + b
    input_vec = jnp.tanh(output_vec)
  W_final, b_final = params[-1]
  output_vec = jnp.dot(W_final, input_vec) + b_final
  return output_vec

We chose to make the example code more concise. Moreover, any dead code is eliminated automatically under a jit, so we don't usually worry about things like an extra jnp.tanh evaluation!

Hope that makes sense. Thanks for the kind words, and for surfacing any possible issue you spot. It's really helpful to have extra checks; false positives are better than missed detections!

Thanks a lot for the detailed explanation!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

rdaems picture rdaems  路  3Comments

asross picture asross  路  3Comments

sussillo picture sussillo  路  3Comments

lonelykid picture lonelykid  路  3Comments

murphyk picture murphyk  路  3Comments