Machine Translation with encoder-decoder transformer model#

Open in Colab

This tutorial is adapted from Keras’ documentation on English-to-Spanish translation with a sequence-to-sequence Transformer, which is itself an adaptation from the book Deep Learning with Python, Second Edition by François Chollet

We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation.

import pathlib
import random
import string
import re
import numpy as np

import jax.numpy as jnp
import optax

from flax import nnx

import tiktoken
import grain.python as grain
import tqdm

Pull down data to temp and extract into memory#

There are lots of ways to get this done, but for simplicity and clear visibility into what’s happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing.

import requests
import zipfile
import tempfile

url = "http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"

with tempfile.TemporaryDirectory() as temp_dir:
    temp_path = pathlib.Path(temp_dir)
    zip_file_path = temp_path / "spa-eng.zip"

    response = requests.get(url)
    zip_file_path.write_bytes(response.content)

    with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
        zip_ref.extractall(temp_path)

    text_file = temp_path / "spa-eng" / "spa.txt"

    with open(text_file) as f:
        lines = f.read().split("\n")[:-1]
    text_pairs = []
    for line in lines:
        eng, spa = line.split("\t")
        spa = "[start] " + spa + " [end]"
        text_pairs.append((eng, spa))

Build train/validate/test pair sets#

We’ll stay close to the original tutorial so it’s clear how to follow what’s the same vs what’s different; one early difference is the choice to go with an off-the-shelf encoder/tokenizer in tiktoken. Specifically “cl100k_base” - it has a wide range of language understanding and it’s fast.

random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")
118964 total pairs
83276 training pairs
17844 validation pairs
17844 test pairs
tokenizer = tiktoken.get_encoding("cl100k_base")

We strip out punctuation to keep things simple and in line with the original tutorial - the [ ] are kept in so that our [start] and [end] formatting is preserved.

strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

vocab_size = tokenizer.n_vocab
sequence_length = 20
def custom_standardization(input_string):
    lowercase = input_string.lower()
    return re.sub(f"[{re.escape(strip_chars)}]", "", lowercase)
def tokenize_and_pad(text, tokenizer, max_length):
    tokens = tokenizer.encode(text)[:max_length]
    padded = tokens + [0] * (max_length - len(tokens)) if len(tokens) < max_length else tokens ##assumes list-like - (https://github.com/openai/tiktoken/blob/main/tiktoken/core.py#L81 current tiktoken out)
    return padded
def format_dataset(eng, spa, tokenizer, sequence_length):
    eng = custom_standardization(eng)
    spa = custom_standardization(spa)
    eng = tokenize_and_pad(eng, tokenizer, sequence_length)
    spa = tokenize_and_pad(spa, tokenizer, sequence_length)
    return {
            "encoder_inputs": eng,
            "decoder_inputs": spa[:-1],
            "target_output": spa[1:],
            }
train_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in train_pairs]
val_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in val_pairs]
test_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in test_pairs]

At this point we’ve extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following:

## data selection example
print(train_data[135])
{'encoder_inputs': [72, 1390, 311, 617, 264, 3137, 449, 1461, 922, 856, 3938, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 92820, 7669, 277, 390, 33013, 1645, 78993, 409, 9686, 65744, 510, 408, 60, 0, 0, 0, 0], 'target_output': [60, 92820, 7669, 277, 390, 33013, 1645, 78993, 409, 9686, 65744, 510, 408, 60, 0, 0, 0, 0, 0]}

The output should look something like

{‘encoder_inputs’: [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ‘decoder_inputs’: [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], ‘target_output’: [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]}

Define Transformer components: Encoder, Decoder, Positional Embed#

In many ways this is very similar to the original source, with ops changing to jnp and keras or layers becoming nnx. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call.

class TransformerEncoder(nnx.Module):
    def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads

        self.attention = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          decode=False,
                                          rngs=rngs)
        self.dense_proj = nnx.Sequential(
                nnx.Linear(embed_dim, dense_dim, rngs=rngs),
                nnx.relu,
                nnx.Linear(dense_dim, embed_dim, rngs=rngs),
        )

        self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)

    def __call__(self, inputs, mask=None):
        if mask is not None:
            padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)
        else:
            padding_mask = None

        attention_output = self.attention(
            inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)


class PositionalEmbedding(nnx.Module):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs):
        self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs)
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def __call__(self, inputs):
        length = inputs.shape[1]
        positions = jnp.arange(0, length)[None, :]
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        else:
            return jnp.not_equal(inputs, 0)

class TransformerDecoder(nnx.Module):
    def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = nnx.MultiHeadAttention(num_heads=num_heads,
                                  in_features=embed_dim,
                                  decode=False,
                                  rngs=rngs)
        self.attention_2 = nnx.MultiHeadAttention(num_heads=num_heads,
                                  in_features=embed_dim,
                                  decode=False,
                                  rngs=rngs)

        self.dense_proj = nnx.Sequential(
                nnx.Linear(embed_dim, latent_dim, rngs=rngs),
                nnx.relu,
                nnx.Linear(latent_dim, embed_dim, rngs=rngs),
        )
        self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.layernorm_3 = nnx.LayerNorm(embed_dim, rngs=rngs)

    def __call__(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs.shape[1])
        if mask is not None:
            padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)
            padding_mask = jnp.minimum(padding_mask, causal_mask)
        else:
            padding_mask = None
        attention_output_1 = self.attention_1(
            inputs_q=inputs, inputs_v=inputs, inputs_k=inputs,  mask=causal_mask
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2( ## https://github.com/google/flax/blob/main/flax/nnx/nn/attention.py#L403-L405
            inputs_q=out_1,
            inputs_v=encoder_outputs,
            inputs_k=encoder_outputs,
            mask=padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, sequence_length):
        i = jnp.arange(sequence_length)[:, None]
        j = jnp.arange(sequence_length)
        mask = (i >= j).astype(jnp.int32)
        mask = jnp.reshape(mask, (1, 1, sequence_length, sequence_length))
        return mask

Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we’ll train and later use for inference.

class TransformerModel(nnx.Module):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, latent_dim: int, num_heads: int, dropout_rate: float, rngs: nnx.Rngs):
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        self.encoder = TransformerEncoder(embed_dim, latent_dim, num_heads, rngs=rngs)
        self.positional_embedding = PositionalEmbedding(sequence_length, vocab_size, embed_dim, rngs=rngs)
        self.decoder = TransformerDecoder(embed_dim, latent_dim, num_heads, rngs=rngs)
        self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
        self.dense = nnx.Linear(embed_dim, vocab_size, rngs=rngs)

    def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):
        x = self.positional_embedding(encoder_inputs)
        encoder_outputs = self.encoder(x, mask=mask)

        x = self.positional_embedding(decoder_inputs)
        decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)
        # per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training
        decoder_outputs = self.dropout(decoder_outputs, deterministic=deterministic)

        logits = self.dense(decoder_outputs)
        return logits

Build out Data Loader and Training Definitions#

It can be more computationally efficient to use pygrain for the data load stage, but this way it’s abundandtly clear what’s happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. ‘Encoder_inputs’, ‘decoder_inputs’ and ‘target_output’.

batch_size = 512 #set here for the loader and model train later on

class CustomPreprocessing(grain.MapTransform):
    def __init__(self):
        pass

    def map(self, data):
        return {
            "encoder_inputs": np.array(data["encoder_inputs"]),
            "decoder_inputs": np.array(data["decoder_inputs"]),
            "target_output": np.array(data["target_output"]),
        }

train_sampler = grain.IndexSampler(
    len(train_data),
    shuffle=True,
    seed=12,                        # Seed for reproducibility
    shard_options=grain.NoSharding(), # No sharding since it's a single-device setup
    num_epochs=1,                    # Iterate over the dataset for one epoch
)

val_sampler = grain.IndexSampler(
    len(val_data),
    shuffle=False,
    seed=12,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_data,
    sampler=train_sampler,                 # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations
    worker_buffer_size=2,                  # Count of output batches to produce in advance per worker
    operations=[
        CustomPreprocessing(),
        grain.Batch(batch_size=batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_data,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        CustomPreprocessing(),
        grain.Batch(batch_size=batch_size),
    ]
)

Optax doesn’t have the identical loss function that the source tutorial uses, but this softmax cross entropy works well here - you can one_hot_encode if you don’t use the _with_integer_labels version of the loss.

def compute_loss(logits, labels):
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
    return jnp.mean(loss)

While in the original tutorial most of the model and training details happen inside keras, we make them explicit here in our step functions, which are later used in train_one_epoch and eval_model.

@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model, train_encoder_input, train_decoder_input, train_target_input):
        logits = model(train_encoder_input, train_decoder_input)
        loss = compute_loss(logits, train_target_input)
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model, jnp.array(batch["encoder_inputs"]), jnp.array(batch["decoder_inputs"]), jnp.array(batch["target_output"]))
    optimizer.update(grads)
    return loss

@nnx.jit
def eval_step(model, batch, eval_metrics):
    logits = model(jnp.array(batch["encoder_inputs"]), jnp.array(batch["decoder_inputs"]))
    loss = compute_loss(logits, jnp.array(batch["target_output"]))
    labels = jnp.array(batch["target_output"])

    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels,
    )

Here, nnx.MultiMetric helps us keep track of general training statistics, while we make our own dictionaries to hold historical values

eval_metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    accuracy=nnx.metrics.Accuracy(),
)

train_metrics_history = {
    "train_loss": [],
}

eval_metrics_history = {
    "test_loss": [],
    "test_accuracy": [],
}
## Hyperparameters
rng = nnx.Rngs(0)
embed_dim = 256
latent_dim = 2048
num_heads = 8
dropout_rate = 0.5
vocab_size = tokenizer.n_vocab
sequence_length = 20
learning_rate = 1.5e-3
num_epochs = 10
bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"
train_total_steps = len(train_data) // batch_size

def train_one_epoch(epoch):
    model.train()  # Set model to the training mode: e.g. update batch statistics
    with tqdm.tqdm(
        desc=f"[train] epoch: {epoch}/{num_epochs}, ",
        total=train_total_steps,
        bar_format=bar_format,
        leave=True,
    ) as pbar:
        for batch in train_loader:
            loss = train_step(model, optimizer, batch)
            train_metrics_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)


def evaluate_model(epoch):
    # Compute the metrics on the train and val sets after each training epoch.
    model.eval()  # Set model to evaluation model: e.g. use stored batch statistics

    eval_metrics.reset()  # Reset the eval metrics
    for val_batch in val_loader:
        eval_step(model, val_batch, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'test_{metric}'].append(value)

    print(f"[test] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}")
model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))

Start the Training!#

With our data loaders set and the model, optimizer, and epoch train/eval functions set up - time to finally press go - on a 3090, this is roughly 19GB VRAM and each epoch is roughly 18 seconds with batch_size set to 512.

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)
[test] epoch: 1/10
- total loss: 1.9655
- Accuracy: 0.6774
[test] epoch: 2/10
- total loss: 1.1961
- Accuracy: 0.7903
[test] epoch: 3/10
- total loss: 1.0054
- Accuracy: 0.8167
[test] epoch: 4/10
- total loss: 0.9351
- Accuracy: 0.8289
[test] epoch: 5/10
- total loss: 0.8976
- Accuracy: 0.8369
[test] epoch: 6/10
- total loss: 0.8876
- Accuracy: 0.8396
[test] epoch: 7/10
- total loss: 0.8857
- Accuracy: 0.8426
[test] epoch: 8/10
- total loss: 0.8959
- Accuracy: 0.8427
[test] epoch: 9/10
- total loss: 0.9128
- Accuracy: 0.8434
[test] epoch: 10/10
- total loss: 0.9227
- Accuracy: 0.8452
[train] epoch: 0/10, [160/162], loss=1.98 [00:27<00:00]
[train] epoch: 1/10, [160/162], loss=1.16 [00:18<00:00]
[train] epoch: 2/10, [160/162], loss=0.846 [00:18<00:00]
[train] epoch: 3/10, [160/162], loss=0.695 [00:18<00:00]
[train] epoch: 4/10, [160/162], loss=0.593 [00:18<00:00]
[train] epoch: 5/10, [160/162], loss=0.511 [00:18<00:00]
[train] epoch: 6/10, [160/162], loss=0.454 [00:18<00:00]
[train] epoch: 7/10, [160/162], loss=0.421 [00:18<00:00]
[train] epoch: 8/10, [160/162], loss=0.371 [00:18<00:00]
[train] epoch: 9/10, [160/162], loss=0.341 [00:18<00:00]

We can then plot the loss over training time. That log-plot comes in handy here, or it’s hard to appreciate the progress after 1000 steps or so.

import matplotlib.pyplot as plt

plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
plt.yscale('log')
plt.legend()
<matplotlib.legend.Legend at 0x77ba782b8c90>
_images/80bfe3c3cecde1466a92dcd32cecad303a4584d845bdc389c664480039388020.png

And eval set Loss and Accuracy - Accuracy does continue to rise, though it’s hard-earned progress after about the 5th epoch. Based on the training statistics, it’s fair to say the process starts overfitting after roughly that 5th epoch.

fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].set_title("Loss value on eval set")
axs[0].plot(eval_metrics_history["test_loss"])
axs[1].set_title("Accuracy on eval set")
axs[1].plot(eval_metrics_history["test_accuracy"])
[<matplotlib.lines.Line2D at 0x77ba73f39350>]
_images/b8e1af55861cbe2873b88072c1f7fb5609d79b90a693eb5fe9f495f9b371992d.png

Use Model for Inference#

After all that, the product of what we were working for: a trained model we can save and load for inference. For people using LLMs recently, this pattern may look rather familiar: an input sentence tokenized into an array and computed ‘next’ token-by-token. While many recent LLMs are decoder-only, this was an encoder/decoder architecture with the very specific english-to-spanish pattern baked in.

We’ve changed a couple things from the source ‘use’ function, here - because of the tokenizer used, things like [start] and [end] are no longer single tokens - instead [start] is [29563, 60] = "[start" + "]" and [end] is [58308, 60] = "[end" + "]" - thus we start with only a single token [start and can’t only test on last_token = "[end"]. Otherwise, the main change here is that the input is assumed a single sentence, rather than batch inference.

def decode_sequence(input_sentence):

    input_sentence = custom_standardization(input_sentence)
    tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)

    decoded_sentence = "[start"
    for i in range(sequence_length):
        tokenized_target_sentence = tokenize_and_pad(decoded_sentence, tokenizer, sequence_length)[:-1]
        predictions = model(jnp.array([tokenized_input_sentence]), jnp.array([tokenized_target_sentence]))

        sampled_token_index = np.argmax(predictions[0,i, :]).item(0)
        sampled_token = tokenizer.decode([sampled_token_index])
        decoded_sentence += "" + sampled_token

        if decoded_sentence[-5:] == "[end]":
            break
    return decoded_sentence
test_eng_texts = [pair[0] for pair in test_pairs]
test_result_pairs = []
for _ in range(10):
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequence(input_sentence)

    test_result_pairs.append(f"[Input]: {input_sentence} [Translation]: {translated}")

Test Results#

For the model and the data, not too shabby - It’s definitely spanish-ish. Though when ‘making’ friends, please don’t confuse ‘hacer’ (to make) with ‘comer’ (to eat).

for i in test_result_pairs:
    print(i)
[Input]: We're both way too busy to help you right now. [Translation]: [start] los dos estamos demasiado para ayudar esta mañana [end]
[Input]: Have you eaten dinner? [Translation]: [start] has comido la cena [end]
[Input]: That is the poet I met in Paris. [Translation]: [start] ese es el poeta que conocí en parís [end]
[Input]: It doesn't make sense to me. [Translation]: [start] no me hace falta sentido [end]
[Input]: We're happy. [Translation]: [start] estamos felices [end]
[Input]: What about me? [Translation]: [start] de qué me [end]
[Input]: Make a decision and make it with the confidence that you are right. [Translation]: [start] haz una decisión y tomará la confianza en el confian [end]
[Input]: Put some salt on your meat. [Translation]: [start] ponte algo de sal [end]
[Input]: Tom's deaf. [Translation]: [start] tom es sordo [end]
[Input]: How old are your brothers and sisters? [Translation]: [start] qué edad son tus hermanos [end]

Example output from the above cell:

  [Input]: We're going to have a baby. [Translation]: [start] nosotros vamos a tener un bebé [end]
  [Input]: You drive too fast. [Translation]: [start] conducís demasiado rápido [end]
  [Input]: Let me know if there's anything I can do. [Translation]: [start] déjame saber si hay cualquier cosa que yo pueda hacer [end]
  [Input]: Let's go to the kitchen. [Translation]: [start] vayamos a la cocina [end]
  [Input]: Tom gasped. [Translation]: [start] tom se quedó sin aliento [end]
  [Input]: I was just hanging out with some of my friends. [Translation]: [start] estaba escquieto con algunos de mi amigos [end]
  [Input]: Tom is in the bathroom. [Translation]: [start] tom está en el cuarto de baño [end]
  [Input]: I feel safe here. [Translation]: [start] me siento segura [end]
  [Input]: I'm going to need you later. [Translation]: [start] me voy a necesitar después [end]
  [Input]: A party is a good place to make friends with other people. [Translation]: [start] una fiesta es un buen lugar de comer amigos con otras personas [end]