# Getting started with JAX for ML

[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.

## Who is this tutorial for?

This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.

## What does this tutorial cover?

JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:

- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.

After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.

## Let's get started!

```{toctree}
:maxdepth: 1

neural_net_basics
digits_vae
digits_diffusion_model
```
