Train a miniGPT language model with JAX#

Open in Colab

This tutorial demonstrates how to use JAX, Flax NNX and Optax for language model (pre)training using data and tensor parallelism for Single-Program Multi-Data). It was originally inspired by the Keras miniGPT tutorial.

Here, you will learn how to:

  • Define the miniGPT model with Flax and JAX automatic parallelism

  • Load and preprocess the dataset

  • Create the loss and training step functions

  • Train the model on Google Colab’s Cloud TPU v2

  • Profile for hyperparameter tuning

If you are new to JAX for AI, check out the introductory tutorial, which covers neural network building with Flax NNX.

Setup#

JAX installation is covered in this guide on the JAX documentation site. We will use Tiktoken for tokenization and Grain for data loading.

!pip install -Uq tiktoken grain matplotlib
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 780.7/780.7 kB 6.5 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 54.7 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 289.6/289.6 kB 8.4 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 270.5/270.5 kB 13.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.5/62.5 kB 2.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 15.3/15.3 MB 61.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 128.1/128.1 kB 6.7 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.4/42.4 kB 1.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.5/49.5 kB 1.0 MB/s eta 0:00:00
?25hERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires ml-dtypes~=0.2.0, but you have ml-dtypes 0.4.0 which is incompatible.
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 15.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 419.0/419.0 kB 20.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.3/8.3 MB 99.4 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.0/61.0 kB 3.6 MB/s eta 0:00:00
?25h

Note: If you are using Google Colab, select the free Google Cloud TPU v2 as the hardware accelerator.

Check the available JAX devices, or jax.Device, with jax.devices(). The output of the cell below will show a list of 8 (eight) devices.

import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Get the TinyStories dataset from Hugging Face. We only use the training split.

!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
--2024-11-01 02:50:38--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true
Resolving huggingface.co (huggingface.co)... 65.8.243.46, 65.8.243.92, 65.8.243.90, ...
Connecting to huggingface.co (huggingface.co)|65.8.243.46|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/c5cf5e22ff13614e830afbe61a99fbcbe8bcb7dd72252b989fa1117a368d401f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-train.txt%3B+filename%3D%22TinyStories-train.txt%22%3B&response-content-type=text%2Fplain&Expires=1730688639&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMDY4ODYzOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxL2M1Y2Y1ZTIyZmYxMzYxNGU4MzBhZmJlNjFhOTlmYmNiZThiY2I3ZGQ3MjI1MmI5ODlmYTExMTdhMzY4ZDQwMWY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=oQHJBcHVix9N1HnNsJSj7KK-BoqdXdl6NRh%7E1ilGx-ROnLrZxKINfonOtva5e5Xf9KQVNl6QQkx5gNw4iMTmS6JRFB%7EcXdTcFjrHSnBxwLRZkMCBKAv3oHhRnJ6I2rV8iBAZTq%7E-caDCLFvBrgT9pcEFakh3-5mSp%7ER7hnNqE5lcE5n7tzXS0l-8tOShDmR5aUCFPStZHfPbyS3MwCAdc2KoqXdqzRf9M4WvXWB78El7WGxse0DrTQFbGGW1kjpvBOqzljH0Qn6WqsiBockhHDbwE1nQmGfxKrbreXenAKdOsUTN9fuRKl-6srhI2xGKFpfu3IGDEN%7Ebmwg8CnwAfQ__&Key-Pair-Id=K3RPWS32NSSJCE [following]
--2024-11-01 02:50:39--  https://cdn-lfs.hf.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/c5cf5e22ff13614e830afbe61a99fbcbe8bcb7dd72252b989fa1117a368d401f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-train.txt%3B+filename%3D%22TinyStories-train.txt%22%3B&response-content-type=text%2Fplain&Expires=1730688639&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMDY4ODYzOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxL2M1Y2Y1ZTIyZmYxMzYxNGU4MzBhZmJlNjFhOTlmYmNiZThiY2I3ZGQ3MjI1MmI5ODlmYTExMTdhMzY4ZDQwMWY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=oQHJBcHVix9N1HnNsJSj7KK-BoqdXdl6NRh%7E1ilGx-ROnLrZxKINfonOtva5e5Xf9KQVNl6QQkx5gNw4iMTmS6JRFB%7EcXdTcFjrHSnBxwLRZkMCBKAv3oHhRnJ6I2rV8iBAZTq%7E-caDCLFvBrgT9pcEFakh3-5mSp%7ER7hnNqE5lcE5n7tzXS0l-8tOShDmR5aUCFPStZHfPbyS3MwCAdc2KoqXdqzRf9M4WvXWB78El7WGxse0DrTQFbGGW1kjpvBOqzljH0Qn6WqsiBockhHDbwE1nQmGfxKrbreXenAKdOsUTN9fuRKl-6srhI2xGKFpfu3IGDEN%7Ebmwg8CnwAfQ__&Key-Pair-Id=K3RPWS32NSSJCE
Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 3.167.152.12, 3.167.152.119, 3.167.152.37, ...
Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|3.167.152.12|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1924281556 (1.8G) [text/plain]
Saving to: ‘TinyStories-train.txt’

TinyStories-train.t 100%[===================>]   1.79G  38.1MB/s    in 45s     

2024-11-01 02:51:24 (40.7 MB/s) - ‘TinyStories-train.txt’ saved [1924281556/1924281556]

Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Grain, pandas, and Tiktoken:

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
from jax.experimental import mesh_utils

import flax.nnx as nnx
import optax

from dataclasses import dataclass
import grain.python as pygrain
import pandas as pd
import tiktoken
import time

Define the miniGPT model with Flax and JAX automatic parallelism#

Leveraging JAX’s data and tensor parallelism#

One of the most powerful features of JAX is device parallelism for SPMD.

  • The data parallelism technique enables, for example, the training data to run via multiple parts (this is called sharding) - batches - in parallel and simultaneously across different devices, such as GPUs and Google TPUs. This allows to use larger batch sizes to speed up training.

  • Tensor parallelism allows us to split the model parameter tensors across several devices (sharding model tensors).

  • You can learn more about the basics of JAX parallelism in more detail in the Introduction to parallel programming on the JAX documentation site.

In this example, we’ll utilize a 4-way data parallel and 2-way tensor parallel setup. The free Google Cloud TPU v2 on Google Colab offers 4 chips, each with 2 TPU cores. The TPU v2 architeture aligns with the proposed setup.

jax.sharding.Mesh#

Earlier, we imported jax.sharding.Mesh - is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, such as 'x' or 'y'. This will help encapsulate the information about the TPU resource organization for distributing computations across the devices.

Our Mesh will have two arguments:

  • devices: This will take the value of jax.experimental.mesh_utils((4, 2)), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from jax.devices())..

  • axis_names, where:

    • batch: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and

    • model: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.

This matches the (4, 2) structure in the Colab’s TPU v2 setup.

Let’s instantiate Mesh as mesh and declare the TPU configuration to define how data and model parameters are distributed across the devices:

# Create a `Mesh` object representing TPU device arrangement.
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

### Alternatively, we could use the 8-way data parallelism with only one line of code change.
### JAX enables quick experimentation with different partitioning strategies
### like this. We will come back to this point at the end of this tutorial.
# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

We will use the GPT-2 tokenizer from the Tiktoken library:

tokenizer = tiktoken.get_encoding("gpt2")

To leverage model parallelism, we need to instruct the JAX compiler how to shard the model tensors across the TPU devices. Earlier, we also imported jax.sharding.PartitionSpec and jax.sharding.NamedSharding:

  • PartitionSpec (using alias P) defines how tensors are sharded across the devices in our Mesh. Its elements describe how an input dimension is partitioned across mesh dimensions. For example, in PartitionSpec('x', 'y') the first dimension of data is sharded across x axis of the mesh, and the second one - across the y axis.

    • We’ll use PartitionSpec to describe how to shard a tensor across, for example, the model axis or be replicated on other dimensions (which is denoted by None).

  • NamedSharding is a (Mesh, PartitionSpec) pair that describes how to shard a model tensor across our mesh.

  • We combine Mesh (the TPU resources) with PartitionSpec and create a NamedSharding, which instructs how to shard each model tensor across the TPU devices.

Additionally, we’ll use Flax NNX’s flax.nnx.with_partitioning to let each model layer know that the model weights or tensors need to be sharded according to our specification. We need to do this for every tensor/layer in the model.

# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nnx.Module):
    """ A single Transformer block.

    Each Transformer block processes input sequences via self-attention and feed-forward networks.

    Args:
        embed_dim (int): Embedding dimensionality.
        num_heads (int): Number of attention heads.
        ff_dim (int): Dimensionality of the feed-forward network.
        rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
        rate (float): Dropout rate. Defaults to 0.1.
    """
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
        # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.
        # Specifies tensor sharding (depending on the mesh configuration)
        # where we shard the weights across devices for parallel computation.
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                          rngs=rngs)
        # The first dropout with `flax.nnx.Dropout`.
        self.dropout1 = nnx.Dropout(rate=rate)
        # First layer normalization with `flax.nnx.LayerNorm`.
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                         rngs=rngs)
        # The first linear transformation for the feed-forward network with `flax.nnx.Linear`.
        self.linear1 = nnx.Linear(in_features=embed_dim,
                                  out_features=ff_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  rngs=rngs)
        # The second linear transformation for the feed-forward network with `flax.nnx.Linear`.
        self.linear2 = nnx.Linear(in_features=ff_dim,
                                  out_features=embed_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  rngs=rngs)
        # The second dropout with `flax.nnx.Dropout`.
        self.dropout2 = nnx.Dropout(rate=rate)
        # Second layer normalization with `flax.nnx.LayerNorm`.
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                         rngs=rngs)


    # Apply the Transformer block to the input sequence.
    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        _, seq_len, _ = input_shape

        # Instantiate the causal attention mask.
        mask = causal_attention_mask(seq_len)

        # Apply Multi-Head Attention with the causal attention mask.
        attention_output = self.mha(
            inputs_q=inputs,
            mask=mask,
            decode=False
        )
        # Apply the first dropout.
        attention_output = self.dropout1(attention_output, deterministic=not training)
        # Apply the first layer normalization.
        out1 = self.layer_norm1(inputs + attention_output)

        # The feed-forward network.
        # Apply the first linear transformation.
        ffn_output = self.linear1(out1)
        # Apply the ReLU activation with `flax.nnx.relu`.
        ffn_output = nnx.relu(ffn_output)
        # Apply the second linear transformation.
        ffn_output = self.linear2(ffn_output)
        # Apply the second dropout.
        ffn_output = self.dropout2(ffn_output, deterministic=not training)
        # Apply the second layer normalization and return the output of the Transformer block.
        return self.layer_norm2(out1 + ffn_output)

class TokenAndPositionEmbedding(nnx.Module):
    """ Combines token embeddings (words in an input sentence) with
    positional embeddings (the position of each word in a sentence).

    Args:
        maxlen (int): Matimum sequence length.
        vocal_size (int): Vocabulary size.
        embed_dim (int): Embedding dimensionality.
        rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
    """
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        # Initialize token embeddings (using `flax.nnx.Embed`).
        # Each unique word has an embedding vector.
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        # Initialize positional embeddings (using `flax.nnx.Embed`).
        self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)

    # Takes a token sequence (integers) and returns the combined token and positional embeddings.
    def __call__(self, x):
        # Generate a sequence of positions for the input tokens.
        positions = jnp.arange(0, x.shape[1])[None, :]
        # Look up the positional embeddings for each position in the input sequence.
        position_embedding = self.pos_emb(positions)
        # Look up the token embeddings for each token in the input sequence.
        token_embedding = self.token_emb(x)
        # Combine token and positional embeddings.
        return token_embedding + position_embedding

