Getting started with JAX for ML#
JAX is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.
Who is this tutorial for?#
This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. JAX is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with NumPy, and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.
What does this tutorial cover?#
JAX focuses on array-based computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:
Flax NNX: A machine learning library designed for defining and building scalable neural networks using JAX.
Optax: A high-performance function optimization library that comes with built-in optimizers and loss functions.
After working through this content, you may wish to visit the JAX documentation site for a deeper dive into the core JAX concepts.