With JAX basics under your belt, it’s time to build a neural network! While you could construct everything from scratch in JAX (as shown in the JAX documentation), why reinvent the wheel? JAX-based libraries like Flax make creating neural networks much simpler and more intuitive. Think of Flax as JAX’s equivalent of PyTorch’s torch.nn.

Why Flax?

There’s no shortage of great JAX-based libraries for building neural networks. Here are a few notable ones:

For this tutorial, we’ll use Flax due to its intuitive API, flexibility, and vibrant community. That said, every library has its strengths, so feel free to explore and find your favorite!

The XOR Problem: A Simple Challenge

To demonstrate Flax in action, we’ll build a classifier for a classic problem: XOR. Given two binary inputs $x_1$ and $x_2$:

Why XOR? A single linear neuron can’t solve this problem, making it a great example for showcasing a small neural network. To spice things up, we’ll move XOR into continuous space and add Gaussian noise. Here’s what the dataset might look like:

image.png


Defining Neural Networks in Flax

Flax Basics

In Flax, networks are defined using the nn.Module class. Think of a module as a building block: it can represent a single layer, a larger sub-network, or even the entire model. Here’s a basic template:

from flax import linen as nn

class MyModule(nn.Module):
    # Attributes like hidden dimensions or layer counts go here
    varname: int

    def setup(self):
        # Define submodules or layers (lazy initialization)
        pass

    def __call__(self, x):
        # Define the forward pass
        pass