Skip to main content
Ctrl+K
JAX AI Stack - Home JAX AI Stack - Home
    • Get Started
    • Tutorials
    • Blog
    • News
  • Community & Events
  • Videos
  • .md
    • Get Started
    • Tutorials
    • Blog
    • News
  • Community & Events
  • Videos
  • .md
JAX AI Stack - Home JAX AI Stack - Home
  • JAX AI Stack

Getting started

  • Installing the stack
  • Getting started with JAX for ML
    • Part 1: JAX neural net basics
    • Part 2: Debug a variational autoencoder (VAE)
    • Part 3: Train a diffusion model for image generation

Tutorials

  • Visualize JAX model metrics with TensorBoard
  • Introduction to Data Loaders
    • Introduction to Data Loaders on CPU with JAX
    • Introduction to Data Loaders on GPU with JAX
  • From PyTorch to JAX
    • JAX for PyTorch users
    • Porting a PyTorch model to JAX

Example applications

  • Train a miniGPT language model with JAX
  • Basic text classification with 1D CNN
  • Text classification with a transformer language model using JAX
  • Machine Translation with encoder-decoder transformer model
  • Image segmentation with UNETR model
  • Image Captioning with Vision Transformer (ViT) model
  • Train a Vision Transformer (ViT) for image classification with JAX
  • Time series classification with CNN

Developer resources

  • Contribute to documentation

Other resources

  • Blog
  • News
  • Events
  • Videos
  • Learning resources
  • .md

News

News#

  • Launched Tunix: A JAX-Native Library for LLM Post-Training

previous

Blog

next

Events

By JAX team

© Copyright 2024, JAX team.