Let’s talk about JAX. Why? Because it’s insanely cool and incredibly fast. If you’re coming from frameworks like PyTorch or TensorFlow, you’re in for a treat—JAX offers a new perspective on building and training neural networks. In this tutorial, we’ll walk through the basics of JAX and even dip into using Flax to build and train your own neural nets. Along the way, we’ll uncover what makes JAX special and why you might want to add it to your deep learning toolkit.

Why Bother With JAX?

You’ve probably already asked yourself, “Why JAX? I’m happy with PyTorch or TensorFlow.” Great question. Let me hit you with the short answer: speed and efficiency. For example, we train a small GoogleNet model on CIFAR-10, and JAX does it three times faster than PyTorch on similar hardware. Yup, three times.

Now, hold up—before you get too hyped, let’s be clear: the speed boost varies. For larger models, bigger batch sizes, or less powerful GPUs, the difference might not be as dramatic. And no, the code we’ll use here isn’t designed for benchmarking, but it highlights the potential. JAX achieves this speed by leveraging just-in-time (JIT) compilation, which optimizes numerical programs for accelerators like GPUs and TPUs. PyTorch, with its dynamic computation graphs, simply can’t match this level of efficiency because it executes operations as you write them, one by one.

Let’s take a real example. Imagine an Inception block from GoogleNet. This block applies multiple convolutional layers to the same input in parallel. JAX can compile the entire forward pass for the accelerator, cleverly fusing operations to minimize memory access and maximize execution speed. PyTorch, on the other hand, handles one convolution at a time, sending each operation to the GPU sequentially. The result? JAX gets more out of your hardware.

But There’s a Catch...

Of course, there’s no such thing as a free lunch. JAX’s speed comes with some trade-offs. To make its magic happen, JAX imposes a few rules:

  1. No Side Effects

    Functions in JAX need to be pure, meaning they can’t mess with anything outside their scope. For example, in-place operations that modify variables directly aren’t allowed. Even something as simple as torch.rand(...) is off-limits because it changes the global random state. Don’t worry—JAX has its own clever way of handling randomness (we’ll cover that soon).

  2. Static Shapes

    JAX loves knowing what it’s dealing with upfront. It compiles functions based on the anticipated shapes of all arrays or tensors involved. This can lead to issues if shapes or control flow depend on the data itself. For instance, consider y = x[x > 3]: the shape of y depends on the number of elements in x greater than 3. JAX isn’t a fan of that kind of dynamic behavior. But don’t sweat it—most standard deep learning workflows fit neatly into JAX’s constraints.

What’s Next?

Throughout this notebook, we’ll guide you through writing JAX code that adheres to these rules while building powerful and efficient neural networks. You’ll learn how to:

By the end, you’ll see how JAX transforms the way you think about numerical computation and neural network training. Ready to dive in? Let’s go!

Throughout this tutorial, we’ll often draw comparisons to PyTorch (which we’ll also use for its data-loading library—check out our PyTorch tutorial for a refresher). The goal here isn’t to reinvent the wheel but to demonstrate how JAX can work alongside other frameworks. For example, you can use PyTorch for data loading and TensorFlow for logging in TensorBoard. Meanwhile, we’ll leverage Flax as our neural network library and Optax for optimizers. More on those later—let’s start with the basics.

Standard Libraries

First, let’s get the usual suspects out of the way:

import os
import math
import numpy as np
import time

# Plotting imports
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')  # For high-quality export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

# Progress bar
from tqdm.auto import tqdm

JAX as NumPy on Accelerators