With JAX basics under your belt, it’s time to build a neural network! While you could construct everything from scratch in JAX (as shown in the JAX documentation), why reinvent the wheel? JAX-based libraries like Flax make creating neural networks much simpler and more intuitive. Think of Flax as JAX’s equivalent of PyTorch’s torch.nn
.
There’s no shortage of great JAX-based libraries for building neural networks. Here are a few notable ones:
For this tutorial, we’ll use Flax due to its intuitive API, flexibility, and vibrant community. That said, every library has its strengths, so feel free to explore and find your favorite!
To demonstrate Flax in action, we’ll build a classifier for a classic problem: XOR. Given two binary inputs $x_1$ and $x_2$:
1
if one of the inputs is 1
while the other is 0
.0
otherwise.Why XOR? A single linear neuron can’t solve this problem, making it a great example for showcasing a small neural network. To spice things up, we’ll move XOR into continuous space and add Gaussian noise. Here’s what the dataset might look like:
In Flax, networks are defined using the nn.Module
class. Think of a module as a building block: it can represent a single layer, a larger sub-network, or even the entire model. Here’s a basic template:
from flax import linen as nn
class MyModule(nn.Module):
# Attributes like hidden dimensions or layer counts go here
varname: int
def setup(self):
# Define submodules or layers (lazy initialization)
pass
def __call__(self, x):
# Define the forward pass
pass