Variational autoencoder (VAE) and debugging in JAX#
This tutorial explores a simplified version of a generative model called Variational Autoencoder (VAE) with scikit-learn digits
dataset, and expands on what we learned in Getting started with JAX. Along the way, you’ll learn more about how JAX’s JIT compilation (jax.jit
) actually works, and what this means for debugging JAX programs, as we learn how to identify what can go wrong during model training.
If you are new to JAX for AI, check out the first tutorial, which explains how to build a simple neural netwwork with Flax and Optax, and JAX’s key features, including the NumPy-style interface with jax.numpy
, JAX transformations for JIT compilation with jax.jit
, automatic vectorization with jax.vmap
, and automatic differentiation with jax.grad
.
Loading the data#
As before, this example uses the well-known, small and self-contained scikit-learn digits
dataset:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import jax.numpy as jnp
digits = load_digits()
splits = train_test_split(digits.images, random_state=0)
images_train, images_test = map(jnp.asarray, splits)
print(f"{images_train.shape=}")
print(f"{images_test.shape=}")
images_train.shape=(1347, 8, 8)
images_test.shape=(450, 8, 8)
The dataset comprises 1800 images of hand-written digits, each represented by an 8x8
pixel grid, and their corresponding labels. For visualization of this data, refer to loading the data in the previous tutorial.
Defining the VAE with Flax#
Previously, we learned how to use Flax NNX to create a simple feed-forward neural network trained for classification with an architecture that looked roughly like this:
import jax
import jax.numpy as jnp
from flax import nnx
class SimpleNN(nnx.Module):
def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs):
self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
x = nnx.selu(self.layer1(x))
x = nnx.selu(self.layer2(x))
return self.layer3(x)
This kind of network has one output per class, and the loss function is designed such that once the model is trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.
To create a VAE with Flax NNX, we will use similar building blocks - subclassing flax.nnx.Module
, stacking flax.nnx.Linear
layers, and adding a rectified linear unit activation function (flax.nnx.relu
). A VAE maps the input data into the parameters of a probability distribution (mean
, std
), and the output is a small probabilistic model representing the data.
Note that the classic VAE is generally based on convolutional layers, this example uses linear layers for simplicity.
The sub-network that produces this probabilistic encoding is the Encoder
:
class Encoder(nnx.Module):
def __init__(self, input_size: int, intermediate_size: int, output_size: int,
*, rngs: nnx.Rngs):
self.rngs = rngs
self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs)
self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs)
self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs)
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
x = self.linear(x)
x = jax.nn.relu(x)
mean = self.linear_mean(x)
std = jnp.exp(self.linear_std(x))
key = self.rngs.noise()
z = mean + std * jax.random.normal(key, mean.shape)
return z, mean, std
The idea here is that mean
and std
define a low-dimensional probability distribution over a latent space, and that z
is a draw from this latent space that represents the training data.
To ensure that this latent distribution faithfully represents the actual data, define a Decoder
that maps back to the input space as follows:
class Decoder(nnx.Module):
def __init__(self, input_size: int, intermediate_size: int, output_size: int,
*, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs)
self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs)
def __call__(self, z: jax.Array) -> jax.Array:
z = self.linear1(z)
z = jax.nn.relu(z)
logits = self.linear2(z)
return logits
Now, define the VAE model (again by subclassing flax.nnx.Module
) by combining Encoder
and Decoder
in a single network (VAE
).
The model returns both the reconstructed image and the internal latent space model:
class VAE(nnx.Module):
def __init__(
self,
image_shape: tuple[int, int],
hidden_size: int,
latent_size: int,
*,
rngs: nnx.Rngs
):
self.image_shape = image_shape
self.latent_size = latent_size
input_size = image_shape[0] * image_shape[1]
self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs)
self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs)
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
x = jax.vmap(jax.numpy.ravel)(x) # flatten
z, mean, std = self.encoder(x)
logits = self.decoder(z)
logits = jnp.reshape(logits, (-1, *self.image_shape))
return logits, mean, std
Next, we need to define the loss function. The are two components to the model that we want to ensure:
The
logits
output faithfully reconstruct the input image.The model represented by
mean
andstd
faithfully represents the “true” latent distribution.
Note that VAE uses a loss function based on the Evidence lower bound to quantify these two goals in a single loss value:
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
return reconstruction_loss + 0.1 * kl_loss
Now all that’s left:
Instantiate the
VAE
model.Select
optax.adam
(the Adam optimizer in our example), and instantiate theoptimizer
withflax.nnx.Optimizer
for setting the train step.Define the
train_step
usingflax.nnx.value_and_grad
for computing the gradients and update the model’s parameters using theoptimizer
.Use the
flax.nnx.jit
transformation decorator to trace thetrain_step
function for just-in-time compilation.Run the training loop.
import optax
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
@nnx.jit
def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):
loss, grads = nnx.value_and_grad(vae_loss)(model, x)
optimizer.update(grads)
return loss
for epoch in range(2001):
loss = train_step(model, optimizer, images_train)
if epoch % 500 == 0:
print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 10175.9921875
Epoch 500 loss: nan
Epoch 1000 loss: nan
Epoch 1500 loss: nan
Epoch 2000 loss: nan
Notice in the output that something has gone wrong - the loss value has become NaN after some number of iterations.
Debugging NaNs in JAX#
Despite our best efforts, the VAE
model is producing NaNs. What can we do?
JAX offers a number of debugging approaches for situations like this, outlined in JAX’s Debugging runtime values guide. (There is also the Introduction to debugging tutorial you may find useful.)
In this case, we can use the jax.debug_nans
configuration to check where the NaN value is arising.
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
with jax.debug_nans(True):
for epoch in range(2001):
train_step(model, optimizer, images_train)
Invalid nan value encountered in the output of a jax.jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
Cell In[8], line 12
10 with jax.debug_nans(True):
11 for epoch in range(2001):
---> 12 train_step(model, optimizer, images_train)
File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py:350, in jit.<locals>.jit_wrapper(*args, **kwargs)
340 with graph.update_context(jit_wrapper):
341 pure_args, pure_kwargs = extract.to_tree(
342 (args, kwargs),
343 prefix=(in_shardings, kwarg_shardings)
(...) 348 ctxtag=jit_wrapper,
349 )
--> 350 pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
351 *pure_args, **pure_kwargs
352 )
353 _args_out, _kwargs_out, out = extract.from_tree(
354 (pure_args_out, pure_kwargs_out, pure_out),
355 merge_fn=_jit_merge_fn,
356 is_inner=False,
357 ctxtag=jit_wrapper,
358 )
359 return out
[... skipping hidden 4 frame]
File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py:129, in JitFn.__call__(self, *pure_args, **pure_kwargs)
121 def __call__(self, *pure_args, **pure_kwargs):
122 args, kwargs = extract.from_tree(
123 (pure_args, pure_kwargs),
124 merge_fn=_jit_merge_fn,
125 ctxtag=self.ctxtag,
126 is_inner=True,
127 )
--> 129 out = self.f(*args, **kwargs)
131 args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
132 pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
133 (args_out, kwargs_out, out),
134 prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings),
135 ctxtag=self.ctxtag,
136 split_fn=_jit_split_fn,
137 )
Cell In[7], line 14, in train_step(model, optimizer, x)
12 @nnx.jit
13 def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):
---> 14 loss, grads = nnx.value_and_grad(vae_loss)(model, x)
15 optimizer.update(grads)
16 return loss
File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/graph.py:1817, in UpdateContextManager.__call__.<locals>.update_context_manager_wrapper(*args, **kwargs)
1814 @functools.wraps(f)
1815 def update_context_manager_wrapper(*args, **kwargs):
1816 with self:
-> 1817 return f(*args, **kwargs)
File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/transforms/autodiff.py:163, in _grad_general.<locals>.grad_wrapper(***failed resolving arguments***)
151 pure_args = extract.to_tree(
152 args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad'
153 )
155 gradded_fn = transform(
156 GradFn(f, has_aux, nondiff_states),
157 argnums=jax_argnums,
(...) 160 allow_int=allow_int,
161 )
--> 163 fn_out = gradded_fn(*pure_args)
165 def process_grads(grads):
166 return jax.tree.map(
167 lambda x: x.state if isinstance(x, extract.NodeStates) else x,
168 grads,
169 is_leaf=lambda x: isinstance(x, extract.NodeStates),
170 )
[... skipping hidden 20 frame]
File ~/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/jax/_src/pjit.py:219, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
217 except dispatch.InternalFloatingPointError as e:
218 if getattr(fun, '_apply_primitive', False):
--> 219 raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
220 dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)
222 if p.attrs_tracked:
FloatingPointError: invalid value (nan) encountered in dot_general
When differentiating the code at the top of the callstack:
/home/docs/checkouts/readthedocs.org/user_builds/jax-ai-stack/envs/latest/lib/python3.12/site-packages/flax/nnx/nn/linear.py:383:8 (Linear.__call__)
The output here is complicated, because the function we’re evaluating is complicated. The key to “deciphering” this traceback is to look for the places where the traceback touches our implementation.
In particular here, the output above indicates that NaN values arise during the gradient update:
<ipython-input-9-b5b28eeeadf6> in train_step()
14 loss, grads = nnx.value_and_grad(vae_loss)(model, x)
---> 15 optimizer.update(grads)
16 return loss
and further down from this, the details of the gradient update step where the NaN is arising:
/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in <lambda>()
280 lambda g, t: (
--> 281 (1 - decay) * (g**order) + decay * t if g is not None else None
282 ),
This suggests that the gradient is returning values that lead to NaN
during the model update. Typically, this would come about when the gradient itself is for some reason diverging.
A diverging gradient means that something with the loss function may be amiss. Previously, we had loss=NaN
at iteration 500. Let’s print the progress up to this point:
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for epoch in range(501):
loss = train_step(model, optimizer, images_train)
if epoch % 50 == 0:
print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 10175.9921875
Epoch 50 loss: -15.788647651672363
Epoch 100 loss: -134.56039428710938
Epoch 150 loss: -1761.32275390625
Epoch 200 loss: nan
Epoch 250 loss: nan
Epoch 300 loss: nan
Epoch 350 loss: nan
Epoch 400 loss: nan
Epoch 450 loss: nan
Epoch 500 loss: nan
It looks like the loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.
At this point, we may wish to inspect the values within the loss function itself to see where the diverging loss might be coming from.
In typical Python programs we can do this by inserting either a print
statement or a breakpoint
in the loss function. This may look something like this:
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
print("kl loss", kl_loss)
print("reconstruction loss", reconstruction_loss)
return reconstruction_loss + 0.1 * kl_loss
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
train_step(model, optimizer, images_train)
kl loss Traced<ShapedArray(float32[])>with<JVPTrace> with
primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
pval = (ShapedArray(float32[]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7d8874a688c0>, in_tracers=(Traced<ShapedArray(float32[1347]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x7d8874a81530; to 'JaxprTracer' at 0x7d8874a81450>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347]. let
b:f32[] = reduce_sum[axes=(0,)] a
c:f32[] = div b 1347.0
in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'donated_invars': (False,), 'ctx_mesh': None, 'name': '_mean', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7d88766b19f0>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None))
reconstruction loss Traced<ShapedArray(float32[])>with<JVPTrace> with
primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
pval = (ShapedArray(float32[]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7d8874a68b40>, in_tracers=(Traced<ShapedArray(float32[1347,8,8]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x7d8874a800e0; to 'JaxprTracer' at 0x7d8874a80280>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347,8,8]. let
b:f32[] = reduce_sum[axes=(0, 1, 2)] a
c:f32[] = div b 86208.0
in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'donated_invars': (False,), 'ctx_mesh': None, 'name': '_mean', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7d88766b3940>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), xla_metadata=None))
Array(10175.992, dtype=float32)
But here rather than printing the value, we’re getting some kind of Traced
object. You’ll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like jax.jit
and jax.grad
, and you can read more about them in JAX Key Concepts: Tracing.
In this example, the workaround is to use another tool from the Debugging runtime values guide: namely jax.debug.print
, which allows us to print runtime values even when they’re traced:
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
jax.debug.print("kl_loss: {}", kl_loss)
jax.debug.print("reconstruction_loss: {}", reconstruction_loss)
return reconstruction_loss + 0.1 * kl_loss
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for i in range(5):
train_step(model, optimizer, images_train)
kl_loss: 101859.0
reconstruction_loss: -9.908414840698242
kl_loss: 41515.171875
reconstruction_loss: -17.777271270751953
kl_loss: 19582.40234375
reconstruction_loss: -14.064595222473145
kl_loss: 10484.7431640625
reconstruction_loss: -17.528074264526367
kl_loss: 6111.37744140625
reconstruction_loss: -14.794851303100586
Let’s iterate a few hundred more times (we’ll use the IPython %%capture
magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values:
%%capture
for i in range(300):
train_step(model, optimizer, images_train)
loss = train_step(model, optimizer, images_train)
kl_loss: nan
reconstruction_loss: nan
The output above suggests that the large negative value is coming from the reconstruction_loss
term. Let’s return to this and inspect what it’s actually doing:
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
This is a binary cross entropy described at optax.sigmoid_binary_cross_entropy
. Based on the Optax documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a 0
or 1
) – but in the current implementation x
is associated with images_train
, which is not a binary label!
print(images_train[0])
[[ 0. 3. 13. 16. 9. 0. 0. 0.]
[ 0. 10. 15. 13. 15. 2. 0. 0.]
[ 0. 15. 4. 4. 16. 1. 0. 0.]
[ 0. 0. 0. 5. 16. 2. 0. 0.]
[ 0. 0. 1. 14. 13. 0. 0. 0.]
[ 0. 0. 10. 16. 5. 0. 0. 0.]
[ 0. 4. 16. 13. 8. 10. 9. 1.]
[ 0. 2. 16. 16. 14. 12. 9. 1.]]
This is likely the source of the issue: we forgot to normalize the input images to the range (0, 1)
!
Let’s fix this by binarizing the inputs, and then run the training loop again (redefining the loss function to remove the debug statements):
images_normed = (digits.images / 16) > 0.5
splits = train_test_split(images_normed, random_state=0)
images_train, images_test = map(jnp.asarray, splits)
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
return reconstruction_loss + 0.1 * kl_loss
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for epoch in range(2001):
loss = train_step(model, optimizer, images_train)
if epoch % 500 == 0:
print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 0.7713417410850525
Epoch 500 loss: 0.31561601161956787
Epoch 1000 loss: 0.2794969081878662
Epoch 1500 loss: 0.2690196633338928
Epoch 2000 loss: 0.2631562352180481
The loss values are now “behaving” without showing NaNs.
We have successfully debugged the initial NaN problem, which was not in the VAE
model but rather in the input data.
Exploring the VAE model results#
Now that we have a trained VAE
model, let’s explore what it can be used for.
First, let’s pass the test data through the model to output the result of the associated latent space representation for each input.
Pass the logits
through a sigmoid
function to recover predicted images in the input space:
logits, mean, std = model(images_test)
images_pred = jax.nn.sigmoid(logits)
Let’s visualize several of these inputs and outputs:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 10, figsize=(6, 1.5),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian')
ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian')

The top row here are the input images, and the bottom row are what the model “thinks” these images look like, given their latent space representation. There’s not perfect fidelity, but the essential features are recovered.
We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let’s generate 36 new digits this way:
import numpy as np
# generate new images by sampling the latent space
z = np.random.normal(scale=1.5, size=(36, model.latent_size))
logits = model.decoder(z).reshape(-1, 8, 8)
images_gen = nnx.sigmoid(logits)
fig, ax = plt.subplots(6, 6, figsize=(4, 4),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(36):
ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')

Another possibility here is to use the latent model to interpolate between two entries in the training set through the latent model space.
Here’s an interpolation between a digit 9
and a digit 3
:
z, _, _ = model.encoder(images_train.reshape(-1, 64))
zrange = jnp.linspace(z[2], z[9], 10)
logits = model.decoder(zrange).reshape(-1, 8, 8)
images_gen = nnx.sigmoid(logits)
fig, ax = plt.subplots(1, 10, figsize=(8, 1),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')

Summary#
This tutorial offered an example of defining and training a generative model - a simplified VAE - and approaches to debugging JAX programs using the jax.debug_nans
configuration and the jax.debug.print
function.
You can learn more about debugging on the JAX documentation site in Debugging runtime values and Introduction to debugging.