Getting started with JAX for AI#

Open in Colab

JAX is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.

Who is this tutorial for?#

This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. JAX is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with NumPy, and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.

What does this tutorial cover?#

JAX focuses on array-based computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:

  • Flax NNX: A machine learning library designed for defining and building scalable neural networks using JAX.

  • Optax: A high-performance function optimization library that comes with built-in optimizers and loss functions.

After working through this content, you may wish to visit the JAX documentation site for a deeper dive into the core JAX concepts.

Example: A simple neural network with Flax#

We’ll start with a very quick example of what it looks like to use JAX with the Flax framework to define and train a very simple neural network to recognize hand-written digits.

Loading the data#

JAX can work with a variety of data loaders, including Grain, TensorFlow Datasets and TorchData, but for simplicity this example uses the well-known scikit-learn digits dataset.

from sklearn.datasets import load_digits
digits = load_digits()

print(f"{digits.data.shape=}")
print(f"{digits.target.shape=}")
digits.data.shape=(1797, 64)
digits.target.shape=(1797,)

This dataset consists of 8x8 pixelated images of hand-written digits and their corresponding labels. Let’s visualize a handful of them with matplotlib:

import matplotlib.pyplot as plt

fig, axes = plt.subplots(10, 10, figsize=(6, 6),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')
    ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')
_images/0ffd0d926a33909ebb215773755b71394c2ae41422b9ec10196511f89c26825a.png

Next, we split the dataset into a training and testing set, and convert these splits into jax.Arrays before we feed them into the model. We’ll use the jax.numpy module, which provides a familiar NumPy-style API around JAX operations:

from sklearn.model_selection import train_test_split
splits = train_test_split(digits.images, digits.target, random_state=0)
import jax.numpy as jnp
images_train, images_test, label_train, label_test = map(jnp.asarray, splits)
print(f"{images_train.shape=} {label_train.shape=}")
print(f"{images_test.shape=}  {label_test.shape=}")
images_train.shape=(1347, 8, 8) label_train.shape=(1347,)
images_test.shape=(450, 8, 8)  label_test.shape=(450,)

Defining the Flax model#

We can now use Flax NNX to create a simple feed-forward neural network - subclassing flax.nnx.Module - with flax.nnx.Linear layers with scaled exponential linear unit (SELU) activation function using the built-in flax.nnx.selu:

from flax import nnx

class SimpleNN(nnx.Module):

  def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,
               *, rngs: nnx.Rngs):
    self.n_features = n_features
    self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
    self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
    self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)

  def __call__(self, x):
    x = x.reshape(x.shape[0], self.n_features) # Flatten images.
    x = nnx.selu(self.layer1(x))
    x = nnx.selu(self.layer2(x))
    x = self.layer3(x)
    return x

model = SimpleNN(rngs=nnx.Rngs(0))

nnx.display(model)  # Interactive display if penzai is installed.

Training the model#

With the SimpleNN model created and instantiated, we can now choose the loss function and the optimizer with the Optax package, and then define the training step function. Use:

import jax
import optax

optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))

def loss_fun(
    model: nnx.Module,
    data: jax.Array,
    labels: jax.Array):
  logits = model(data)
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=labels
  ).mean()
  return loss, logits

@nnx.jit  # JIT-compile the function
def train_step(
    model: nnx.Module,
    optimizer: nnx.Optimizer,
    data: jax.Array,
    labels: jax.Array):
  loss_gradient = nnx.grad(loss_fun, has_aux=True)  # gradient transform!
  grads, logits = loss_gradient(model, data, labels)
  optimizer.update(grads)  # inplace update

Notice here the use of flax.nnx.jit and flax.nnx.grad, which are Flax NNX transformations built on jax.jit and jax.grad transformations.

We will return to these transformations later in the tutorial.

Now that we have a training step function, let’s define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:

for i in range(301):  # 300 training epochs
  train_step(model, optimizer, images_train, label_train)
  if i % 50 == 0:  # Print metrics.
    loss, _ = loss_fun(model, images_test, label_test)
    print(f"epoch {i}: loss={loss:.2f}")
epoch 0: loss=20.18
epoch 50: loss=0.21
epoch 100: loss=0.14
epoch 150: loss=0.12
epoch 200: loss=0.12
epoch 250: loss=0.11
epoch 300: loss=0.11

After 300 training epochs, our model should have converged to a target loss of around 0.10. We can check what this implies for the accuracy of the labels for each image:

label_pred = model(images_test).argmax(axis=1)
num_matches = jnp.count_nonzero(label_pred == label_test)
num_total = len(label_test)
accuracy = num_matches / num_total
print(f"{num_matches} labels match out of {num_total}:"
      f" accuracy = {num_matches/num_total:%}")
433 labels match out of 450: accuracy = 96.222222%

The simple feed-forward network has achieved approximately 98% accuracy on the test set. We can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):

