Introduction to Data Loaders on GPU with JAX#
This tutorial explores different data loading strategies for using JAX on a single GPU. While JAX doesn’t include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:
You’ll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.
Compared to CPU-based loading, working with a GPU introduces specific considerations like transferring data to the GPU using device_put
, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU.
If you’re looking for CPU-specific data loading advice, see Data Loaders on CPU.
If you’re looking for a multi-device data loading strategy, see Data Loaders on Multi-Device Setups.
Import JAX API#
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, device_put
Checking GPU Availability for JAX#
jax.devices()
[CudaDevice(id=0)]
Setting Hyperparameters and Initializing Parameters#
You’ll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You’ll also initialize the weights and biases for a fully-connected neural network.
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Function to initialize network parameters for all layers based on defined sizes
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10] # Layers of the network
step_size = 0.01 # Learning rate
num_epochs = 8 # Number of training epochs
batch_size = 128 # Batch size for training
n_targets = 10 # Number of classes (digits 0-9)
num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels
data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset
# Initialize network parameters using the defined layer sizes and a random seed
params = init_network_params(layer_sizes, random.PRNGKey(0))
Model Prediction with Auto-Batching#
In this section, you’ll define the predict
function for your neural network. This function computes the output of the network for a single input image.
To efficiently process multiple images simultaneously, you’ll use vmap
, which allows you to vectorize the predict
function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration.
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
Utility and Loss Functions#
You’ll now define utility functions for:
One-hot encoding: Converts class indices to binary vectors.
Accuracy calculation: Measures the performance of the model on the dataset.
Loss computation: Calculates the difference between predictions and targets.
To optimize performance:
grad
is used to compute gradients of the loss function with respect to network parameters.jit
compiles the update function, enabling faster execution by leveraging JAX’s XLA compilation.device_put
to transfer the dataset to the GPU.
import time
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
"""Calculate the accuracy of predictions."""
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
"""Calculate the loss between predictions and targets."""
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
"""Update the network parameters using gradient descent."""
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
def reshape_and_one_hot(x, y):
"""Reshape and one-hot encode the inputs."""
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, n_targets)
return x, y
def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):
"""Train the model for a given number of epochs and device_put for GPU transfer."""
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:
x, y = reshape_and_one_hot(x, y)
x, y = device_put(x), device_put(y)
params = update(params, x, y)
print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: "
f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, "
f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}")
Loading Data with PyTorch DataLoader#
This section shows how to load the MNIST dataset using PyTorch’s DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images.
!pip install torch torchvision
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST
def numpy_collate(batch):
"""Collate function to convert a batch of PyTorch data into NumPy arrays."""
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
"""Custom DataLoader to return NumPy arrays from a PyTorch Dataset."""
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
"""Transform class to flatten and cast images to float32."""
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
Load Dataset with Transformations#
Standardize the data by flattening the images, casting them to float32
, and ensuring consistent data types.
mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())
Full Training Dataset for Accuracy Checks#
Convert the entire training dataset to JAX arrays.
train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)
train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)
Get Full Test Dataset#
Load and process the full test dataset.
mnist_dataset_test = MNIST(data_dir, download=True, train=False)
test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
Training Data Generator#
Define a generator function using PyTorch’s DataLoader for batch training.
Setting num_workers > 0
enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.
Note: When setting num_workers > 0
, you may see the following RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
This warning can be safely ignored since data loaders do not use JAX within the forked processes.
def pytorch_training_generator(mnist_dataset):
return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
Training Loop (PyTorch DataLoader)#
The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters.
train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')
Epoch 1 in 20.23 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195
Epoch 2 in 14.64 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385
Epoch 3 in 3.91 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467
Epoch 4 in 3.85 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532
Epoch 5 in 4.48 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577
Epoch 6 in 4.03 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617
Epoch 7 in 3.86 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652
Epoch 8 in 4.57 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671
Loading Data with TensorFlow Datasets (TFDS)#
This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow.
import tensorflow_datasets as tfds
Fetch Full Dataset for Evaluation#
Load the dataset with tfds.load
, convert it to NumPy arrays, and process it for evaluation.
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, n_targets)
# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, n_targets)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
Define the Training Generator#
Create a generator function to yield batches of data for training.
def training_generator():
# as_supervised=True gives us the (image, label) as a tuple instead of a dict
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
# You can build up an arbitrary tf.data input pipeline
ds = ds.batch(batch_size).prefetch(1)
# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
return tfds.as_numpy(ds)
Training Loop (TFDS)#
Use the training generator in a custom training loop.
train_model(num_epochs, params, training_generator)
Epoch 1 in 20.86 sec: Train Accuracy: 0.9253, Test Accuracy: 0.9268
Epoch 2 in 8.56 sec: Train Accuracy: 0.9428, Test Accuracy: 0.9413
Epoch 3 in 5.40 sec: Train Accuracy: 0.9532, Test Accuracy: 0.9511
Epoch 4 in 3.86 sec: Train Accuracy: 0.9598, Test Accuracy: 0.9555
Epoch 5 in 3.88 sec: Train Accuracy: 0.9652, Test Accuracy: 0.9601
Epoch 6 in 10.35 sec: Train Accuracy: 0.9692, Test Accuracy: 0.9631
Epoch 7 in 4.39 sec: Train Accuracy: 0.9726, Test Accuracy: 0.9650
Epoch 8 in 4.77 sec: Train Accuracy: 0.9753, Test Accuracy: 0.9669
Loading Data with Grain#
This section demonstrates how to load MNIST data using Grain, a data-loading library. You’ll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training.
Install Grain
!pip install grain
Requirement already satisfied: grain in /usr/local/lib/python3.10/dist-packages (0.2.2)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from grain) (1.4.0)
Requirement already satisfied: array-record in /usr/local/lib/python3.10/dist-packages (from grain) (0.5.1)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from grain) (3.1.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from grain) (0.1.8)
Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.10/dist-packages (from grain) (1.10.0)
Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (from grain) (0.2.36)
Requirement already satisfied: more-itertools>=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.9.0)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)
Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)
Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)
Import Required Libraries (import MNIST dataset from torchvision)
import numpy as np
import grain.python as pygrain
from torchvision.datasets import MNIST
Define Dataset Class#
Create a custom dataset class to load MNIST data for Grain.
class Dataset:
def __init__(self, data_dir, train=True):
self.data_dir = data_dir
self.train = train
self.load_data()
def load_data(self):
self.dataset = MNIST(self.data_dir, download=True, train=self.train)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img, label = self.dataset[index]
return np.ravel(np.array(img, dtype=np.float32)), label
Initialize the Dataset#
mnist_dataset = Dataset(data_dir)
Get the full train and test dataset#
# Convert training data to JAX arrays and encode labels as one-hot vectors
train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)
train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)
# Load test dataset and process it
mnist_dataset_test = MNIST(data_dir, download=True, train=False)
test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)
test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)
print("Train:", train_images.shape, train_labels.shape)
print("Test:", test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
Initialize PyGrain DataLoader#
sampler = pygrain.SequentialSampler(
num_records=len(mnist_dataset),
shard_options=pygrain.NoSharding()) # Single-device, no sharding
def pygrain_training_generator():
return pygrain.DataLoader(
data_source=mnist_dataset,
sampler=sampler,
operations=[pygrain.Batch(batch_size)],
)
Training Loop (Grain)#
Run the training loop using the Grain DataLoader.
train_model(num_epochs, params, pygrain_training_generator)
Epoch 1 in 15.65 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195
Epoch 2 in 15.03 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385
Epoch 3 in 14.93 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467
Epoch 4 in 11.56 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532
Epoch 5 in 9.33 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577
Epoch 6 in 9.31 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617
Epoch 7 in 9.78 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652
Epoch 8 in 9.80 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671
Loading Data with Hugging Face#
This section demonstrates loading MNIST data using the Hugging Face datasets
library. You’ll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator.
Install the Hugging Face datasets
library.
!pip install datasets
Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.1.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)
Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)
Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.6)
Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)
Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)
Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
from datasets import load_dataset
Load the MNIST dataset from Hugging Face and format it as numpy
arrays for quick access or jax
to get JAX arrays.
mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy")
Extract images and labels#
Get image shape and flatten for model input.
train_images = mnist_dataset["train"]["image"]
train_labels = mnist_dataset["train"]["label"]
test_images = mnist_dataset["test"]["image"]
test_labels = mnist_dataset["test"]["label"]
# Extract image shape
image_shape = train_images.shape[1:]
num_features = image_shape[0] * image_shape[1]
# Flatten the images
train_images = train_images.reshape(-1, num_features)
test_images = test_images.reshape(-1, num_features)
# One-hot encode the labels
train_labels = one_hot(train_labels, n_targets)
test_labels = one_hot(test_labels, n_targets)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
Define Training Generator#
Set up a generator to yield batches of images and labels for training.
def hf_training_generator():
"""Yield batches for training."""
for batch in mnist_dataset["train"].iter(batch_size):
x, y = batch["image"], batch["label"]
yield x, y
Training Loop (Hugging Face Datasets)#
Run the training loop using the Hugging Face training generator.
train_model(num_epochs, params, hf_training_generator)
Epoch 1 in 19.06 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195
Epoch 2 in 8.94 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385
Epoch 3 in 5.43 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467
Epoch 4 in 6.41 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532
Epoch 5 in 5.80 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577
Epoch 6 in 6.61 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617
Epoch 7 in 5.49 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652
Epoch 8 in 6.64 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671
Summary#
This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using device_put
for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements.