Introduction to JAX

Implementing a Neural Network with Flax

Part A.1: Scaling Models on a Single GPU – A Practical Guide

Part A.2: Scaling Up Transformers on a Single GPU

Extra

Part 2.1: Introduction to Distributed Computing in JAX

Part 2.2 (Fully-Sharded) Data Parallelism)