In the context of an array, axes refer to its dimensions:
A 1D array has 1 axis. For example:
python
CopyEdit
a = [0, 1, 2, 3]
Here:
A 2D array has 2 axes. For example:
python
CopyEdit
b = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
]
Here:
A 3D array has 3 axes. For example:
python
CopyEdit
c = [
[[0, 1], [2, 3]],
[[4, 5], [6, 7]]
]
Here:
PartitionSpec
WorksThe PartitionSpec
specifies how each axis of the array will be distributed (sharded) across the axes of the mesh.
Let’s say you have a 1D array a
with 8 elements and a mesh axis named 'i'
(8 devices in total):
python
CopyEdit
a = [0, 1, 2, 3, 4, 5, 6, 7]
The sharding:
python
CopyEdit
sharding = NamedSharding(mesh, PartitionSpec('i',)) # Shard axis 0 over mesh axis 'i'
Here:
'i'
mesh axis.Now let’s consider a 2D array b
with shape (4, 8)
(4 rows, 8 columns):
python
CopyEdit
b = [
[0, 1, 2, 3, 4, 5, 6, 7], # Row 0
[8, 9, 10, 11, 12, 13, 14, 15], # Row 1
[16, 17, 18, 19, 20, 21, 22, 23], # Row 2
[24, 25, 26, 27, 28, 29, 30, 31] # Row 3
]