Installing the stack#
jax-ai-stack
is a metapackage that can be installed with the following command:
pip install jax-ai-stack
This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:
JAX: the core JAX package, which includes array operations and program transformations like
jit
,vmap
,grad
, etc.flax: build neural networks with JAX
ml_dtypes: NumPy dtype extensions for machine learning.
optax: gradient processing and optimization in JAX.
orbax: checkpointing and persistence utilities for JAX.
chex: utilities for writing reliable JAX code.
grain: data loading.
Optional packages#
Additionally, there are optional packages you can install with pip
extras. The following command:
pip install jax-ai-stack[tfds]
will install a compatible version of tensorflow and tensorflow-datasets.
Pinned versions#
The jax-ai-stack
meta-package does periodic releases, with date-based version strings. For
example, if you’d like to pin the set of packages from November 2024, you can use this installation
command:
pip install jax-ai-stack==2024.11.1
For the full list of released versions and the pinned packages, refer to the Change log.
Hardware support#
To install jax-ai-stack
with hardware-specific JAX support, add the JAX installation
command in the same pip install
invocation. For example:
pip install jax-ai-stack "jax[cuda]" # JAX + AI stack with GPU/CUDA support
pip install jax-ai-stack "jax[tpu]" # JAX + AI stack with TPU support
For more information on available options for hardware-specific JAX installation, refer to JAX installation.