class MiniGPT(nnx.Module):
    """ A miniGPT transformer model, inherits from `flax.nnx.Module`.

    Args:
        maxlen (int): Maximum sequence length.
        vocab_size (int): Vocabulary size.
        embed_dim (int): Embedding dimensionality.
        num_heads (int): Number of attention heads.
        feed_forward_dim (int): Dimensionality of the feed-forward network.
        num_transformer_blocks (int): Number of transformer blocks. Each block contains attention and feed-forward networks.
        rngs (nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
    """
    # Initialize miniGPT model components.
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        # Initiliaze the `TokenAndPositionEmbedding` that combines token and positional embeddings.
        self.embedding_layer = TokenAndPositionEmbedding(
                    maxlen, vocab_size, embed_dim, rngs=rngs
                )
        # Create a list of `TransformerBlock` instances.
        # Each block processes input sequences using attention and feed-forward networks.
        self.transformer_blocks = [TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, rngs=rngs
        ) for _ in range(num_transformer_blocks)]
        # Initialize the output `flax.nnx.Linear` layer producing logits over the vocabulary for next-token prediction.
        self.output_layer = nnx.Linear(in_features=embed_dim,
                                       out_features=vocab_size,
                                       kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                       bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                       rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        # Pass the input tokens through the `embedding_layer` to get token embeddings.
        # Apply each transformer block sequentially to the embedded input, use the `training` flag for the behavior of `flax.nnx.Dropout`.
        x = self.embedding_layer(inputs)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        # Pass the output of the transformer blocks through the output layer,
        # and obtain logits for each token in the vocabulary (for next token prediction).
        outputs = self.output_layer(x)
        return outputs

    # Text generation.
    def generate_text(self, max_tokens: int, start_tokens: [int], top_k=10):
        # Sample the next token from a probability distribution based on
        # `logits` and `tok_k` (top-k) sampling strategy.
        def sample_from(logits):
            logits, indices = jax.lax.top_k(logits, k=top_k)
            # Convert logits to probabilities (using `flax.nnx.softmax`).
            logits = nnx.softmax(logits)
            return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

        # Generate text one token at a time until the maximum token limit is reached (`maxlen`).
        def generate_step(start_tokens):
            pad_len = maxlen - len(start_tokens)
            # Index of the last token in the current sequence.
            sample_index = len(start_tokens) - 1
            # If the input is longer than `maxlen`, then truncate it.
            if pad_len < 0:
                x = jnp.array(start_tokens[:maxlen])
                sample_index = maxlen - 1
            # If the input is shorter than `maxlen`, then pad it (`pad_len`).
            elif pad_len > 0:
                x = jnp.array(start_tokens + [0] * pad_len)
            else:
                x = jnp.array(start_tokens)

            # Add a batch dimension.
            x = x[None, :]
            logits = self(x)
            next_token = sample_from(logits[0][sample_index])
            return next_token

        # Store generated tokens.
        generated = []
        # Generate tokens until the end-of-text token is encountered or the maximum token limit is reached.
        for _ in range(max_tokens):
            next_token = generate_step(start_tokens + generated)
            # Truncate whatever is after '<|endoftext|>' (stop word)
            if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
              # Stop text generation if the end-of-text token is encountered.
              break
            generated.append(int(next_token))
        # Decode the generated token IDs into text.
        return tokenizer.decode(start_tokens + generated)

# Creates the miniGPT model with 4 transformer blocks.
def create_model(rngs):
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)

Set some hyperparameters.

vocab_size = tokenizer.n_vocab
num_transformer_blocks = 8
maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256
batch_size = 256 # You can set a bigger batch size if you use Kaggle's Cloud TPU.
num_epochs = 1

Loading and preprocessing the data#

Data loading and preprocessing with Grain.

@dataclass
class TextDataset:
    data: list
    maxlen: int

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        # Use Tiktoken for tokenization
        encoding = tokenizer.encode(self.data[idx], allowed_special={'<|endoftext|>'})[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story+'<|endoftext|>' for story in stories if story.strip()]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dl

text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen)

Defining the loss function and training step function#

# Defines the loss function using `optax.softmax_cross_entropy_with_integer_labels`.
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits

# Define the training step with the `flax.nnx.jit` transformation decorator.
@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Training the model#

Start training. It takes ~50 minutes on Colab.

Note that for data parallel, we are sharding the training data along the batch axis using jax.device_put with NamedeSharding.

We are also using the jax.vmap transformation to produce the target sequences faster.

model = create_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)
rng = jax.random.PRNGKey(0)

start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
generated_text = model.generate_text(
    maxlen, start_tokens
)
print(f"Initial generated text:\n{generated_text}\n")


metrics_history = {
  'train_loss': [],
}

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

step = 0
for epoch in range(num_epochs):
    start_time = time.time()
    for batch in text_dl:
        if len(batch) % len(jax.devices()) != 0:
          continue  # skip the remaining elements
        input_batch = jnp.array(jnp.array(batch).T)
        target_batch = prep_target_batch(input_batch)
        train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))

        if (step + 1) % 200 == 0:
          for metric, value in metrics.compute().items():
              metrics_history[f'train_{metric}'].append(value)
          metrics.reset()

          elapsed_time = time.time() - start_time
          print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
          start_time = time.time()

          generated_text = model.generate_text(
              maxlen, start_tokens
          )
          print(f"Generated text:\n{generated_text}\n")
        step += 1

