When you hear "scaling models," your mind might immediately jump to distributed training across GPUs or even across entire clusters. But there’s a lot you can do to scale on a single GPU—especially when memory is the bottleneck. In this section, we’ll cover a set of techniques to train larger models more efficiently, such as mixed precision training, activation checkpointing, and gradient accumulation. These methods are game-changers, even if you’re working with just one device.
On a single GPU, memory often becomes the limiting factor, not compute power. If your model can’t fit into memory (along with all the activations, gradients, and optimizer states), training becomes impossible. By reducing the memory footprint of the training process, these techniques also pave the way for scaling to multi-GPU setups later. So, let’s first master them on a single device.
First up: mixed precision training. It’s a simple yet powerful technique that speeds up training and reduces memory usage by using both 16-bit and 32-bit floating-point numbers. Here's the idea:
Why is this useful? Memory savings and faster computations! But there’s a trade-off: 16-bit floats have lower precision, which can cause issues like numerical instability (e.g., underflow or overflow). To mitigate this, we use loss scaling, where the loss and gradients are multiplied by a constant factor to keep values within a representable range.
Not all 16-bit floats are created equal. Let’s quickly compare the two most common formats:
Format | Precision | Exponent Range | Ideal Use |
---|---|---|---|
float16 |
High | Smaller range | Precision-sensitive tasks |
bfloat16 |
Lower | Matches float32 |
Range-critical computations |
Bfloat16 stands out because it matches float32
in range, eliminating the need for loss scaling in most cases. It’s also natively supported by many accelerators like TPUs and NVIDIA GPUs, offering up to 2x speedups without sacrificing much accuracy. In this notebook, we’ll use bfloat16
as our default mixed precision format.
In JAX, implementing mixed precision involves casting activations and intermediate computations to bfloat16
, while keeping model parameters and optimizer states in float32
. Here’s a streamlined way to do it:
Before diving into the code, let’s enable some JAX flags for GPU performance: