JAX for PyTorch users#
This is a quick overview of JAX and the JAX AI stack written for those who are famiilar with PyTorch.
First, we cover how to manipulate JAX Arrays following the well-known PyTorch’s tensors tutorial. Next, we explore automatic differentiation with JAX, followed by how to build a model and optimize its parameters.
Finally, we will introduce jax.jit
and compare it to its PyTorch counterpart torchscript
.
Setup#
Let’s get started by importing JAX and checking the installed version. For details on how to install JAX check installation guide.
import jax
import jax.numpy as jnp
print(jax.__version__)
0.4.34
JAX Arrays manipulation#
In this section, we will learn about JAX Arrays and how to manipulate them compared to PyTorch tensors.
Initializing a JAX Array#
The primary array object in JAX is the jax.Array
, which is the JAX counterpart of torch.Tensor
.
As with torch.Tensor
, jax.Array
objects are never constructed directly, but rather constructed via array creation APIs that populate the new array with constant numbers, random numbers, or data drawn from lists, numpy arrays, torch tensors, and more.
Let’s see some examples of this.
To initialize an array from a Python data:
# From data
data = [[1, 2, 3], [3, 4, 5]]
x_array = jnp.array(data)
assert isinstance(x_array, jax.Array)
print(x_array, x_array.shape, x_array.dtype)
[[1 2 3]
[3 4 5]] (2, 3) int32
Or from an existing NumPy array:
import numpy as np
np_array = np.array(data)
x_np = jnp.array(np_array)
assert isinstance(x_np, jax.Array)
print(x_np, x_np.shape, x_np.dtype)
# x_np is a copy of np_array
[[1 2 3]
[3 4 5]] (2, 3) int32
You can create arrays with the same shape and dtype
as existing JAX Arrays:
x_ones = jnp.ones_like(x_array)
print(x_ones, x_ones.shape, x_ones.dtype)
x_zeros = jnp.zeros_like(x_array)
print(x_zeros, x_zeros.shape, x_zeros.dtype)
[[1 1 1]
[1 1 1]] (2, 3) int32
[[0 0 0]
[0 0 0]] (2, 3) int32
You can even initialize arrays with constants or random values. For example:
shape = (2, 3)
ones_tensor = jnp.ones(shape)
zeros_tensor = jnp.zeros(shape)
seed = 123
key = jax.random.key(seed)
rand_tensor = jax.random.uniform(key, shape)
print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")
Random Tensor:
[[0.38492894 0.38952553 0.2153877 ]
[0.18297386 0.8140422 0.7754953 ]]
Ones Tensor:
[[1. 1. 1.]
[1. 1. 1.]]
Zeros Tensor:
[[0. 0. 0.]
[0. 0. 0.]]
JAX avoids implicit global random state and instead tracks state explicitly via a random key
.
If we create two random arrays using the same key
we will obtain two identical random arrays.
We can also split the random key
into multiple keys to create two different random arrays.
seed = 124
key = jax.random.key(seed)
rand_tensor1 = jax.random.uniform(key, (2, 3))
rand_tensor2 = jax.random.uniform(key, (2, 3))
assert (rand_tensor1 == rand_tensor2).all()
k1, k2 = jax.random.split(key, num=2)
rand_tensor1 = jax.random.uniform(k1, (2, 3))
rand_tensor2 = jax.random.uniform(k2, (2, 3))
assert (rand_tensor1 != rand_tensor2).all()
For further discussion on random numbers in NumPy and JAX check this tutorial.
Finally, if you have a PyTorch tensor, you can use it to initialize a JAX Array:
import torch
x_torch = torch.rand(3, 4)
# Create JAX Array as a copy of x_torch tensor
x_jax = jnp.asarray(x_torch)
assert isinstance(x_jax, jax.Array)
print(x_jax, x_jax.shape, x_jax.dtype)
# Use dlpack to create JAX Array without copying
x_jax = jax.dlpack.from_dlpack(x_torch.to(device="cuda"), copy=False)
print(x_jax, x_jax.shape, x_jax.dtype)
[[0.30022258 0.9624368 0.22899538 0.54575473]
[0.05540031 0.41184962 0.20278037 0.20024061]
[0.7847725 0.2454623 0.22583973 0.11959136]] (3, 4) float32
[[0.30022258 0.9624368 0.22899538 0.54575473]
[0.05540031 0.41184962 0.20278037 0.20024061]
[0.7847725 0.2454623 0.22583973 0.11959136]] (3, 4) float32
Attributes of a JAX Array#
Similarly to PyTorch tensors, JAX Array attributes describe the array’s shape, dtype and device:
x_jax = jnp.ones((3, 4))
print(f"Shape of tensor: {x_jax.shape}")
print(f"Datatype of tensor: {x_jax.dtype}")
print(f"Device tensor is stored on: {x_jax.device}")
Shape of tensor: (3, 4)
Datatype of tensor: float32
Device tensor is stored on: cuda:0
However, there are some notable differences between PyTorch tensors and JAX Arrays:
JAX Arrays are immutable
The default integer and float dtypes are int32 and float32
The default device corresponds to the available accelerator, e.g. cuda:0 if one or multiple GPUs are available.
try:
x_jax[0, 0] = 100.0
except TypeError as e:
print(e)
x_torch = torch.tensor([1, 2, 3, 4])
x_jax = jnp.array([1, 2, 3, 4])
print(f"Default integer dtypes, PyTorch: {x_torch.dtype} and Jax: {x_jax.dtype}")
x_torch = torch.zeros(3, 4)
x_jax = jnp.zeros((3, 4))
print(f"Default float dtypes, PyTorch: {x_torch.dtype} and Jax: {x_jax.dtype}")
print(f"Default devices, PyTorch: {x_torch.device} and Jax: {x_jax.device}")
'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Default integer dtypes, PyTorch: torch.int64 and Jax: int32
Default float dtypes, PyTorch: torch.float32 and Jax: float32
Default devices, PyTorch: cpu and Jax: cuda:0
For some discussion of JAX’s alternative to in-place mutation, refer to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.
Devices and accelerators#
Using the PyTorch API, we can check whether we have GPU accelerators available with torch.cuda.is_available()
. In JAX, we can check available devices as follows:
print(f"Available devices given a backend (gpu or tpu or cpu): {jax.devices()}")
# Define CPU and CUDA devices
cpu_device = jax.devices("cpu")[0]
cuda_device = jax.devices("cuda")[0]
print(cpu_device, cuda_device)
Available devices given a backend (gpu or tpu or cpu): [CudaDevice(id=0), CudaDevice(id=1)]
TFRT_CPU_0 cuda:0
Let’s briefly explore how to create arrays on CPU and CUDA devices.
# create an array on CPU and check the device
x_cpu = jnp.ones((3, 4), device=cpu_device)
print(x_cpu.device, )
# create an array on GPU
x_gpu = jnp.ones((3, 4), device=cuda_device)
print(x_gpu.device)
TFRT_CPU_0
cuda:0
In PyTorch we are used to device placement always being explicit. JAX can operate this way via explicit device placement as above, but unless the device is specified the array will remain uncommitted: i.e. it will be stored on the default device, but allow implicit movement to other devices when necessary:
x = jnp.ones((3, 4))
x.device, (x_cpu + x).device
(CudaDevice(id=0), CpuDevice(id=0))
However, if we make a computation with two arrays with explicitly specified devices, e.g. CPU and CUDA, similarly to PyTorch, an error will be raised.
try:
x_cpu + x_gpu
except ValueError as e:
print(e)
To move from one device to another, we can use jax.device_put
function:
x = jnp.ones((3, 4))
x_cpu = jax.device_put(x, device=jax.devices("cpu")[0])
x_cuda = jax.device_put(x_cpu, device=jax.devices("cuda")[0])
print(f"{x.device} -> {x_cpu.device} -> {x_cuda.device}")
cuda:0 -> TFRT_CPU_0 -> cuda:0
Operations on JAX Arrays#
There is a large list of operations (arithmetics, linear algebra, matrix manipulation, etc) that can be directly performed on JAX Arrays. JAX API contains important modules:
jax.numpy
provides NumPy-like functionsjax.scipy
provides SciPy-like functionsjax.nn
provides common functions for neural networks: activations, softmax, one-hot encoding etcjax.lax
provides low-level XLA APIs…
More details on available ops can be found in the API reference.
All operations can be run on CPUs, GPUs or TPUs. By default, JAX Arrays are created on an accelerated device, while PyTorch tensors are created on CPUs.
We can now try out some array operations and check for similarities between the JAX, NumPy and PyTorch APIs.
Standard NumPy-like indexing and slicing:
tensor = jnp.ones((3, 4))
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")
# Equivalent PyTorch op: tensor[:, 1] = 0
tensor = tensor.at[:, 1].set(0)
print(tensor)
First row: [1. 1. 1. 1.]
First column: [1. 1. 1.]
Last column: [1. 1. 1.]
[[1. 0. 1. 1.]
[1. 0. 1. 1.]
[1. 0. 1. 1.]]
We would like to note particular out-of-bounds indexing behaviour in JAX. In JAX the index is clamped to the bounds of the array in the indexing operations.
print(jnp.arange(10)[11])
9
Join arrays similar to torch.cat
. Note the kwarg name: axis
vs dim
.
t1 = jnp.concat([tensor, tensor, tensor], axis=1)
print(t1)
[[1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1.]
[1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1.]
[1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1.]]
Arithmetic operations. Operations below compute the matrix multiplication between two tensors. y1, y2 will have the same value.
# ``tensor.T`` returns the transpose of a tensor
y1 = tensor @ tensor.T
y2 = jnp.matmul(tensor, tensor.T)
assert (y1 == y2).all()
# This computes the element-wise product. z1, z2 will have the same value
z1 = tensor * tensor
z2 = jnp.multiply(tensor, tensor)
assert (z1 == z2).all()
Single-element arrays. If you have a one-element array, for example by aggregating all values of a tensor into one value, you can convert it to a Python numerical value using .item()
:
agg = tensor.sum()
agg_value = agg.item()
print(agg_value, isinstance(agg_value, float), isinstance(agg, jax.Array))
9.0 True True
JAX follows NumPy in exposing a number of reduction and other operations as array methods:
jax_array = jnp.ones((2, 3))
jax_array.sum(), jax_array.mean(), jax_array.min(), jax_array.max(), jax_array.dot(jax_array.T), # ...
tensor = torch.ones(2, 3)
tensor.sum(), tensor.mean(), tensor.min(), tensor.max(), tensor.matmul(tensor.T), # ...
PyTorch exposes many more methods on its tensor object than either JAX or NumPy does on their respective array objects. Here are some examples of methods only available in PyTorch:
tensor.sigmoid(), tensor.softmax(dim=1), tensor.sin(), # ...
Automatic differentiation with JAX#
In this section, we will learn about the fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system, and its API has inspired the torch.func
module in PyTorch, previously known as “functorch” (JAX-like composable function transforms for PyTorch).
In PyTorch, there is an API to turn on the automatic operations graph recording (e.g., required_grad
argument and tensor.backward()
), but in JAX, automatic differentiation is a functional operation, i.e., there is no need to mark arrays with a flag to enable gradient tracking.
Let us follow autodiff PyTorch tutorial and consider the simplest one-layer neural network, with input x
, parameters w
and b
, and some loss function. In JAX, this can be defined in the following way:
import jax
import jax.numpy as jnp
# Input tensor
x = jnp.ones(5)
# Target output
y_true = jnp.zeros(3)
# Initialize random parameters
seed = 123
key = jax.random.key(seed)
key, w_key, b_key = jax.random.split(key, 3)
w = jax.random.normal(w_key, (5, 3))
b = jax.random.normal(b_key, (3, ))
# model function
def predict(x, w, b):
return jnp.matmul(x, w) + b
# Criterion or loss function
def compute_loss(w, b, x, y_true):
y_pred = predict(x, w, b)
return jnp.mean((y_true - y_pred) ** 2)
loss = compute_loss(w, b, x, y_true)
print(loss)
6.5595226
In our example network, w
and b
are parameters to optimize and we need to be able to compute the gradients of the loss function with respect to those variables. In order to do that, we use jax.grad
function on compute_loss
function:
# Differentiate `compute_loss` with respect to the 0 and 1 positional arguments:
w_grad, b_grad = jax.grad(compute_loss, argnums=(0, 1))(w, b, x, y_true)
print(f'{w_grad=}')
print(f'{b_grad=}')
w_grad=Array([[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275]], dtype=float32)
b_grad=Array([-1.6753345, 1.7790363, 1.6656275], dtype=float32)
# Compute w_grad, b_grad and loss value:
loss_value, (w_grad, b_grad) = jax.value_and_grad(compute_loss, argnums=(0, 1))(w, b, x, y_true)
print(f'{w_grad=}')
print(f'{b_grad=}')
print(f'{loss_value=}')
print(f'{compute_loss(w, b, x, y_true)=}')
w_grad=Array([[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275]], dtype=float32)
b_grad=Array([-1.6753345, 1.7790363, 1.6656275], dtype=float32)
loss_value=Array(6.5595226, dtype=float32)
compute_loss(w, b, x, y_true)=Array(6.5595226, dtype=float32)
jax.grad
and PyTrees#
JAX introduced the PyTree abstraction(e.g. Python containers like dicts, tuples, lists, etc which provides a uniform system for handling nested containers of array values) and its functional API works easily on these containers. Let us consider an example where we gathered our example network parameters into a dictionary:
net_params = {
"weights": w,
"bias": b,
}
def compute_loss2(net_params, x, y_true):
y_pred = predict(x, net_params["weights"], net_params["bias"])
return jnp.mean((y_true - y_pred) ** 2)
jax.value_and_grad(compute_loss2, argnums=0)({"weights": w, "bias": b}, x, y_true)
(Array(6.5595226, dtype=float32),
{'bias': Array([-1.6753345, 1.7790363, 1.6656275], dtype=float32),
'weights': Array([[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275],
[-1.6753345, 1.7790363, 1.6656275]], dtype=float32)})
The functional API in JAX easily allows us to compute higher order gradients by calling jax.grad
multiple times on the function. We will not cover this topic in this tutorial, for more details we suggest reading JAX automatic differentiation tutorial.
Build and train a model#
In this section we will learn how to build a simple model using Flax (flax.nnx
API) and optimize its parameters using training data provided by PyTorch dataloader.
Model creation with Flax is very similar to PyTorch using the torch.nn
module. In this example, we will build the ResNet18 model.
Build ResNet18 model#
# To install Flax: `pip install -U flax treescope optax`
import jax
import jax.numpy as jnp
from flax import nnx
class BasicBlock(nnx.Module):
def __init__(
self, in_planes: int, out_planes: int, do_downsample: bool = False, *, rngs: nnx.Rngs
):
strides = (2, 2) if do_downsample else (1, 1)
self.conv1_bn1 = nnx.Sequential(
nnx.Conv(
in_planes, out_planes, kernel_size=(3, 3), strides=strides,
padding="SAME", use_bias=False, rngs=rngs,
),
nnx.BatchNorm(out_planes, momentum=0.9, epsilon=1e-5, rngs=rngs),
)
self.conv2_bn2 = nnx.Sequential(
nnx.Conv(
out_planes, out_planes, kernel_size=(3, 3), strides=(1, 1),
padding="SAME", use_bias=False, rngs=rngs,
),
nnx.BatchNorm(out_planes, momentum=0.9, epsilon=1e-5, rngs=rngs),
)
if do_downsample:
self.conv3_bn3 = nnx.Sequential(
nnx.Conv(
in_planes, out_planes, kernel_size=(1, 1), strides=(2, 2),
padding="VALID", use_bias=False, rngs=rngs,
),
nnx.BatchNorm(out_planes, momentum=0.9, epsilon=1e-5, rngs=rngs),
)
else:
self.conv3_bn3 = lambda x: x
def __call__(self, x: jax.Array):
out = self.conv1_bn1(x)
out = nnx.relu(out)
out = self.conv2_bn2(out)
out = nnx.relu(out)
shortcut = self.conv3_bn3(x)
out += shortcut
out = nnx.relu(out)
return out
class ResNet18(nnx.Module):
def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
self.num_classes = num_classes
self.conv1_bn1 = nnx.Sequential(
nnx.Conv(
3, 64, kernel_size=(3, 3), strides=(1, 1), padding="SAME",
use_bias=False, rngs=rngs,
),
nnx.BatchNorm(64, momentum=0.9, epsilon=1e-5, rngs=rngs),
)
self.layer1 = nnx.Sequential(
BasicBlock(64, 64, rngs=rngs), BasicBlock(64, 64, rngs=rngs),
)
self.layer2 = nnx.Sequential(
BasicBlock(64, 128, do_downsample=True, rngs=rngs), BasicBlock(128, 128, rngs=rngs),
)
self.layer3 = nnx.Sequential(
BasicBlock(128, 256, do_downsample=True, rngs=rngs), BasicBlock(256, 256, rngs=rngs),
)
self.layer4 = nnx.Sequential(
BasicBlock(256, 512, do_downsample=True, rngs=rngs), BasicBlock(512, 512, rngs=rngs),
)
self.fc = nnx.Linear(512, self.num_classes, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.conv1_bn1(x)
x = nnx.relu(x)
x = nnx.max_pool(x, (2, 2), strides=(2, 2))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = nnx.avg_pool(x, (x.shape[1], x.shape[2]))
x = x.reshape((x.shape[0], -1))
x = self.fc(x)
return x
model = ResNet18(10, rngs=nnx.Rngs(0))
# Visualize the model architecture
nnx.display(model)
Let us test the model on a dummy data:
x = jnp.ones((4, 32, 32, 3))
y_pred = model(x)
y_pred.shape
(4, 10)
Note that the input array is explicitly in the channels-last memory format. In PyTorch, the typical input tensor to a neural network has channels-first memory format and has shape (4, 3, 32, 32)
by default.
Dataflow using Torchvision and PyTorch data loaders#
Let us now set up training and test data using the CIFAR10 dataset from torchvision
.
We will create torch dataloaders with collate functions returning NumPy Arrays instead of PyTorch tensors.
Since JAX is a multithreaded framework, using it in multiple processes can cause issues. For this reason, we will avoid creating JAX Arrays in the dataloaders.
As an alternative, one can use grain for data loading and PIX for image data augmentations.
# CIFAR10 training/testing datasets setup
import numpy as np
from torchvision.transforms import v2 as T
from torchvision.datasets import CIFAR10
def to_np_array(pil_image):
return np.asarray(pil_image)
def normalize(image):
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = image.astype(np.float32) / 255.0
return (image - mean) / std
train_transforms = T.Compose([
T.Pad(4),
T.RandomCrop(32, fill=128),
T.RandomHorizontalFlip(),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
test_transforms = T.Compose([
T.Lambda(to_np_array),
T.Lambda(normalize),
])
train_dataset = CIFAR10("./data", train=True, download=True, transform=train_transforms)
test_dataset = CIFAR10("./data", train=True, download=False, transform=test_transforms)
Files already downloaded and verified
# Data loaders setup
from torch.utils.data import DataLoader
batch_size = 512
def np_arrays_collate_fn(list_of_datapoints):
list_of_images = [dp[0] for dp in list_of_datapoints]
list_of_targets = [dp[1] for dp in list_of_datapoints]
return np.stack(list_of_images, axis=0), np.asarray(list_of_targets)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=np_arrays_collate_fn,
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=np_arrays_collate_fn,
)
# Let us check training dataloader:
trl_iter = iter(train_loader)
batch = next(trl_iter)
print(batch[0].shape, batch[0].dtype, batch[1].shape, batch[1].dtype)
(512, 32, 32, 3) float64 (512,) int64
/opt/conda/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Note: when executing the code above you may see this warning: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
. This warning can be ignored as dataloaders are not using JAX in forked processes.
Model training#
Let us now define the optimizer, loss function, train and test steps using Flax API. PyTorch users can find the code using Flax NNX API very similar to PyTorch.
import optax
learning_rate = 0.005
momentum = 0.9
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
def compute_loss_and_logits(model: nnx.Module, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
return loss, logits
@nnx.jit
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
"""Train for a single step."""
# convert numpy arrays to jnp.array on GPU
x, y_true = jnp.asarray(batch[0]), jnp.asarray(batch[1])
grad_fn = nnx.value_and_grad(compute_loss_and_logits, has_aux=True)
(loss, logits), grads = grad_fn(model, (x, y_true))
metrics.update(loss=loss, logits=logits, labels=y_true) # In-place updates.
optimizer.update(grads) # In-place updates.
return loss
@nnx.jit
def eval_step(model: nnx.Module, metrics: nnx.MultiMetric, batch):
# convert numpy arrays to jnp.array on GPU
x, y_true = jnp.asarray(batch[0]), jnp.asarray(batch[1])
loss, logits = compute_loss_and_logits(model, (x, y_true))
metrics.update(loss=loss, logits=logits, labels=y_true) # In-place updates.
Readers may note the nnx.jit
decorator of train_step
and eval_step
methods which is used to jit-compile the functions. JIT compilation in JAX is explored in the last section of this tutorial.
# Define helper object to compute train/test metrics
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average('loss'),
)
metrics_history = {
'train_loss': [],
'train_accuracy': [],
'test_loss': [],
'test_accuracy': [],
}
# Start the training
num_epochs = 3
for epoch in range(num_epochs):
metrics.reset() # Reset the metrics for the test set.
model.train() # Set model to the training mode: e.g. update batch statistics
for step, batch in enumerate(train_loader):
loss = train_step(model, optimizer, metrics, batch)
print(f"\r[train] epoch: {epoch + 1}/{num_epochs}, iteration: {step}, batch loss: {loss.item():.4f}", end="")
print("\r", end="")
for metric, value in metrics.compute().items(): # Compute the metrics.
metrics_history[f'train_{metric}'].append(value) # Record the metrics.
metrics.reset() # Reset the metrics for the test set.
# Compute the metrics on the test set after each training epoch.
model.eval() # Set model to evaluation model: e.g. use stored batch statistics
for test_batch in test_loader:
eval_step(model, metrics, test_batch)
# Log the test metrics.
for metric, value in metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
metrics.reset() # Reset the metrics for the next training epoch.
print(
f"[train] epoch: {epoch + 1}/{num_epochs}, "
f"loss: {metrics_history['train_loss'][-1]:0.4f}, "
f" accuracy: {metrics_history['train_accuracy'][-1] * 100:0.4f}"
)
print(
f"[test] epoch: {epoch + 1}/{num_epochs}, "
f"loss: {metrics_history['test_loss'][-1]:0.4f}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100:0.4f}"
"\n"
)
[train] epoch: 1/3, iteration: 97, batch loss: 1.4222
[train] epoch: 1/3, loss: 1.7809, accuracy: 38.3060
[test] epoch: 1/3, loss: 1.4749, accuracy: 46.9300
[train] epoch: 2/3, iteration: 96, batch loss: 1.1357
[train] epoch: 2/3, loss: 1.2433, accuracy: 55.02208
[test] epoch: 2/3, loss: 1.2523, accuracy: 57.0800
[train] epoch: 3/3, loss: 0.9977, accuracy: 64.42802
[test] epoch: 3/3, loss: 0.9752, accuracy: 65.6920
/opt/conda/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
/opt/conda/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
/opt/conda/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
/opt/conda/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Further reading#
More details about Flax NNX API, how to save and load the model’s state and about available optimizers, we suggest to check out the links below:
Other AI/ML tutorials to check out:
Just-In-Time (JIT) compilation in JAX#
PyTorch users know very well about the eager mode execution of the operations in PyTorch, e.g. the operations are executed one by one without any high-level optimizations on sets of operations. Similarly, almost everywhere in this tutorial we used JAX in the eager mode as well.
In PyTorch 1.0 TorchScript was introduced to optimize and serialize PyTorch models by capturing the execution graph into TorchScript programs, which can then be run independently from Python, e.g. as a C++ program.
In JAX, there is a similar transformation: jax.jit
. It performs JIT compilation of a Python function for efficient execution in XLA. Behind the scenes, jax.jit
wraps the input into tracers and is tracing the function to record all JAX operations. By default, JAX JIT is compiling the function on the first call and reusing the cached compiled XLA code on subsequent calls.
def matmul_relu_add(x, y):
z = x * y
return jax.nn.relu(z) + x
key = jax.random.key(123)
key1, key2 = jax.random.split(key)
x = jax.random.uniform(key1, (2500, 3000))
y = jax.random.uniform(key2, (2500, 3000))
%%timeit
matmul_relu_add(x, y)
320 μs ± 5.71 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jit_matmul_relu = jax.jit(matmul_relu_add)
# Warm-up: compile the function
_ = jit_matmul_relu(x, y)
%%timeit
jit_matmul_relu(x, y)
93.7 μs ± 903 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)