# Final text generation
generated_text = model.generate_text(
    maxlen, start_tokens
)
print(f"Final generated text:\n{generated_text}")
Initial generated text:
Once upon a time Christina Raven Liqu Everyday seaw Spl digit mini Hungarian wasteful USC recurrent brawl towers summAvailability manualsidsAvailability Jord staleEarlier 303 Latter soakinginated pierced acquaint propaganda differentlyBesides Splambling Significant processing locals FoundingFlickrverbalSquaresth pixels CON repetitivebass%; dartsKN ushered sim wasteful Qi510 174 (_ Hillaryall hopeddalePref recurrentbassoves AOL ushered Hunt manuals NietzscheidsBY Equ souls correctedresaKN ghamblinguador contest cornerback bannedKN realizedSix summlargest gh fastest req influences cursingosureelse delighted wrecked donors codsedentiallyindaletteogenicAI summ wasteful USCesm shaped Garrett resistance grandchildren souls babyStatementambling fastestirin AWSiden groundedKen%; aboarddogs seaw Sultan Sachs Sonic ArchivesINE darts belts asylumei simette expands targetintergroupon Graveyard Graveyard398Jordan 66 medication Leadership 174?: seaw manuals summ asylumrw slice manualsiries Prometheus� Seat correctedINE denomination summ vastlyKNKN belts?: contest PamelaidiumKN themHI seawKN minions summ squadKN Joker sacredamblingKNuckyKNette 69 Xan 69ourse notificationuku Sitting cosmeticakesGro McAuliffeilles Graveyard differe <-Jordan Archives 180 Puppet cabinetodcast spir305 bannedambling 66 medicationbass victory relatingakespe Rover GarrettPrefppo sim recurrent manualsidsrg eveningsossus asylum Puppet hydra SultanProxy 66 chew Jokeranswer%;Loc Australian awaidiumdale landed Luahangはambling SuddenlyKN victory victory victory victory

Step 200, Loss: 4.541538715362549, Elapsed Time: 119.14 seconds
Generated text:
Once upon a time was a time was was was was very little girl, day was so happy, her little girl, her little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little!!!!

Step 400, Loss: 2.8348119258880615, Elapsed Time: 103.58 seconds
Generated text:
Once upon a time there was a small cat named Jack. Tim was a little girl who lived in the world and wanted to explore the forest. The dog, he had a new toy car and his mommy and daddy. He wanted to be so much. He went to help the ball. He said to be careful to go and said, "What's wrong!"
After they found the other kids said, Timmy was so excited and he could see his friend. Timmy said, "Let's play in the other animals, but I want to help."
Timmy smiled and they played with a while. Timmy said, "It's not listen to play. "Thank you. "You are happy."


Step 600, Loss: 2.3290421962738037, Elapsed Time: 69.78 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She was very curious and loved to play outside in the garden. One day, she was playing with her toys and saw a big box with a shiny red ball. The ball was shiny and wanted to climb it, but she didn't want to be too far. 
"Hi, I can't get my toy!" said, but Lily said, "No, you are not yours." 
Lily's mom smiled and said, "You can help me go to me!"
"I want to get some candy and it," she said. "That's a big dog."
The cat said, "No, we need to go home. I'm going to find some." 
Her mom said, "Okay, you will get to play with me."
Lily smiled and said, "I want to get my favorite toy first!" 
Her mom said, "Okay, but they don't know. They both thought it is so much fun to go on their mom and have fun. They played together, but they did not wait to be friends and they had fun together. 


Step 800, Loss: 2.049470901489258, Elapsed Time: 84.69 seconds
Generated text:
Once upon a time, there was a little girl named Lucy was playing in the garden. She loved to run and play outside and she had lots of fun. She would always look at all day and her friends in the garden.
One day, she was playing outside when she heard a strange noise coming from a loud noise. It was coming from a dark and it started to shake. It was a big dog, so scared.
Her mom came into the room and said, "That's not safe!"
The dog smiled and said, "Don't worry, I will find my friend." Lucy and the dog started to walk in the grass, but soon she heard the sound coming from the garden, and she heard a noise. The dog ran back inside, but the door opened and saw that the shadow had a hole in a hole. Lucy and the dog became friends, and they played together all the things and they never found a beautiful spot to get home.


Step 1000, Loss: 1.8924535512924194, Elapsed Time: 76.64 seconds
Generated text:
Once upon a time, there was a girl named Amy. Amy loved to go outside and play with her friends. One day, Amy saw a big tree and asked her friends, "Can I have some best friends?" The tree was so big that it made a special tree to get in the tree. 
Amy went to the tree and asked, "Can I go on my tree?" The tree replied, "Sure, let's play together!" 
So, Amy and the tree were very excited. Amy started to play, but it was too heavy for the tree. She started to climb and forth, but it was so far away. Amy's mom saw how fast it was. She asked, "Why do you think this, but I want to get to be my friend?" Amy said, "No, it's too late, so we can't play together!" Amy's friend started to fight and said, "I'll be friends. Let's play together!" Amy and Amy laughed, and laughed together, and they became best friends and they had a fun day together.


Step 1200, Loss: 1.7993096113204956, Elapsed Time: 80.74 seconds
Generated text:
Once upon a time, there was a little girl named Amy. She loved to go outside and play in the sunshine. One day, she went outside to play with her friends. They went on a sunny day and found a shiny ball. 
Amy wanted to see the ball and it started to fly it. But then, a little girl came to her friend. Her friend saw her and asked if she could join the bird. The bird said yes and Lily took it to her house.
Amy was very happy and she started to sing and dance with her friend. They sang together and danced until the sun came up. They laughed and danced together, laughing and having fun. They played together all day and the day long.


Step 1400, Loss: 1.7445931434631348, Elapsed Time: 69.74 seconds
Generated text:
Once upon a time, there was a girl named Sue. She was only three years old. Sue loved to explore. One day, she saw a little bird sitting on a branch. Sue was scared. She ran up to her mom and said, "Mom, what's that noise?" Her mom smiled and said, "It's okay, Sue. We don't have any more fun!"
Sue and her mom looked around the house and found a small bird that made her feel better. It was a little bird. Sue and her mom went outside and played in the sun. They had a great time playing and laughing.
At the end of the day, Sue said, "Thank you, Mommy! You're welcome! You are very lucky to be a friend." Her mom smiled and said, "Thank you, mommy. I love your new friend!"


Step 1600, Loss: 1.6877646446228027, Elapsed Time: 74.16 seconds
Generated text:
Once upon a time, a boy named Jack was walking down the street. He was feeling very scared. His mom told him that he had to go outside and get to play. 
One day, Jack noticed a little boy playing in the park. He wanted to play with the equipment too, but he was scared. 
Jack ran over and asked, "Why did you get me?" His mom said, "I don't want to be scared. I'm sorry I can't have the equipment to take the equipment to play." 
Jack thought about the equipment he would make it even more fun with the equipment. So, he started to build the equipment. When he was finished, Jack felt a bit better. 
The equipment was a way to the playground, but he knew that his mom was not alone, she had to help him get better. She asked him to stay in the playground, but Jack was still happy to have the equipment to stay, but Jack couldn't.
Jack learned that being kind was important to help and not be afraid of being ignorant and selfish. He was still happy, but he was still very selfish and always asked the equipment for help.


Step 1800, Loss: 1.6456927061080933, Elapsed Time: 85.76 seconds
Generated text:
Once upon a time, there was a little girl named Mia. Mia loved to play outside in the park. One day, Mia found a shiny coin in the park. She wanted it for her birthday party, but it was too expensive. Mia thought it would not buy the coin for her birthday.
So, Mia decided to ask her mom for the coin to buy the coin. The coin looked for a long time and said yes. Mia put the coin back in the park, but it was gone. Mia was very sad.
So, Mia asked her mom if she could have a special day. Her mom said yes and they both went home with the coin. Mia was so happy! She said yes, but the coin was gone forever.


Step 2000, Loss: 1.6211789846420288, Elapsed Time: 69.42 seconds
Generated text:
Once upon a time, there was a small boy named Tim. Tim was a small boy named Timmy. Timmy loved to play with his toy car and he loved to make noises with it.
One day, Timmy and his toy car went to the park with his toy car. Timmy's car got stuck in the car and it crashed. Timmy's car was too fast and it hit a tree with the wheel.
Timmy was sad because he loved his car so much. He knew he could help make his toy car better. He said, "I will make you happy! Let's play together and make a new car." And Timmy and his toy car had fun together and they became good friends.


Step 2200, Loss: 1.578463077545166, Elapsed Time: 68.18 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was only three years old and loved to play with her toys. One day, she found a special treasure, a big, scary toy that was very big and she was so excited to take it home. 
But, one day, Lucy found a shiny jewel that she was very happy to find it. She opened the door and saw a beautiful necklace in the corner of the room. She wanted to see what it was like. She asked her friend, the necklace, and the necklace said, "Can I take it home now?" 
Lucy smiled and said, "I can borrow it with a special diamond. It's very special because it belongs to someone else."
The diamond was so happy to see Lucy and said, "Thank you, Lucy. You are a great friend!"


Step 2400, Loss: 1.5326989889144897, Elapsed Time: 73.51 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play outside in the sun. One day, Timmy saw a beautiful flower. It was so pretty!
Timmy asked his mom, "What is this flower?"
His mom replied, "It's a flower. Do you want to play with it?"
Timmy nodded his head and said, "Yes, please."
Timmy was so happy and he played with the flower. They laughed and played until the sun went down.
Suddenly, they heard a loud noise. It was a little kitten. The kitten was scared and didn't know what to do. Timmy's mom told him it was a bad ending.


Step 2600, Loss: 1.5533288717269897, Elapsed Time: 70.30 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She was very happy because she had a pretty blouse. One day, Lily was playing with her friend Jack came over to play. She asked, "Can I play with you?" 
Jack said, "Sure, let's play with your friends!" Lily was excited and said, "Okay, let's play!" They ran around the room and played with their new friends. 
As they were playing, Lily accidentally knocked over the blouse and it fell off the blouse. Jack said, "Ouch!" 
Lily felt embarrassed and said, "It's okay, Lily. I'm sorry. I didn't mean to break the blouse. Let's play with it again!" 
They both started to play and soon the blouse was over. They played together and had a lot of fun. Lily felt so happy that she didn't make her friend laugh again. She realized that sometimes things don't seem too bad and we should be different, but they still have a solution.


Step 2800, Loss: 1.5323652029037476, Elapsed Time: 80.58 seconds
Generated text:
Once upon a time, there was a little boy named Tim. Tim was a good boy. He liked to play with his toys and his dog, Max. One day, Tim went to the park with his mom.
"Hi, Tim! I want to play with you. Do you want to play with me?" Tim asked.
"Yes, Tim. Let's play together!" his mom said.
Tim and Max played together all day. They laughed and had lots of fun. Tim and Max were good friends.
"Can we play with our friends?" Max asked.
"Yes, but we can play together," his mom said.
Tim and Max played together. They laughed and had fun. Tim and Max were happy to play with them. They were best friends.


Step 3000, Loss: 1.5298235416412354, Elapsed Time: 70.76 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She was three years old and loved to explore the world around her. One day, Lily found a shiny jewel on the ground. She picked it up and looked inside the jewel. She picked it up and examined it. She showed it to her mom and they were so proud of her. 
After she was finished playing, Lily went to bed. She had so much fun that she didn't realize it was her mom had made a mistake. Her mom explained that sometimes it's important to be responsible and to listen to others. Lily was very careful and listened to her mom's advice.
From that day on, Lily was careful and she never forgot about the jewel. She always made sure to be careful when she played. And always remember to always be careful when things happened.


Step 3200, Loss: 1.471197485923767, Elapsed Time: 72.75 seconds
Generated text:
Once upon a time, there was a little boy named Tim. Tim was a curious little boy who loved to play. One day, he found a toy car in the garden. It was red and shiny. Tim was very happy and excited.
As Tim went outside to play, he saw a little bird in the grass. The bird had a red car. Tim wanted to help the bird. He said, "Hello, little bird. Can you help me lift my toy?" The bird looked at Tim and said, "Yes, please! It is so nice to be careful. You are smart and smart."
Tim was very happy to help the bird. He thanked the bird and the bird flew away. Tim learned that helping others can make us feel good and smart. He learned that helping others can make us happy too.


Step 3400, Loss: 1.5036591291427612, Elapsed Time: 71.77 seconds
Generated text:
Once upon a time there was a boy named Tim. He loved to go outside and play. One day, he saw a big tree with lots of branches. He wanted to climb it, but his mom said no.
Tim tried to climb the tree, but it was too high. He was getting higher and higher until he could see what was on the top. Then he heard a loud noise coming from outside. He peeked out from the ground and saw a small bird flying above.
Tim was so happy to be able to reach the top. He jumped and jumped up and down the tree with excitement. When he was finally caught the bird, it flew down to the tree. Tim was so proud of his success.
He thanked the bird and continued climbing. He was very happy that he had climbed the tree, even though he couldn't get down from the tree. From then on, he made sure to stay away from the top of the tree, safe and sound.


Step 3600, Loss: 1.4788475036621094, Elapsed Time: 76.91 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys and watch cartoons. One day, Timmy's mommy asked him to help his mommy. Timmy was excited to help, so he got a new toy from his grandma's house. 
After a few days, Timmy's mommy gave him a special gift. Timmy was so excited to receive such a gift! It was a special gift that his grandma gave him a big hug. Timmy was so happy to have such a gift and hugged his mommy. 
After that, Timmy went to bed that night with a special gift from his grandma. He loved his present so much that he gave it a special gift to his grandma. Timmy was so happy to have a present that he had found such a beautiful gift. 
The next day, Timmy's grandma gave him a gift and gave him a gift. Timmy was so excited and couldn't wait to tell his grandma. He showed his grandma a gift for him with a present, and they all said thank you to his grandma. Timmy was so happy to receive the gift from her grandma and her gift. 
The present was that Timmy had!!!!

Step 3800, Loss: 1.4588173627853394, Elapsed Time: 88.57 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys, but he always wanted to share them with others. One day, Timmy was playing with his toy cars and accidentally broke them. His mom said, "Don't worry, I will help you fix your toys. Let's make it together."
So, Timmy went to his room and started to play with his toy cars. He had so much fun playing with it, he didn't even want to lose it. His mom was happy to help, so they took a break and put the toy cars away in the right place. Timmy was very happy that he could share his toys with his mom.


Step 4000, Loss: 1.4644711017608643, Elapsed Time: 69.17 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play in the garden, and even when she saw a beautiful butterfly. She ran over to the butterfly and tried to catch it, but it flew away. 
Suddenly, a little girl appeared. She looked around and saw her and asked her what she was doing. Lily replied, "I want to play too, please!" Her mom replied, "Okay, but only if you don't like the butterfly, it's better." 
So, Lily decided to go back to the garden to find the butterfly. She walked over to the garden and saw a big, beautiful butterfly. The butterfly flapped its wings and flew away. Lily felt so happy that she ran back to the garden, happy that she could play with the butterfly again.


Step 4200, Loss: 1.4640543460845947, Elapsed Time: 72.24 seconds
Generated text:
Once upon a time, there was a boy named Timmy. Timmy loved to play with his toy cars, trucks and even the cars. One day, Timmy's toy car got stuck in a big mud. Timmy tried to get the wheels, but he was too heavy. He tried to push the wheel but it was too heavy. Timmy tried and tried, but he couldn't do it. Timmy was sad, but then he remembered how his mom told him. He was brave and strong, but he knew how to use his car to get it. He was happy again, and he was able to get the wheel.


Step 4400, Loss: 1.4368994235992432, Elapsed Time: 66.23 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play outside and pick flowers. One day, he saw a little boy who was very curious. He went up to him and asked, "Can I help you?" The boy replied, "Sure, you can help me find your way."
Timmy and the boy searched everywhere but they couldn't find the way up to the boy. They searched everywhere but they couldn't find the answer. Suddenly, the boy said, "Don't worry, I'll help you find your way." Timmy said, "I know how to solve this problem."
Timmy's face lit up with excitement and said, "I'll help you solve this problem. I can find your way home." His mom said, "That's right. Let's find your way and see if we can find a way to find my way back." Timmy and the boy looked at each other's hand and said, "Yes, I'll help you."
They found a small tree, and Timmy helped his mom get his help. They worked together to search for the little boy's journey. After a few hours, they finally found the perfect path. Timmy was so happy that he could!!!!

Step 4600, Loss: 1.4162362813949585, Elapsed Time: 88.24 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play outside and look at the pretty flowers. One day, Timmy saw a beautiful flower. It was so pretty and he wanted to touch it. 
But when he got close to the flower, he accidentally dropped it. It broke! Timmy was so sad that he cried and his mom couldn't fix it. She said they would get another special thing, but Timmy was okay. 
His mom hugged him and said it was okay. She told him that accidents happen and that it's important to be careful when you touch things that can happen.


Step 4800, Loss: 1.4360578060150146, Elapsed Time: 66.49 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big tree and wanted to climb it. But her mom told her to be careful because the tree might hurt them.
Lily tried to climb the tree, but she was too weak. She fell down and started to cry. Her mom hugged her and told her to go back home. They were very happy and went back to the tree.


Step 5000, Loss: 1.4377806186676025, Elapsed Time: 59.31 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside and collect things. One day, she went to the park with her mom. She saw a boy sitting at the bench and he said, "Hello, doggy!" Lily felt embarrassed because she didn't like it. She asked the boy, "What's wrong?" The boy said, "My dog's dog bit me!" Lily said, "It's too bad. I'm too small." The boy felt sad for her.
Lily said, "I'm sorry, but I don't know it's not good." She went to her mom and asked if she could borrow some of her toys. Her mom said, "Sure, you can borrow your toy car. You can borrow it, but be careful." Lily smiled and went back to playing with her toy car. She had a lot of fun with her toy car and played with it all day.


Step 5200, Loss: 1.3954946994781494, Elapsed Time: 76.86 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside and pick flowers. One day, she saw a beautiful flower in the garden. She picked it up and held it in her hand. She felt so happy that she wanted to show her mom.
Her mom came outside and saw that Lily was very sad. She told Lily that the flower was still a flower. Lily felt sorry for being so pretty. She promised to be more careful with the flower.
The next day, Lily went back outside to pick some flowers. She found some pretty flowers and showed them to her mom. Her mom was very happy and said that Lily's flower was not as pretty as the flowers in the garden. From that day on, Lily loved to play and make pretty flowers in the garden.


Step 5400, Loss: 1.401665449142456, Elapsed Time: 70.89 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play outside and explore. One day, Timmy saw a big tree with a lot of branches. He thought it looked interesting and decided to climb the tree. 
Timmy's friend, a wise old owl, said, "Timmy, don't be careless. I don't want to fall. You need to get hurt." 
Timmy didn't understand why the owl was so wise, so he explained to his friends, "It's important to always stay safe." 
Timmy felt better and went back to his tree to admire the branches from branch. He knew that even if something is too small, it's always better to be safe and be careful when climbing tree.


Step 5600, Loss: 1.3960157632827759, Elapsed Time: 71.91 seconds
Generated text:
Once upon a time, a little girl named Lily went on a trip with her mommy. They were having so much fun. When they got to the beach, they saw a crab and wanted to touch it.
Lily said, "Mommy, can I touch the crab?"
"No, it's too dangerous. You should stay in the sand. The crab might be safe," her mommy replied.
Lily was sad and started to cry. "I'm sorry, mommy. I didn't want to be hurt."
Her mommy said, "It's okay, sweetie. I'm glad you listened to your mommy and not touched the crab. Now you can't touch the crab's safety."
Lily smiled and hugged her mommy. From that day on, they became good friends and they played together every day.


Step 5800, Loss: 1.3838456869125366, Elapsed Time: 73.78 seconds
Generated text:
Once upon a time, there was a little boy named Tim. Tim had a big toy car. The car could go very fast. Tim liked to race with his car.
One day, Tim saw a big box. He wanted to open it. Tim went to his friend, Sue. Sue saw the box and said, "I can open this box!" Sue wanted to open the box. They opened the box. Inside, there was a big, soft teddy bear. The bear was happy.
Tim and Sue played with the bear all day. They were not very good at all. They played with the toys and had lots of fun. They learned that being kind was better and to help others. And from that day on, Tim and Sue were the best of friends.


Step 6000, Loss: 1.4077720642089844, Elapsed Time: 70.17 seconds
Generated text:
Once upon a time, a little girl named Sue and her dog, Spot. They lived in a big house with their mom and dad. One day, they found a shiny rock that they had to take home. They took it home and put it in their room.
Sue's mom saw the shiny rock and said, "Oh, my toy! Your toy is very special! You should take it home." Sue was happy to have her toy back and said, "Thank you, Spot!"
But then, a big wind came and blew the rock back. Sue and Spot got scared and started to cry. Their mom came and saw what happened. She told them not to worry and they both went to the toy store to get the shiny rock.
The moral of the story is that it's important to take care of our things, but we can also take care of them. If you take care of them, we might find them and help them find them.


Step 6200, Loss: 1.3794877529144287, Elapsed Time: 77.59 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside and pick flowers. One day, she went to the park with her mommy. She saw a boy crying because he lost his toy car.
Lily said, "I lost my toy car, but I can't find it."
Lily asked the boy, "Have you seen my toy car?" The boy said, "No, I haven't seen it."
Lily felt sad and went to play on the swings. She went to the boy and said, "Thank you, you're my friend." The boy said, "You are very kind. Can you help me find your toy car?"
Lily smiled and said, "I will help you find your toy car. Maybe you can find your toy car."


Step 6400, Loss: 1.3905563354492188, Elapsed Time: 73.45 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside in the sun. One day, she saw a big, round thing on the ground. It was a pretty, shiny rock.
Lily picked up the rock and put it on her foot. She was so happy and showed her mom the rock. Her mom was proud of her and gave her a big hug. Lily felt happy too.
Later that day, Lily and her mom went to the park. Lily saw a big slide. She wanted to slide down it too. She ran up the slide, and slid down fast. The wind blew and Lily fell. She was hurt and sad.
Lily learned to be careful when she slide. She went to the swings, and when she was playing, she felt much better. Her mom gave her a hug and told her to keep the smooth rock. Lily learned that she should always keep the rock safe.


Step 6600, Loss: 1.39736008644104, Elapsed Time: 76.19 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and her favorite toy was a teddy bear. One day, Lily was playing with her teddy bear when she saw a shiny object on the ground. She picked it up and looked at it closely. 
"Wow, look at that object!" said Lily.
"I want it!" said her mom. "It's a special object. It's shiny and it's very special to me."
Lily thought for a moment and then decided to take the object from her mom to her friend. "Look, I found a pretty bracelet!" said Lily. 
Her friend said, "That's great, Lily! Let's play with it." And so, they played with the bracelet all day long. They were having so much fun!


Step 6800, Loss: 1.4118915796279907, Elapsed Time: 72.23 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to draw with her toys and make pictures. One day, she was playing in her backyard when she saw a big bird sitting on a branch. The bird was so pretty! Lily wanted to be friends with the bird, so she tried to sing with it. But the bird didn't fly away. Lily was sad and didn't know what to do.
Then, Lily saw a bird that it was very graceful. She asked the bird what was wrong. The bird said that it had wings to fly away. Lily thought that it was a fun idea and flew away. She was happy again and went back to her house to play. From that day on, Lily knew that if she wanted to be friends with her friend, she could help the bird feel better.


Step 7000, Loss: 1.4061782360076904, Elapsed Time: 72.51 seconds
Generated text:
Once upon a time, a little girl named Sue lived in a big house. Sue had a toy car that she loved to play with. One day, Sue's mom told her not to clean up. Sue listened to her mom and started to clean her room.
After cleaning, Sue felt tired and needed to take a nap. She put on her clean, but her mom was not looking well. Sue went to her room and sat down to rest. Her mom said, "Sue, I'm going to rest and drink a big drink." Sue smiled and felt better.
Sue went to bed with her mom to rest her room. In the end, she went to sleep with her mom. She felt happy that she was able to help her mom. From that day on, Sue knew that her mom loved her because she loved her little sister so much.


Step 7200, Loss: 1.381988286972046, Elapsed Time: 73.38 seconds
Generated text:
Once upon a time, a little girl named Lily went to the park with her mommy. She saw a boy who looked sad. She went to her mommy and asked her if she could have a turn to her.
"Sure, but you have to ask your mommy," her mommy said. "It's okay, but it's important to ask nicely."
Lily nodded and gave the boy a hug. "Can we play on the swings?" she asked.
The boy didn't like it and said, "Sure, but be careful."
Lily felt happy and went to play on the swings. She went back to the boy and asked if he wanted to swing too. The boy said, "I want to swing like you, but it's not safe."
Lily went to the swing and pushed the boy away. She was sad that she lost her favorite toy. The boy was upset and didn't want to swing anymore. They both sat on the swing and watched the boy play on the swings. Lily felt better and said, "Thank you for letting me swing, Timmy. You're my friend."


Step 7400, Loss: 1.3883332014083862, Elapsed Time: 83.07 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside in the garden and feel the warm sunshine on her skin. One day, while she was playing, she noticed that the flowers had bloomed in a beautiful butterfly. She wanted to see the butterfly so she asked her mommy if she could go to the butterfly's home. Her mommy said yes and Lily ran to her mommy's room. 
When she got there, she saw that the butterfly was sad and started to cry. Her mommy told her that she couldn't go back to the garden to get her dress and her mommy was not happy. Lily felt bad and didn't want to go back to the garden to get her dress. 
The next day, Lily saw the butterfly again and ran back to her mommy's house. She was happy to see her mommy again and said she was a great helper. Her mommy was so happy that she could see the butterfly again and that was a beautiful butterfly who was a good girl.


Step 7600, Loss: 1.3650509119033813, Elapsed Time: 80.14 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys, but one day she got bored. She went to her friend, Tim, and asked if he could borrow his toy.
Tim said, "No, Lily. This is my toy car. You can play with it, but I want to borrow your toy car."
Lily felt sad because she didn't want to play with Tim. She said, "Please, Tim. We can borrow your toy car. I can borrow it for my birthday."
Tim took the toy car and said, "No, Lily. I can borrow it for my birthday." He played with the toy car all day. He was happy with his new toy car.


Step 7800, Loss: 1.3602105379104614, Elapsed Time: 69.70 seconds
Generated text:
Once upon a time, a little girl named Lily went to a park with her mommy. The park was a pretty green place, with a big pond with lots of ducks. Lily loved the sound of the water, so she decided to play in the water.
Lily saw a butterfly flying in the sky and tried to catch it. But the butterfly was too fast, and she couldn't catch it. Suddenly, the butterfly landed on a flower and Lily started to panic.
Her mommy saw the butterfly and asked, "What happened, Lily?"
"The butterfly fell and hurt my hand," said Lily.
Her mommy came over and saw the butterfly and asked, "Are you okay, Lily?"
"It hurts, sweetie," said her mommy.
Lily's mommy took a big spoon and the butterfly was very pretty and had lots of colors on it. Lily felt much better and was able to catch the butterfly. She was happy to have her friend a little friend, and they played together all day long.


Step 8000, Loss: 1.3417373895645142, Elapsed Time: 79.57 seconds
Generated text:
Once upon a time, there was a little boy named Tim. Tim had a toy named Sam. Sam liked to play with the toy car. One day, Tim's mom said they were going to a new park. Tim wanted to play on the swings, but his mom said no.
Tim started playing with the car. He saw a big dog in the park. The dog barked loudly. Tim felt scared. He wanted to run away. He went to his mom and dad and said they were safe. Tim felt safe.
Later, Tim and Sam went to the park. They saw a big slide. They both wanted to go on the slide. They ran to the slide and slid down the slide. Tim's mom and dad helped them get down. They all laughed and had a great time together.


Step 8200, Loss: 1.3874244689941406, Elapsed Time: 71.68 seconds
Generated text:
Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toy car, but his favorite thing to do was to go to the store. Timmy didn't like the store, so he went to the store.
When they got to the store, Timmy saw a toy car and wanted it. He asked his mom for a toy car and her mom said, "No, it's too expensive." Timmy didn't want to buy a toy car. He thought the toy car wouldn't have a toy car, but his mom said it was okay.
Later that day, Timmy went to the park with his mom. Timmy saw a boy who looked sad because he lost his toy car. Timmy felt sorry for the boy and said, "I can help you find my toy car." But he couldn't find his toy car, and he didn't know what to do.
Timmy asked his mom, "Why are you looking for my toy car, and the boy can have it back?" His mom said, "Don't worry, we can fix it. Let's find your toy car." They found the toy car and it was in the store. Timmy was happy that he found a!!!!

Final generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside and pick flowers. One day, Lily saw a butterfly and tried to catch it, but it flew away. Lily felt sad because she didn't know what was happening.
Suddenly, her friend, a little bird named Timmy, came to visit. "Don't worry, Timmy," said Lily. "I can catch it!"
"I'm sorry, Lily," said Timmy. "I just need to help my mom."
"Thank you, Timmy," said Lily. "You are very kind to me."
From that day on, Lily learned to always help others when they needed it. And every time she saw the butterfly, she felt happy again.

Visualize the training loss.

import matplotlib.pyplot as plt
plt.plot(metrics_history['train_loss'])
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()
_images/cd92931fb69295f399be5a2afba4be4b202bd9cda85db6d7299da53c36d2b62a.png

As you can see, the model goes from generating completely random words at the beginning to generating sensible tiny stories at the end of the training. So essentially we have pretrained a small LLM to write tiny stories for us.

Saving the checkpoint#

Save the model checkpoint.

import orbax.checkpoint as orbax

state = nnx.state(model)

checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save('/content/save', state)

# Make sure the files are there
!ls /content/save/
_CHECKPOINT_METADATA  d  manifest.ocdbt  _METADATA  ocdbt.process_0  _sharding

Profiling for hyperparameter tuning#

!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard

Load the tensorboard colab extension.

%load_ext tensorboard

As we’re going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we’ll need to perform some warmup to guarantee that our code is JIT’d and that our TPUs are warm. For improved comparability, we’ll only start tracing after we’ve finished warmup.

trace_dir = "/tmp/jax-trace/"

def loop_step(batch, step):
    input_batch = jnp.array(jnp.array(batch).T)
    target_batch = prep_target_batch(input_batch)
    train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))

def generate_trace():
    tracing_steps = 30
    warmup_steps = 5
    for current_step in range(warmup_steps + tracing_steps):
        if current_step == warmup_steps:
            jax.profiler.start_trace(trace_dir)
        with jax.profiler.StepTraceAnnotation("train", step_num=current_step):
            batch = next(text_dl)
            loop_step(batch, current_step)

    jax.profiler.stop_trace()

Now we’ll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time.

trace_dir = "/tmp/jax-trace-batch-comparison/"

batch_size = 64
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
generate_trace()

batch_size = 256
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
generate_trace()

Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have batch_size = 256.

The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.

In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example.

%tensorboard --logdir=$trace_dir

Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let’s compare results between them. To switch to 8-way data parallel, we’ll replace the Mesh definition with:

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

JAX will automatically figure out how to shard the model and data to use the new partition strategy and nothing else need to be done. Re-connect the TPU runtime and run it again to see how it runs.

How simple and powerful is this! And that’s the beauty of JAX automatic parallelism.

trace_dir = "/tmp/jax-trace-parallelism-comparison/"

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
generate_trace()

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
generate_trace()

Once again we’ll run tensorboard.

Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.

By looking at the Trace Viewer tool and looking under each TPU’s ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in reduce_sum operations.

%tensorboard --logdir=$trace_dir

By changing hyperparameters and comparing profiles, we’re able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.