Jax AI Stack#
Flexible and scalable
Iterate quickly and with efficient out-of-the-box scaling
Run anywhere
Execute the same code on any CPU, GPU, & TPU
Reliability and compatibility
JAX AI Stack tested releases provide high reliability by ensuring compatibility across its libraries
JAX AI Stack
The JAX AI Stack is a curated collection of libraries that researchers and engineers, both inside and outside of Google, have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more.
- JAX - core array operations and program transformations
- Flax - For building neural networks
- Orbax -For checkpointing and persistence utilities
- Optax - For gradient processing and optimization
- ml_dtypes - NumPy dtype extensions for machine learning.
- Optional data loading libraries (Grain or tf.data)
Powered by JAX
JAX is a Python library for efficient array-oriented computation and program transformation. JAX's flexible and modular approach has encouraged communities across AI, scientific computing, simulation and more to build ecosystems on top of it.
JAX is often compared to neural network libraries like PyTorch, but the core JAX package contains very little specific to deep learning. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package. This helps drive innovation as researchers and other users explore what's possible.
Learn more about JAXPart of a wider AI Ecosystem
The JAX AI Stack is part of a growing AI community and ecosystem around JAX. Modularity and choice are important principles for JAX and we are excited to see the development happening around it!
Learn more about the AI Ecosystem