Jax AI Stack

Jax AI Stack#

JAX AI Stack

Flexible, scalable components for AI research and development

Get started

Flexible and scalable

Iterate quickly and with efficient out-of-the-box scaling

Run anywhere

Execute the same code on any CPU, GPU, & TPU

Reliability and compatibility

JAX AI Stack tested releases provide high reliability by ensuring compatibility across its libraries

JAX AI Stack

The JAX AI Stack is a curated collection of libraries that researchers and engineers, both inside and outside of Google, have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more.

  • JAX - core array operations and program transformations
  • Flax - For building neural networks
  • Orbax -For checkpointing and persistence utilities
  • Optax - For gradient processing and optimization
  • ml_dtypes - NumPy dtype extensions for machine learning.
  • Optional data loading libraries (Grain or tf.data)
Get started

Powered by JAX

JAX is a Python library for efficient array-oriented computation and program transformation. JAX's flexible and modular approach has encouraged communities across AI, scientific computing, simulation and more to build ecosystems on top of it.

JAX is often compared to neural network libraries like PyTorch, but the core JAX package contains very little specific to deep learning. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package. This helps drive innovation as researchers and other users explore what's possible.

Learn more about JAX

Part of a wider AI Ecosystem

The JAX AI Stack is part of a growing AI community and ecosystem around JAX. Modularity and choice are important principles for JAX and we are excited to see the development happening around it!

Learn more about the AI Ecosystem