Hello,
I wanted to learn jax. I suppose it is an autograd package that allows differentiation of arbitrary programs.
Could you point me to an example where I can do a simple linear regression for say, a 10 dimensional variable.
Even a high level code for linear regression with jax api will be very helpful. Any accessible introduction, that explains differences with autograd or tangent will be very helpful @mattjj
Thank you.
Thanks for the question! Luckily someone just wrote a blog post that includes linear regression in JAX, both by gradient descent and by directly solving the normal equations: see Example 3 in this post by @ColCarroll.
Regarding the difference from Autograd, JAX is basically "Autograd 2" built around XLA rather than directly on NumPy. The difference to Tangent is essentially the same as the difference between Autograd and Tangent: while Tangent reads and writes Python source code, Autograd and JAX do the much simpler thing of tracing program behavior.
Feel free to reopen if that wasn't what you were looking for.
@mattjj Thank you.
The linked tutorial does does not help in understanding a few things for me.
Please consider this linear regression example for 1 D variable:
import numpy as np
import matplotlib.pyplot as plt
# Prepare data: input = 1..100, output = squared + 2
input = np.array(list(range(100)))
output = np.array([x**2 + 2.0 for x in list(range(100))])
# Define Model: a linear model
a = 1
b = 1
# Define loss : squared loss
def loss(y, y_hat):
return 0.5*(y-y_hat)**2
losses = []
# Train model
for e in range(1000):
epoch_loss = 0
for x, y in zip(input, output):
y_hat = a*x + b
epoch_loss += loss(y, y_hat)
# Manual grad computation -> autograd will replace this
a = a - 0.0001*(-x)*(y-y_hat)
b = b - 0.0001*(-1)*(y-y_hat)
losses.append(epoch_loss)
print("Epoch {0} Training loss = {1}".format(e, epoch_loss))
# Summary
plt.plot(list(range(1,len(losses)+1)), losses, '-', label='Loss vs Epochs')
plt.legend()
plt.show()
# Make Predictions on new data
test_input = np.array(list(range(100,150)))
test_output = np.array([x**2+ 2.0 for x in list(range(100,150))])
model_predictions = np.array([a*x + b for x in list(range(100,150))])
plt.plot(input, output, '*', label='Training Data')
plt.plot(input, [a*x+b for x in input], '-', label='Training Prediction')
plt.plot(test_input, test_output, 'ro', label='Test Data')
plt.plot(test_input, model_predictions, '-', label='Test Prediction')
plt.legend()
plt.show()
Now, I want to do the gradient computation using autograd and then jax for this same code.
I tried:
import autograd.numpy as np
from autograd import grad
from autograd import elementwise_grad
and then using,
elementwise_grad(loss)(y, y_hat)
in the loop, but I should not expect this to work. This gives value 1 for example, whereas what I want is to get gradients of the loss with respect to my two parameters a and b.
How could I compute the gradients using autograd and jax for this simple case?
@ColCarroll could you help in how to go about this. My goal is to learn using this 1 D case, then extend it to n-dimensions, then try more complex functions like neural networks.
Thank you for the help.
It looks that I do not have the option to reopen the issue. @mattjj should I open a new issue?
Hey @nvidiaman -- here is your model working. A few notes:
y = x^2 with a line, which will not go well!input and output to features and labels: input is a builtin python function!1e-10 to 1e-5, everything diverges for various reasons.import jax.numpy as np
from jax import value_and_grad, grad
import matplotlib.pyplot as plt
# Prepare data: input = 1..100, output = squared + 2
features = np.arange(100)
labels = features ** 2 + 2.
# Define Model: a linear model
a = 1.
b = 1.
# Define loss : squared loss
def make_loss(targets, features):
def loss(a, b):
return 0.5*np.sum((targets - a * features - b)**2)
return loss
losses = []
loss = make_loss(labels, features)
# Train model
for e in range(100):
epoch_loss, (da, db) = value_and_grad(loss, (0, 1))(a, b)
a -= 1e-10 * da
b -= 1e-10 * db
losses.append(epoch_loss)
print("Epoch {0} Training loss = {1}".format(e, epoch_loss))
@ColCarroll thank you so much !!
I want to ask a few questions:
why two functions in make loss?
you did not make use of grad, how do grad and value_and_grad differ? since jax does not have much documentation, i also want to ask why (0, 1) in the value_and_grad function call?
sorry for another request, but could you also show autograd can be made to work in the above case?
I think Colin wrote make_loss in a standard Python idiom for staging/currying functions using lexical closure.
JAX doesn't have a lot of documentation, but luckily docstrings are pretty clear about grad vs value_and_grad. Both docstrings include an explanation about the argnums argument too.
Autograd has a nearly identical API to JAX, including a value_and_grad function and an analogous import autograd.numpy as np NumPy wrapper, so we probably don't need to get Colin to spell out the details for us on that one :)
Let's try to keep the questions here about JAX specifically, and make sure to check the documentation carefully before asking.
Thanks for the robust discussion, and particularly @ColCarroll for some exquisitely clean code and thorough answers!
@ColCarroll: Took me quite a while to figure out why the loss did not match the value that I expected. You defined make_loss as make_loss(targets, features) but you called it as make_loss(features, targets) with the argument order reversed.
Wow! Good eye! Another reason to write tests and not use GH issues as an IDE 馃ぃ
I updated the code snippet. Hopefully this does not confuse anyone reading your comment in the future!
Most helpful comment
Hey @nvidiaman -- here is your model working. A few notes:
y = x^2with a line, which will not go well!inputandoutputtofeaturesandlabels:inputis a builtin python function!1e-10to1e-5, everything diverges for various reasons.