We have two types of data parallelism: the normal one and Fully Sharded One
training large deep learning models is hard, but parallelism makes it easier. this series will break down three main strategies: data parallelism, pipeline parallelism, and tensor parallelism. let’s start with data parallelism—it’s the simplest and most widely used.
the idea: if you’ve got a massive batch of data, split it into smaller chunks, give each chunk to a different device, and process them in parallel. after that, you gather the results and update the model. this approach is supported by most frameworks like pytorch and tensorflow, which is why it’s the first one we’ll dive into.
Like: If batch size if 12 and you have 3 gpus, you’ll have something like:
Device 1: [0:4, …, …]/ Device 2: [4:8, …, …] / Device 3: [8:12, …, ….]
later, we’ll get into pipeline and tensor parallelism, which focus on splitting up the model’s computations instead of just the data.
for now, the focus is on data parallelism in jax. but don’t worry—these concepts transfer well to other frameworks. with distributed computing (from part 2.1), we’ll set up a basic data parallel strategy to train a small neural network across multiple devices. then, we’ll explore fully-sharded data parallelism (fsdp). this reduces memory usage by splitting the model parameters across devices—think zeRO optimizer, but made simple.
stick around; this stuff is the foundation for scaling modern ai.

getting started with scaling: first, let’s set up the environment. this part might seem boring, but it’s critical to make sure everything works smoothly later. here’s the deal:
if you’re running this on colab, you’ll need to grab some python scripts from github. if you’re local, skip this step, and the setup will take care of itself.
import os
import urllib.request
from urllib.error import HTTPError
base_url = "<https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/>"
python_files = ["single_gpu.py", "utils.py"]
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print("Download error. Try grabbing it directly from GitHub or contact the author with this error:", e)
next, we’re going to simulate 8 devices on a single cpu. no fancy hardware needed. colab users? no gpu required. running locally with multiple gpus? you can tweak the setup to leverage them.
from utils import simulate_CPU_devices
simulate_CPU_devices()
once the environment is good to go, let’s bring in the libraries.
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
also, we’ll reuse some utilities from earlier parts—no need to reinvent the wheel.
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics
one last thing: random number generation. in part 2.1, we built a utility to fold rngs across mesh axes. here’s the function for reference: