Skip to content

Commit 2192b18

Browse files
JoelWeecopybara-github
authored andcommitted
Introduce types.py helper
PiperOrigin-RevId: 801748248
1 parent b54fab5 commit 2192b18

4 files changed

Lines changed: 727 additions & 21 deletions

File tree

shardy/integrations/python/jax/mpmd/BUILD

Lines changed: 0 additions & 21 deletions
This file was deleted.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 The MPMD Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Defines options for MPMD partitioning."""
17+
18+
MPMD_BOOLEAN_OPTIONS = frozenset({
19+
'mpmd_infer_transfers',
20+
'mpmd_infer_cross_mesh_reductions',
21+
'mpmd_merge_inferred_with_cloning_during_import',
22+
'mpmd_gspmd_propagate_sharding_across_meshes',
23+
'mpmd_allow_intra_mesh_transfer',
24+
'mpmd_fragment_remat',
25+
'mpmd_merge_remat_fragments',
26+
'mpmd_split_bwd_fragments',
27+
'mpmd_assume_homogeneous_devices',
28+
'mpmd_absorb_inferred_fragments_on_entry_point_function',
29+
'mpmd_copy_constant_creation_from_producer_to_consumer',
30+
'mpmd_apply_merge_transfers_pass',
31+
'mpmd_merge_after_scheduling',
32+
})
33+
34+
MPMD_PIPELINE_SCHEDULE_OPTION = 'mpmd_pipeline_schedule'
35+
36+
MPMD_PIPELINE_SCHEDULE_VALUES = frozenset({
37+
'None',
38+
'1F1B',
39+
'GPipe',
40+
'Circular',
41+
'CircularWithReversedBackward',
42+
'GPipeBut1F1BForLastMesh',
43+
'ZeroBubbleH1',
44+
'ZeroBubbleH2ZeroTxLatency',
45+
'ZeroBubbleH2HalfTxLatency',
46+
'ZeroBubbleH2FullTxLatency',
47+
'ParallelPipelinesWithWrapAround',
48+
})
49+
50+
MPMD_OPTIONS = MPMD_BOOLEAN_OPTIONS | frozenset({MPMD_PIPELINE_SCHEDULE_OPTION})

0 commit comments

Comments
 (0)