Variational autoencoder (VAE) and debugging in JAX#

Open in Colab

This tutorial explores a simplified version of a generative model called Variational Autoencoder (VAE) with scikit-learn digits dataset, and expands on what we learned in Getting started with JAX. Along the way, you’ll learn more about how JAX’s JIT compilation (jax.jit) actually works, and what this means for debugging JAX programs, as we learn how to identify what can go wrong during model training.

If you are new to JAX for AI, check out the first tutorial, which explains how to build a simple neural netwwork with Flax and Optax, and JAX’s key features, including the NumPy-style interface with jax.numpy, JAX transformations for JIT compilation with jax.jit, automatic vectorization with jax.vmap, and automatic differentiation with jax.grad.

Loading the data#

As before, this example uses the well-known, small and self-contained scikit-learn digits dataset:

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import jax.numpy as jnp

digits = load_digits()

splits = train_test_split(digits.images, random_state=0)

images_train, images_test = map(jnp.asarray, splits)

print(f"{images_train.shape=}")
print(f"{images_test.shape=}")
images_train.shape=(1347, 8, 8)
images_test.shape=(450, 8, 8)

The dataset comprises 1800 images of hand-written digits, each represented by an 8x8 pixel grid, and their corresponding labels. For visualization of this data, refer to loading the data in the previous tutorial.

Defining the VAE with Flax#

Previously, we learned how to use Flax NNX to create a simple feed-forward neural network trained for classification with an architecture that looked roughly like this:

import jax
import jax.numpy as jnp
from flax import nnx

class SimpleNN(nnx.Module):

  def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
    self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
    self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)

  def __call__(self, x: jax.Array) -> jax.Array:
    x = nnx.selu(self.layer1(x))
    x = nnx.selu(self.layer2(x))
    return self.layer3(x)

This kind of network has one output per class, and the loss function is designed such that once the model is trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.

To create a VAE with Flax NNX, we will use similar building blocks - subclassing flax.nnx.Module, stacking flax.nnx.Linear layers, and adding a rectified linear unit activation function (flax.nnx.relu). A VAE maps the input data into the parameters of a probability distribution (mean, std), and the output is a small probabilistic model representing the data.

Note that the classic VAE is generally based on convolutional layers, this example uses linear layers for simplicity.

The sub-network that produces this probabilistic encoding is the Encoder:

class Encoder(nnx.Module):
  def __init__(self, input_size: int, intermediate_size: int, output_size: int,
               *, rngs: nnx.Rngs):
    self.rngs = rngs
    self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs)
    self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs)
    self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs)

  def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
    x = self.linear(x)
    x = jax.nn.relu(x)

    mean = self.linear_mean(x)
    std = jnp.exp(self.linear_std(x))

    key = self.rngs.noise()
    z = mean + std * jax.random.normal(key, mean.shape)
    return z, mean, std

The idea here is that mean and std define a low-dimensional probability distribution over a latent space, and that z is a draw from this latent space that represents the training data.

To ensure that this latent distribution faithfully represents the actual data, define a Decoder that maps back to the input space as follows:

class Decoder(nnx.Module):
  def __init__(self, input_size: int, intermediate_size: int, output_size: int,
               *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs)
    self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs)

  def __call__(self, z: jax.Array) -> jax.Array:
    z = self.linear1(z)
    z = jax.nn.relu(z)
    logits = self.linear2(z)
    return logits

Now, define the VAE model (again by subclassing flax.nnx.Module) by combining Encoder and Decoder in a single network (VAE).

The model returns both the reconstructed image and the internal latent space model:

class VAE(nnx.Module):
  def __init__(
    self,
    image_shape: tuple[int, int],
    hidden_size: int,
    latent_size: int,
    *,
    rngs: nnx.Rngs
  ):
    self.image_shape = image_shape
    self.latent_size = latent_size
    input_size = image_shape[0] * image_shape[1]
    self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs)
    self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs)

  def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
    x = jax.vmap(jax.numpy.ravel)(x)  # flatten
    z, mean, std = self.encoder(x)
    logits = self.decoder(z)
    logits = jnp.reshape(logits, (-1, *self.image_shape))
    return logits, mean, std

Next, we need to define the loss function. The are two components to the model that we want to ensure:

  1. The logits output faithfully reconstruct the input image.

  2. The model represented by mean and std faithfully represents the “true” latent distribution.

Note that VAE uses a loss function based on the Evidence lower bound to quantify these two goals in a single loss value:

def vae_loss(model: VAE, x: jax.Array):
  logits, mean, std = model(x)
  kl_loss = jnp.mean(0.5 * jnp.mean(
      -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
  reconstruction_loss = jnp.mean(
    optax.sigmoid_binary_cross_entropy(logits, x)
  )
  return reconstruction_loss + 0.1 * kl_loss

Now all that’s left:

  • Instantiate the VAE model.

  • Select optax.adam (the Adam optimizer in our example), and instantiate the optimizer with flax.nnx.Optimizer for setting the train step.

  • Define the train_step using flax.nnx.value_and_grad for computing the gradients and update the model’s parameters using the optimizer.

  • Use the flax.nnx.jit transformation decorator to trace the train_step function for just-in-time compilation.

  • Run the training loop.

import optax

model = VAE(
  image_shape=(8, 8),
  hidden_size=32,
  latent_size=8,
  rngs=nnx.Rngs(0, noise=1),
)

optimizer = nnx.Optimizer(model, optax.adam(1e-3))

@nnx.jit
def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):
  loss, grads = nnx.value_and_grad(vae_loss)(model, x)
  optimizer.update(grads)
  return loss

for epoch in range(2001):
  loss = train_step(model, optimizer, images_train)
  if epoch % 500 == 0:
    print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 10175.9921875
Epoch 500 loss: nan
Epoch 1000 loss: nan
Epoch 1500 loss: nan
Epoch 2000 loss: nan

Notice in the output that something has gone wrong - the loss value has become NaN after some number of iterations.

Debugging NaNs in JAX#

Despite our best efforts, the VAE model is producing NaNs. What can we do?

JAX offers a number of debugging approaches for situations like this, outlined in JAX’s Debugging runtime values guide. (There is also the Introduction to debugging tutorial you may find useful.)

In this case, we can use the jax.debug_nans configuration to check where the NaN value is arising.

model = VAE(
  image_shape=(8, 8),
  hidden_size=32,
  latent_size=8,
  rngs=nnx.Rngs(0, noise=1),
)

optimizer = nnx.Optimizer(model, optax.adam(1e-3))

with jax.debug_nans(True):
  for epoch in range(2001):
    train_step(model, optimizer, images_train)
Invalid nan value encountered in the output of a jax.jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
Cell In[8], line 12
     10 with jax.debug_nans(True):
     11   for epoch in range(2001):
---> 12     train_step(model, optimizer, images_train)

File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py:350, in jit.<locals>.jit_wrapper(*args, **kwargs)
    340 with graph.update_context(jit_wrapper):
    341   pure_args, pure_kwargs = extract.to_tree(
    342     (args, kwargs),
    343     prefix=(in_shardings, kwarg_shardings)
   (...)    348     ctxtag=jit_wrapper,
    349   )
--> 350   pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
    351     *pure_args, **pure_kwargs
    352   )
    353   _args_out, _kwargs_out, out = extract.from_tree(
    354     (pure_args_out, pure_kwargs_out, pure_out),
    355     merge_fn=_jit_merge_fn,
    356     is_inner=False,
    357     ctxtag=jit_wrapper,
    358   )
    359 return out

    [... skipping hidden 4 frame]

File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py:129, in JitFn.__call__(self, *pure_args, **pure_kwargs)
    121 def __call__(self, *pure_args, **pure_kwargs):
    122   args, kwargs = extract.from_tree(
    123     (pure_args, pure_kwargs),
    124     merge_fn=_jit_merge_fn,
    125     ctxtag=self.ctxtag,
    126     is_inner=True,
    127   )
--> 129   out = self.f(*args, **kwargs)
    131   args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
    132   pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
    133     (args_out, kwargs_out, out),
    134     prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings),
    135     ctxtag=self.ctxtag,
    136     split_fn=_jit_split_fn,
    137   )

Cell In[7], line 14, in train_step(model, optimizer, x)
     12 @nnx.jit
     13 def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):
---> 14   loss, grads = nnx.value_and_grad(vae_loss)(model, x)
     15   optimizer.update(grads)
     16   return loss

File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/graph.py:1817, in UpdateContextManager.__call__.<locals>.update_context_manager_wrapper(*args, **kwargs)
   1814 @functools.wraps(f)
   1815 def update_context_manager_wrapper(*args, **kwargs):
   1816   with self:
-> 1817     return f(*args, **kwargs)

File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/transforms/autodiff.py:163, in _grad_general.<locals>.grad_wrapper(***failed resolving arguments***)
    151 pure_args = extract.to_tree(
    152   args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad'
    153 )
    155 gradded_fn = transform(
    156   GradFn(f, has_aux, nondiff_states),
    157   argnums=jax_argnums,
   (...)    160   allow_int=allow_int,
    161 )
--> 163 fn_out = gradded_fn(*pure_args)
    165 def process_grads(grads):
    166   return jax.tree.map(
    167     lambda x: x.state if isinstance(x, extract.NodeStates) else x,
    168     grads,
    169     is_leaf=lambda x: isinstance(x, extract.NodeStates),
    170   )

    [... skipping hidden 20 frame]

File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/jax/_src/pjit.py:219, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    217 except dispatch.InternalFloatingPointError as e:
    218   if getattr(fun, '_apply_primitive', False):
--> 219     raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
    220   dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)
    222 if p.attrs_tracked:

FloatingPointError: invalid value (nan) encountered in dot_general
When differentiating the code at the top of the callstack:
/home/docs/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/nn/linear.py:383:8 (Linear.__call__)

The output here is complicated, because the function we’re evaluating is complicated. The key to “deciphering” this traceback is to look for the places where the traceback touches our implementation.

In particular here, the output above indicates that NaN values arise during the gradient update:

<ipython-input-9-b5b28eeeadf6> in train_step()
     14   loss, grads = nnx.value_and_grad(vae_loss)(model, x)
---> 15   optimizer.update(grads)
     16   return loss

and further down from this, the details of the gradient update step where the NaN is arising:

/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in <lambda>()
    280       lambda g, t: (
--> 281           (1 - decay) * (g**order) + decay * t if g is not None else None
    282       ),

This suggests that the gradient is returning values that lead to NaN during the model update. Typically, this would come about when the gradient itself is for some reason diverging.

A diverging gradient means that something with the loss function may be amiss. Previously, we had loss=NaN at iteration 500. Let’s print the progress up to this point:

model = VAE(
  image_shape=(8, 8),
  hidden_size=32,
  latent_size=8,
  rngs=nnx.Rngs(0, noise=1),
)

optimizer = nnx.Optimizer(model, optax.adam(1e-3))

for epoch in range(501):
  loss = train_step(model, optimizer, images_train)
  if epoch % 50 == 0:
    print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 10175.9921875
Epoch 50 loss: -15.788647651672363
Epoch 100 loss: -134.56039428710938
Epoch 150 loss: -1761.32275390625
Epoch 200 loss: nan
Epoch 250 loss: nan
Epoch 300 loss: nan
Epoch 350 loss: nan
Epoch 400 loss: nan
Epoch 450 loss: nan
Epoch 500 loss: nan

It looks like the loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.

At this point, we may wish to inspect the values within the loss function itself to see where the diverging loss might be coming from.

In typical Python programs we can do this by inserting either a print statement or a breakpoint in the loss function. This may look something like this:

def vae_loss(model: VAE, x: jax.Array):
  logits, mean, std = model(x)
  kl_loss = jnp.mean(0.5 * jnp.mean(
      -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
  reconstruction_loss = jnp.mean(
    optax.sigmoid_binary_cross_entropy(logits, x)
  )
  print("kl loss", kl_loss)
  print("reconstruction loss", reconstruction_loss)
  return reconstruction_loss + 0.1 * kl_loss

model = VAE(
  image_shape=(8, 8),
  hidden_size=32,
  latent_size=8,
  rngs=nnx.Rngs(0, noise=1),
)

optimizer = nnx.Optimizer(model, optax.adam(1e-3))
train_step(model, optimizer, images_train)
kl loss Traced<ShapedArray(float32[])>with<JVPTrace> with
  primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7d8874a688c0>, in_tracers=(Traced<ShapedArray(float32[1347]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x7d8874a81530; to 'JaxprTracer' at 0x7d8874a81450>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347]. let
    b:f32[] = reduce_sum[axes=(0,)] a
    c:f32[] = div b 1347.0
  in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'donated_invars': (False,), 'ctx_mesh': None, 'name': '_mean', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7d88766b19f0>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None))
reconstruction loss Traced<ShapedArray(float32[])>with<JVPTrace> with
  primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7d8874a68b40>, in_tracers=(Traced<ShapedArray(float32[1347,8,8]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x7d8874a800e0; to 'JaxprTracer' at 0x7d8874a80280>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347,8,8]. let
    b:f32[] = reduce_sum[axes=(0, 1, 2)] a
    c:f32[] = div b 86208.0
  in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'donated_invars': (False,), 'ctx_mesh': None, 'name': '_mean', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7d88766b3940>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None))
Array(10175.992, dtype=float32)

But here rather than printing the value, we’re getting some kind of Traced object. You’ll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like jax.jit and jax.grad, and you can read more about them in JAX Key Concepts: Tracing.

In this example, the workaround is to use another tool from the Debugging runtime values guide: namely jax.debug.print, which allows us to print runtime values even when they’re traced:

