When you hear "scaling models," your mind might immediately jump to distributed training across GPUs or even across entire clusters. But there’s a lot you can do to scale on a single GPU—especially when memory is the bottleneck. In this section, we’ll cover a set of techniques to train larger models more efficiently, such as mixed precision training, activation checkpointing, and gradient accumulation. These methods are game-changers, even if you’re working with just one device.

Why Memory Matters

On a single GPU, memory often becomes the limiting factor, not compute power. If your model can’t fit into memory (along with all the activations, gradients, and optimizer states), training becomes impossible. By reducing the memory footprint of the training process, these techniques also pave the way for scaling to multi-GPU setups later. So, let’s first master them on a single device.


Mixed Precision Training: More Speed, Less Memory

First up: mixed precision training. It’s a simple yet powerful technique that speeds up training and reduces memory usage by using both 16-bit and 32-bit floating-point numbers. Here's the idea:

Why is this useful? Memory savings and faster computations! But there’s a trade-off: 16-bit floats have lower precision, which can cause issues like numerical instability (e.g., underflow or overflow). To mitigate this, we use loss scaling, where the loss and gradients are multiplied by a constant factor to keep values within a representable range.


Float Formats: Float16 vs. Bfloat16

Not all 16-bit floats are created equal. Let’s quickly compare the two most common formats:

Format Precision Exponent Range Ideal Use
float16 High Smaller range Precision-sensitive tasks
bfloat16 Lower Matches float32 Range-critical computations

Bfloat16 stands out because it matches float32 in range, eliminating the need for loss scaling in most cases. It’s also natively supported by many accelerators like TPUs and NVIDIA GPUs, offering up to 2x speedups without sacrificing much accuracy. In this notebook, we’ll use bfloat16 as our default mixed precision format.

image.png


Implementing Mixed Precision Training in JAX

In JAX, implementing mixed precision involves casting activations and intermediate computations to bfloat16, while keeping model parameters and optimizer states in float32. Here’s a streamlined way to do it:

Setting XLA Flags for Performance

Before diving into the code, let’s enable some JAX flags for GPU performance: