Keras has a NumPy backend that is currently used for unit tests, but could be used with JAX. A little notebook demonstrating integration would be useful.
A branch that enables simple Sequential models for test-time (runs and is correct) and training (runs but is not currently correct) is here. https://github.com/alexbw/keras/tree/jax-backend. More in-depth discussion here in the keras GitHub issue tracker. This is a small test case.
import os
os.environ['KERAS_BACKEND'] = 'jax'
from keras.models import Sequential
from keras.layers import Dense
from keras.datasets import mnist
from keras.utils import np_utils
from jax import device_put
import jax.numpy as np
import numpy as onp
## MNIST:
onp.random.seed(0)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_train = X_train.reshape(-1, 28*28)
X_test /= 255
X_test = X_test.reshape(-1, 28*28)
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
model = Sequential()
# --------------------------------------------------
d = Dense(units=64, activation='relu', input_dim=28*28)
model.add(d)
d = Dense(units=10, activation='softmax')
model.add(d)
# --------------------------------------------------
out = model.predict(X_train[:1])
# --------------------------------------------------
out = model.predict(X_train[:10])
# --------------------------------------------------
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.train_on_batch(X_train, Y_train)
# --------------------------------------------------
model.fit(X_train, Y_train, epochs=5, batch_size=32)
I think the proof-of-concept has been made. I'm going to close this for now, since the Keras team expects big changes to the project for TF2.0 Eager-mode integration. A lot of the changes I made will be made redundant, and real JAX integration will probably be made easier once that lands.
Some learnings are:
K.add in a new foo_backend.py file, but requires significant plumbing in several different places.Good learnings! Thanks for investigating this and for the detailed report.
Also adding a link to the feature request on keras-team/keras.
Most helpful comment
I think the proof-of-concept has been made. I'm going to close this for now, since the Keras team expects big changes to the project for TF2.0 Eager-mode integration. A lot of the changes I made will be made redundant, and real JAX integration will probably be made easier once that lands.
Some learnings are:
K.addin a newfoo_backend.pyfile, but requires significant plumbing in several different places.