fig, axes = plt.subplots(10, 10, figsize=(6, 6),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(images_test[i], cmap='binary', interpolation='gaussian')
    color = 'green' if label_pred[i] == label_test[i] else 'red'
    ax.text(0.05, 0.05, str(label_pred[i]), transform=ax.transAxes, color=color)
_images/4086fb1b044f41d078cfc1933ead0be100305f1baccbd71996f4c50f9f5df3a6.png

In this tutorial, we have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the Flax MNIST tutorial on the Flax website.

Key JAX features#

The Flax NNX neural network API demonstrated above takes advantage of a number of key JAX features, designed into the library from the ground up. In particular:

  • JAX provides a familiar NumPy-like API for array computing. This means that when processing data and outputs, we can reach for APIs like jax.numpy.count_nonzero, which mirror the familiar APIs of the NumPy package; in this case numpy.count_nonzero.

  • JAX provides just-in-time (JIT) compilation. This means that we can implement our code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the XLA compiler by wrapping the code with a simple jax.jit transformation.

  • JAX provides automatic differentiation (autodiff). This means that when fitting models, optax and flax can compute closed-form gradient functions for fast optimization of models, using the jax.grad transformation.

  • JAX provides automatic vectorization. While we didn’t get to use this directly in the code before, but under the hood flax takes advantage of JAX’s vectorized map (jax.vmap) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.

We will learn more about these features through brief examples in the following sections.

JAX NumPy interface#

The foundational array computing package in Python is NumPy, and JAX provides a matching API via the jax.numpy subpackage. Additionally, JAX arrays (jax.Array) behave much like NumPy arrays in their attributes, and in terms of indexing and broadcasting semantics.

In the previous example, we used Flax’s built-in flax.nnx.selu implementation. We can also implement SeLU using JAX’s NumPy API as follows:

import jax.numpy as jnp

def selu(x, alpha=1.67, lam=1.05):
  return lam * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))
[0.        1.05      2.1       3.1499999 4.2      ]

Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in 🔪 JAX – The Sharp Bits 🔪 on the JAX site.

Just-in-time compilation#

As mentioned before, JAX is built on the XLA compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the jax.jit transformation. In the neural network example above, we used the similar flax.nnx.jit transform, which has some special handling for Flax NNX objects for speed in neural network training.

Returning to the previously defined selu function in JAX, we can create a jax.jit-compiled version this way:

import jax
selu_jit = jax.jit(selu)

selu_jit is now a compiled version of the original function, which returns the same result to typical floating-point precision:

x = jnp.arange(1E6)
jnp.allclose(selu(x), selu_jit(x))  # results match
Array(True, dtype=bool)

We can use IPython’s %timeit magic to observe the speedup (note the use of jax.block_until_ready(), which we need to use to account for JAX’s asynchronous dispatch):

%timeit selu(x).block_until_ready()
2.95 ms ± 98 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit selu_jit(x).block_until_ready()
242 μs ± 1.39 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

For this computation, running on CPU, jax.jit compilation gives an order of magnitude speedup. JAX’s documentation has more discussion of JIT compilation at Just-in-time compilation.

Automatic differentiation (autodiff)#

For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its automatic differentiation transformations like jax.grad, which computes a closed-form gradient of a JAX function. In the neural network example, we used the similar flax.nnx.grad function, which has special handling for flax.nnx objects.

Here’s how to compute the gradient of a function with jax.grad:

x = jnp.float32(-1.0)
jax.grad(selu)(x)
Array(0.6450766, dtype=float32)

We can briefly check with a finite-difference approximation that this is giving the expected value:

eps = 1E-3
(selu(x + eps) - selu(x)) / eps
Array(0.64539903, dtype=float32)

Importantly, the automatic differentiation approach is both more accurate and efficient than computing numerical gradients. JAX’s documentation has more discussion of autodiff at Automatic differentiation and Advanced automatic differentiation.

Automatic vectorization#

In the training loop example earlier, we defined the loss function in terms of a single input data vector of shape n_features but trained the model by passing batches of data (of shape [n_samples, n_features]). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX’s automatic vectorization via the jax.vmap transformation to construct a batched version of the kernel automatically.

Consider a simple loss function that looks like this:

def loss(x: jax.Array, x0: jax.Array):
  return jnp.sum((x - x0) ** 2)

We can evaluate it on a single data vector this way:

x = jnp.arange(3.)
x0 = jnp.ones(3)
loss(x, x0)
Array(2., dtype=float32)

But if we attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:

batched_x = jnp.arange(12).reshape(4, 3)  # batch of 4 vectors
loss(batched_x, x0)  # wrong!
Array(386., dtype=float32)

The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways we can address this:

  1. Re-write our loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.

  2. Naively loop over unbatched calls to our original function. However, this is easy to code, but can be slow because it doesn’t take advantage of vectorized compute.

The jax.vmap transformation offers a third way: it automatically transforms our original function into a batch-aware version, so we get the speed of option 1 with the ease of option 2:

loss_batched = jax.vmap(loss, in_axes=(0, None))  # batch x over axis 0, do not batch x0
loss_batched(batched_x, x0)
Array([  2.,  29., 110., 245.], dtype=float32)

In the neural network example earlier, both flax and optax make use of JAX’s vmap to allow for efficient batched computations over our unbatched loss function.

JAX’s documentation has more discussion of automatic vectorization at Automatic vectorization.