Scaling Up Model Training | Notion
Introduction to JAX
Implementing a Neural Network with Flax
Part A.1: Scaling Models on a Single GPU – A Practical Guide
Part A.2: Scaling Up Transformers on a Single GPU
Extra
Part 2.1: Introduction to Distributed Computing in JAX
Part 2.2 (Fully-Sharded) Data Parallelism)