In Part 1.1, i talked about implementing mixed precision training, gradient accumulation, and gradient checkpointing on a simple MLP model. Now, it’s time to take things up a notch. In this part, we’ll apply these techniques to a Transformer model, demonstrating how they enable training large models on resource-constrained hardware. Along the way, we’ll dive into profiling techniques to pinpoint bottlenecks and optimize performance.

Before jumping in, i recommend reviewing Part 1.1, as this section builds on concepts and code from the previous part. Additionally, this blog assumes familiarity with the Transformer architecture and its components. If you’re new to Transformers, you can refer to the original paper by Vaswani et al. or a beautiful blog made by my friend @hafedh: https://huggingface.co/blog/not-lain/tensor-dims.


Running the Notebook

This notebook is designed to run on accelerators like GPUs or TPUs. If you’re on Google Colab, make sure to enable the GPU runtime:


Prerequisites

To avoid duplicating code across notebooks, we’ve converted key functions from Part 1.1 into a Python script, which you can import here. If you’re running this on Google Colab, you’ll need to download these scripts first:

import os
import urllib.request
from urllib.error import HTTPError

# Base GitHub URL for Python scripts
base_url = "<https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/>"
# Files to download
python_files = ["single_gpu.py", "utils.py"]
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                f"Failed to download {file_name}. Please try downloading it directly from the GitHub repository. Error:\\n{e}"
            )

The utils.py script contains utility functions, including one for setting XLA flags, which we’ll apply now:

from utils import install_package, set_XLA_flags_gpu

set_XLA_flags_gpu()

Standard Libraries and Modules

Next, we’ll import the usual suspects:

import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from tqdm.auto import tqdm

# Install ml_collections on Colab if not already available
try:
    from ml_collections import ConfigDict
except ModuleNotFoundError:
    install_package("ml_collections")
    from ml_collections import ConfigDict

# Type aliases for clarity
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]

Bringing in Part 1.1 Functions

Finally, we import functions and modules implemented in Part 1.1. If anything here looks unfamiliar, be sure to revisit the earlier part before proceeding.

from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics