Why bother with distributed computing?

Deep learning has been making serious strides, thanks to big datasets and even bigger models. But there's a catch: training these models is expensive—both in time and hardware. A single GPU often isn't enough. The solution? Parallelism. We split up the work across multiple devices to make training faster and more scalable. As datasets and models grow, this becomes less of a luxury and more of a necessity.

Before jumping into different parallelism strategies, let’s first get a handle on multi-device processing in JAX. This section lays the groundwork, and in Part 2.2, we’ll dive into data parallelism to train a small neural network across multiple devices. If you’re already comfortable with distributed computing in JAX, feel free to skip ahead.

It’s fine if you don’t have mutliple GPUs or any gpu at all, jax can actually easily simulate having mutliple cpus using the command XLA_FLAGS=--xla_force_host_platform_device_count=8

Setting things up

If you're using Google Colab, no need to switch to a GPU runtime—stick with the CPU, since we’ll be simulating multiple devices anyway. Running this locally? If you have multiple GPUs, set USE_CPU_ONLY=False to use them.

import os

# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True

flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
    flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
    # Enforce CPU-only execution
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    # GPU flags
    flags += (
        "--xla_gpu_enable_triton_softmax_fusion=true "
        "--xla_gpu_triton_gemm_any=false "
        "--xla_gpu_enable_async_collectives=true "
        "--xla_gpu_enable_latency_hiding_scheduler=true "
        "--xla_gpu_enable_highest_priority_async_stream=true "
    )
os.environ["XLA_FLAGS"] = flags

With our environment on, let’s import the libraries we need:

import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

Distributed Computing in JAX

Before we dive into advanced parallelism strategies, let’s break down the basics of distributed computing in JAX. This section is all about the building blocks—understanding how JAX handles multiple devices, sharding, and parallel execution. If you're already familiar with concepts like shard_map, feel free to skip ahead to the next part. Otherwise, let’s roll.

Checking Available Devices

JAX makes it easy to check what devices are available:

jax.devices()

If you've set XLA_FLAGS=--xla_force_host_platform_device_count=8, you should see 8 CPU devices listed:

[CpuDevice(id=0), CpuDevice(id=1), ..., CpuDevice(id=7)]

If you actually have GPUs or TPUs, JAX will detect them too. But since we’re in a tutorial setting, we’ll mostly focus on parallelizing within a single process/machine. The same concepts scale up to distributed setups across multiple hosts—JAX handles that, but we’ll keep things simple for now.

Where Does JAX Put Arrays?

By default, when you create an array, JAX assigns it to a single device: