Skip to content

Commit 5b17410

Browse files
committed
Build standard mesh for neuron backend
- Switch to a new mesh for neuron-(trn2|trn2n).48xlarge-64 with better scale-out performance.
1 parent 162fdd0 commit 5b17410

File tree

11 files changed

+94
-21
lines changed

11 files changed

+94
-21
lines changed

axlearn/common/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,13 +1743,15 @@ def create_device_mesh(
17431743
assert num_devices % num_granules == 0, "Number of devices should divide number of granules."
17441744
num_devices_per_granule = num_devices // num_granules
17451745

1746-
# Fallback to a standard mesh if on GPU with incompatible multi-granule mesh.
1746+
# Fallback to a standard mesh if on GPU or neuron with incompatible multi-granule mesh.
17471747
if (
1748-
device_platform == "gpu"
1748+
device_platform in ("gpu", "neuron")
17491749
and isinstance(mesh_shape, MeshShape)
17501750
and mesh_shape[0] % num_granules != 0
17511751
):
1752-
logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.")
1752+
logging.warning(
1753+
"Falling back to ICI-only mesh on %s, performance may be reduced.", device_platform
1754+
)
17531755
return build_standard_mesh(mesh_shape, devices=devices)
17541756

17551757
# Canonicalize to HybridMeshShape. If DCN mesh is not specified, break the first non-singleton

axlearn/common/utils_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,76 @@ def test_create_device_mesh_gpu(
18581858
device_mesh = create_device_mesh(mesh_shape=logical_mesh, devices=devices)
18591859
self.assertEqual(expected or logical_mesh, device_mesh.shape)
18601860

1861+
@parameterized.parameters(
1862+
{"logical_mesh": (16, 4, 8)},
1863+
{"logical_mesh": (64, 8)},
1864+
# Test fallback to standard mesh.
1865+
{"logical_mesh": (16, 32)},
1866+
# Test a case where we infer -1 in ICI mesh.
1867+
{"logical_mesh": (8, -1, 4), "expected": (8, 16, 4)},
1868+
# Test a case where we infer -1 in DCN mesh.
1869+
{"logical_mesh": (-1, 16, 4), "expected": (8, 16, 4)},
1870+
# Test a basic hybrid mesh case.
1871+
{
1872+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, 16, 4), dcn_mesh_shape=(8, 1, 1)),
1873+
"expected": (8, 16, 4),
1874+
},
1875+
# If expressed as a hybrid mesh, fail if DCN mesh is invalid rather than using fallback.
1876+
{
1877+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(4, 16, 4), dcn_mesh_shape=(2, 1, 1)),
1878+
"expected": ValueError("DCN mesh"),
1879+
},
1880+
# Test that ICI mesh should respect the number of devices.
1881+
{
1882+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(8, 1, 16), dcn_mesh_shape=(2, -1, 1)),
1883+
"expected": ValueError("Product of ICI"),
1884+
},
1885+
# Test that DCN mesh should respect the number of slices.
1886+
{
1887+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(4, 1, 16), dcn_mesh_shape=(2, 2, 1)),
1888+
"expected": ValueError("Product of DCN"),
1889+
},
1890+
# Test a case where we infer -1 in ICI mesh.
1891+
{
1892+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, 2, -1), dcn_mesh_shape=(8, 1, 1)),
1893+
"expected": (8, 2, 32),
1894+
},
1895+
# Test a case where we infer -1 in DCN mesh.
1896+
{
1897+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, 4, 16), dcn_mesh_shape=(-1, 1, 1)),
1898+
"expected": (8, 4, 16),
1899+
},
1900+
# Test a case where we infer -1 in both ICI and DCN mesh.
1901+
{
1902+
"logical_mesh": HybridMeshShape(ici_mesh_shape=(1, -1, 16), dcn_mesh_shape=(-1, 1, 1)),
1903+
"expected": (8, 4, 16),
1904+
},
1905+
)
1906+
def test_create_device_mesh_neuron(
1907+
self,
1908+
logical_mesh: Union[MeshShape, HybridMeshShape],
1909+
expected: Optional[Union[MeshShape, Exception]] = None,
1910+
):
1911+
num_devices_per_process = 64
1912+
num_granules = 8
1913+
devices = [
1914+
DummyDevice(
1915+
platform="neuron",
1916+
device_kind="NC_v3d",
1917+
process_index=(num_devices_per_process * granule_index + ix)
1918+
// num_devices_per_process,
1919+
)
1920+
for ix in range(num_devices_per_process)
1921+
for granule_index in range(num_granules)
1922+
]
1923+
if isinstance(expected, Exception):
1924+
with self.assertRaisesRegex(type(expected), str(expected)):
1925+
create_device_mesh(mesh_shape=logical_mesh, devices=devices)
1926+
else:
1927+
# Check that the constructed mesh has the expected shape.
1928+
device_mesh = create_device_mesh(mesh_shape=logical_mesh, devices=devices)
1929+
self.assertEqual(expected or logical_mesh, device_mesh.shape)
1930+
18611931

18621932
class InferMeshShapeTest(TestCase):
18631933
"""Tests infer_mesh_shape."""

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModif
225225
mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64'
226226
mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
227227
mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1
228-
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1
228+
mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: -1
229229
mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1
230-
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1
230+
mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: 128
231231
mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1
232232
mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4
233233
mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'

axlearn/experiments/text/gpt/fuji.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,10 @@ def get_trainer_kwargs(
796796
ChainConfigModifier.default_config().set(
797797
config_modifiers=[
798798
MeshShapeModifier.default_config().set(
799-
# TP within the chip, FSDP across chips.
799+
# TP within the chip, FSDP across 8 nodes and Data parallel
800+
# replication across replicas.
800801
# Each TRN2 chip has 4 XLA cores.
801-
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
802+
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=128, model=4)
802803
),
803804
RematSpecModifier.default_config().set(
804805
remat_policies={

0 commit comments

Comments
 (0)