Installing the stack#

jax-ai-stack is a metapackage that can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.

  • flax: build neural networks with JAX

  • ml_dtypes: NumPy dtype extensions for machine learning.

  • optax: gradient processing and optimization in JAX.

  • orbax: checkpointing and persistence utilities for JAX.

  • chex: utilities for writing reliable JAX code.

  • grain: data loading.

Optional packages#

Additionally, there are optional packages you can install with pip extras. The following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.

Pinned versions#

The jax-ai-stack meta-package does periodic releases, with date-based version strings. For example, if you’d like to pin the set of packages from November 2024, you can use this installation command:

pip install jax-ai-stack==2024.11.1

For the full list of released versions and the pinned packages, refer to the Change log.

Hardware support#

To install jax-ai-stack with hardware-specific JAX support, add the JAX installation command in the same pip install invocation. For example:

pip install jax-ai-stack "jax[cuda]"  # JAX + AI stack with GPU/CUDA support
pip install jax-ai-stack "jax[tpu]"  # JAX + AI stack with TPU support

For more information on available options for hardware-specific JAX installation, refer to JAX installation.