Understanding Axes in Arrays

In the context of an array, axes refer to its dimensions:


How PartitionSpec Works

The PartitionSpec specifies how each axis of the array will be distributed (sharded) across the axes of the mesh.

Example: 1D Array

Let’s say you have a 1D array a with 8 elements and a mesh axis named 'i' (8 devices in total):

python
CopyEdit
a = [0, 1, 2, 3, 4, 5, 6, 7]

The sharding:

python
CopyEdit
sharding = NamedSharding(mesh, PartitionSpec('i',))  # Shard axis 0 over mesh axis 'i'

Here:


Example: 2D Array

Now let’s consider a 2D array b with shape (4, 8) (4 rows, 8 columns):

python
CopyEdit
b = [
    [0, 1, 2, 3, 4, 5, 6, 7],  # Row 0
    [8, 9, 10, 11, 12, 13, 14, 15],  # Row 1
    [16, 17, 18, 19, 20, 21, 22, 23],  # Row 2
    [24, 25, 26, 27, 28, 29, 30, 31]   # Row 3
]