def vae_loss(model: VAE, x: jax.Array):
  logits, mean, std = model(x)

  kl_loss = jnp.mean(0.5 * jnp.mean(
      -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
  reconstruction_loss = jnp.mean(
    optax.sigmoid_binary_cross_entropy(logits, x)
  )
  jax.debug.print("kl_loss: {}", kl_loss)
  jax.debug.print("reconstruction_loss: {}", reconstruction_loss)
  return reconstruction_loss + 0.1 * kl_loss

model = VAE(
  image_shape=(8, 8),
  hidden_size=32,
  latent_size=8,
  rngs=nnx.Rngs(0, noise=1),
)

optimizer = nnx.Optimizer(model, optax.adam(1e-3))

for i in range(5):
  train_step(model, optimizer, images_train)
kl_loss: 101859.0
reconstruction_loss: -9.908414840698242
kl_loss: 41515.171875
reconstruction_loss: -17.777271270751953
kl_loss: 19582.40234375
reconstruction_loss: -14.064595222473145
kl_loss: 10484.7431640625
reconstruction_loss: -17.528074264526367
kl_loss: 6111.37744140625
reconstruction_loss: -14.794851303100586

Let’s iterate a few hundred more times (we’ll use the IPython %%capture magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values:

%%capture
for i in range(300):
  train_step(model, optimizer, images_train)
loss = train_step(model, optimizer, images_train)
kl_loss: nan
reconstruction_loss: nan

The output above suggests that the large negative value is coming from the reconstruction_loss term. Let’s return to this and inspect what it’s actually doing:

reconstruction_loss = jnp.mean(
  optax.sigmoid_binary_cross_entropy(logits, x)
)

This is a binary cross entropy described at optax.sigmoid_binary_cross_entropy. Based on the Optax documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a 0 or 1) – but in the current implementation x is associated with images_train, which is not a binary label!

print(images_train[0])
[[ 0.  3. 13. 16.  9.  0.  0.  0.]
 [ 0. 10. 15. 13. 15.  2.  0.  0.]
 [ 0. 15.  4.  4. 16.  1.  0.  0.]
 [ 0.  0.  0.  5. 16.  2.  0.  0.]
 [ 0.  0.  1. 14. 13.  0.  0.  0.]
 [ 0.  0. 10. 16.  5.  0.  0.  0.]
 [ 0.  4. 16. 13.  8. 10.  9.  1.]
 [ 0.  2. 16. 16. 14. 12.  9.  1.]]

This is likely the source of the issue: we forgot to normalize the input images to the range (0, 1)!

Let’s fix this by binarizing the inputs, and then run the training loop again (redefining the loss function to remove the debug statements):

images_normed = (digits.images / 16) > 0.5
splits = train_test_split(images_normed, random_state=0)
images_train, images_test = map(jnp.asarray, splits)

def vae_loss(model: VAE, x: jax.Array):
  logits, mean, std = model(x)

  kl_loss = jnp.mean(0.5 * jnp.mean(
      -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
  reconstruction_loss = jnp.mean(
    optax.sigmoid_binary_cross_entropy(logits, x)
  )
  return reconstruction_loss + 0.1 * kl_loss

model = VAE(
  image_shape=(8, 8),
  hidden_size=32,
  latent_size=8,
  rngs=nnx.Rngs(0, noise=1),
)

optimizer = nnx.Optimizer(model, optax.adam(1e-3))

for epoch in range(2001):
  loss = train_step(model, optimizer, images_train)
  if epoch % 500 == 0:
    print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 0.7713417410850525
Epoch 500 loss: 0.31561601161956787
Epoch 1000 loss: 0.2794969081878662
Epoch 1500 loss: 0.2690196633338928
Epoch 2000 loss: 0.2631562352180481

The loss values are now “behaving” without showing NaNs.

We have successfully debugged the initial NaN problem, which was not in the VAE model but rather in the input data.

Exploring the VAE model results#

Now that we have a trained VAE model, let’s explore what it can be used for.

First, let’s pass the test data through the model to output the result of the associated latent space representation for each input.

Pass the logits through a sigmoid function to recover predicted images in the input space:

logits, mean, std = model(images_test)
images_pred = jax.nn.sigmoid(logits)

Let’s visualize several of these inputs and outputs:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 10, figsize=(6, 1.5),
                       subplot_kw={'xticks':[], 'yticks':[]},
                       gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
  ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian')
  ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian')
_images/a63062e2b1f12d50a5af45e4e4e96993b86392047abce0462b040b6b7c6e835c.png

The top row here are the input images, and the bottom row are what the model “thinks” these images look like, given their latent space representation. There’s not perfect fidelity, but the essential features are recovered.

We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let’s generate 36 new digits this way:

import numpy as np

# generate new images by sampling the latent space
z = np.random.normal(scale=1.5, size=(36, model.latent_size))
logits = model.decoder(z).reshape(-1, 8, 8)
images_gen = nnx.sigmoid(logits)

fig, ax = plt.subplots(6, 6, figsize=(4, 4),
                       subplot_kw={'xticks':[], 'yticks':[]},
                       gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(36):
  ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')
_images/0ecce21e5e2ecd1270d6ccdff43d92c1cf18eef335c76d3de4d28ca56ff416db.png

Another possibility here is to use the latent model to interpolate between two entries in the training set through the latent model space. Here’s an interpolation between a digit 9 and a digit 3:

z, _, _ = model.encoder(images_train.reshape(-1, 64))
zrange = jnp.linspace(z[2], z[9], 10)

logits = model.decoder(zrange).reshape(-1, 8, 8)
images_gen = nnx.sigmoid(logits)

fig, ax = plt.subplots(1, 10, figsize=(8, 1),
                       subplot_kw={'xticks':[], 'yticks':[]},
                       gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
  ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')
_images/5b52800bdd3bc4c9907a98ced4d52fd40643bf5033f69846311fbf2335b67bdb.png

Summary#

This tutorial offered an example of defining and training a generative model - a simplified VAE - and approaches to debugging JAX programs using the jax.debug_nans configuration and the jax.debug.print function.

You can learn more about debugging on the JAX documentation site in Debugging runtime values and Introduction to debugging.