Skip to content

Commit 9f1820b

Browse files
Merge pull request #1905 from AI-Hypercomputer:carlosbus/v6e_small_cluster_recipes_llama_3_1
PiperOrigin-RevId: 781565589
2 parents 3816409 + db60173 commit 9f1820b

File tree

3 files changed

+241
-2
lines changed

3 files changed

+241
-2
lines changed

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,51 @@
867867
),
868868
)
869869

870+
# Config for v6e-64
871+
llama3_1_8b_8192_bs5 = _add_to_model_dictionary(
872+
trillium_model_dict,
873+
MaxTextModel(
874+
model_name="llama3_1-8b-8192-bs5",
875+
model_type="llama3.1-8b",
876+
tuning_params={
877+
"per_device_batch_size": 5,
878+
"ici_fsdp_parallelism": -1,
879+
"remat_policy": "custom",
880+
"decoder_layer_input": "offload",
881+
"out_proj": "offload",
882+
"query_proj": "offload",
883+
"key_proj": "offload",
884+
"value_proj": "offload",
885+
"max_target_length": 8192,
886+
"attention": "flash",
887+
"use_iota_embed": True,
888+
"dataset_path": "gs://max-datasets-rogue",
889+
"dataset_type": "synthetic",
890+
"enable_checkpointing": False,
891+
"sa_block_q": 2048,
892+
"sa_block_kv": 2048,
893+
"sa_block_kv_compute": 2048,
894+
"sa_block_q_dkv": 2048,
895+
"sa_block_kv_dkv": 2048,
896+
"sa_block_kv_dkv_compute": 2048,
897+
"sa_block_q_dq": 2048,
898+
"sa_block_kv_dq": 2048,
899+
"sa_use_fused_bwd_kernel": True,
900+
"profiler": "xplane",
901+
"skip_first_n_steps_for_profiler": 10,
902+
"profiler_steps": 5,
903+
},
904+
xla_flags=(
905+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
906+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
907+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
908+
+ xla_flags_library.CF_FOR_ALL_GATHER
909+
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE
910+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
911+
),
912+
),
913+
)
914+
870915

871916
llama3_1_8b_8192_no_collective_matmul = _add_to_model_dictionary(
872917
trillium_model_dict,
@@ -956,6 +1001,137 @@
9561001
),
9571002
)
9581003

1004+
# Config for v6e-64
1005+
llama3_1_70b_8192_bs2 = _add_to_model_dictionary(
1006+
trillium_model_dict,
1007+
MaxTextModel(
1008+
model_name="llama3_1-70b-8192-bs2",
1009+
model_type="llama3.1-70b",
1010+
tuning_params={
1011+
"per_device_batch_size": 2,
1012+
"ici_fsdp_parallelism": -1,
1013+
"remat_policy": "custom",
1014+
"decoder_layer_input": "offload",
1015+
"query_proj": "offload",
1016+
"key_proj": "offload",
1017+
"value_proj": "offload",
1018+
"max_target_length": 8192,
1019+
"attention": "flash",
1020+
"use_iota_embed": True,
1021+
"dataset_path": "gs://max-datasets-rogue",
1022+
"dataset_type": "synthetic",
1023+
"enable_checkpointing": False,
1024+
"sa_block_q": 2048,
1025+
"sa_block_kv": 2048,
1026+
"sa_block_kv_compute": 2048,
1027+
"sa_block_q_dkv": 2048,
1028+
"sa_block_kv_dkv": 2048,
1029+
"sa_block_kv_dkv_compute": 2048,
1030+
"sa_block_q_dq": 2048,
1031+
"sa_block_kv_dq": 2048,
1032+
"sa_use_fused_bwd_kernel": True,
1033+
"profiler": "xplane",
1034+
"skip_first_n_steps_for_profiler": 10,
1035+
"profiler_steps": 5,
1036+
},
1037+
xla_flags=(
1038+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
1039+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
1040+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
1041+
+ xla_flags_library.CF_FOR_ALL_GATHER
1042+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
1043+
),
1044+
),
1045+
)
1046+
1047+
# Config for v6e-32
1048+
llama3_1_70b_8192_bs2_bfloat16_no_collective_matmul = _add_to_model_dictionary(
1049+
trillium_model_dict,
1050+
MaxTextModel(
1051+
model_name="llama3_1-70b-8192-bs2-bfloat16-no-collective-matmul",
1052+
model_type="llama3.1-70b",
1053+
tuning_params={
1054+
"per_device_batch_size": 2,
1055+
"ici_fsdp_parallelism": -1,
1056+
"remat_policy": "custom",
1057+
"decoder_layer_input": "offload",
1058+
"query_proj": "offload",
1059+
"key_proj": "offload",
1060+
"value_proj": "offload",
1061+
"max_target_length": 8192,
1062+
"attention": "flash",
1063+
"use_iota_embed": True,
1064+
"dataset_path": "gs://max-datasets-rogue",
1065+
"dataset_type": "synthetic",
1066+
"enable_checkpointing": False,
1067+
"sa_block_q": 2048,
1068+
"sa_block_kv": 2048,
1069+
"sa_block_kv_compute": 2048,
1070+
"sa_block_q_dkv": 2048,
1071+
"sa_block_kv_dkv": 2048,
1072+
"sa_block_kv_dkv_compute": 2048,
1073+
"sa_block_q_dq": 2048,
1074+
"sa_block_kv_dq": 2048,
1075+
"sa_use_fused_bwd_kernel": True,
1076+
"profiler": "xplane",
1077+
"skip_first_n_steps_for_profiler": 10,
1078+
"profiler_steps": 5,
1079+
"weight_dtype": "bfloat16",
1080+
},
1081+
xla_flags=(
1082+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
1083+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
1084+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
1085+
+ xla_flags_library.CF_FOR_ALL_GATHER
1086+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
1087+
+ xla_flags_library.DISABLE_COLLECTIVE_MATMUL
1088+
),
1089+
),
1090+
)
1091+
1092+
# Config for v6e-128
1093+
llama3_1_70b_8192_bs4 = _add_to_model_dictionary(
1094+
trillium_model_dict,
1095+
MaxTextModel(
1096+
model_name="llama3_1-70b-8192-bs4",
1097+
model_type="llama3.1-70b",
1098+
tuning_params={
1099+
"per_device_batch_size": 4,
1100+
"ici_fsdp_parallelism": -1,
1101+
"remat_policy": "custom",
1102+
"decoder_layer_input": "offload",
1103+
"query_proj": "offload",
1104+
"key_proj": "offload",
1105+
"value_proj": "offload",
1106+
"max_target_length": 8192,
1107+
"attention": "flash",
1108+
"use_iota_embed": True,
1109+
"dataset_path": "gs://max-datasets-rogue",
1110+
"dataset_type": "synthetic",
1111+
"enable_checkpointing": False,
1112+
"sa_block_q": 2048,
1113+
"sa_block_kv": 2048,
1114+
"sa_block_kv_compute": 2048,
1115+
"sa_block_q_dkv": 2048,
1116+
"sa_block_kv_dkv": 2048,
1117+
"sa_block_kv_dkv_compute": 2048,
1118+
"sa_block_q_dq": 2048,
1119+
"sa_block_kv_dq": 2048,
1120+
"sa_use_fused_bwd_kernel": True,
1121+
"profiler": "xplane",
1122+
"skip_first_n_steps_for_profiler": 10,
1123+
"profiler_steps": 5,
1124+
},
1125+
xla_flags=(
1126+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
1127+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
1128+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
1129+
+ xla_flags_library.CF_FOR_ALL_GATHER
1130+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
1131+
),
1132+
),
1133+
)
1134+
9591135
llama3_1_70b_8192_iter_synthetic = _add_to_model_dictionary(
9601136
trillium_model_dict,
9611137
MaxTextModel(

maxtext_jax_ai_image.Dockerfile

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ WORKDIR /deps
1414

1515
# Copy setup files and dependency files separately for better caching
1616
COPY setup.sh ./
17-
COPY requirements.txt requirements_with_jax_ai_image.txt ./
17+
COPY requirements.txt requirements_with_jax_ai_image.txt requirements_with_jax_stable_stack_0_6_1_pipreqs.txt ./
1818

1919

2020
# For JAX AI tpu training images 0.4.37 AND 0.4.35
@@ -34,7 +34,15 @@ RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.de
3434
RUN apt-get update && apt-get install --yes && apt-get install --yes dnsutils
3535
# TODO(bvandermoon, parambole): Remove this when it's added to JAX AI Image
3636
RUN pip install google-cloud-monitoring
37-
RUN python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt
37+
38+
# Install requirements file that was generated with pipreqs for JSS 0.6.1 using:
39+
# pipreqs --savepath requirements_with_jax_stable_stack_0_6_1_pipreqs.txt
40+
# Otherwise use general requirements_with_jax_ai_image.txt
41+
RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.6.1-rev1" ]; then \
42+
python3 -m pip install -r /deps/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt; \
43+
else \
44+
python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt; \
45+
fi
3846

3947
# Now copy the remaining code (source files that may change frequently)
4048
COPY . .
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
absl_py==2.2.2
2+
aqt==25.2.7
3+
benchmark_db_writer==1.0.0.dev20250610
4+
benchmark_db_writer.egg==info
5+
cloud_accelerator_diagnostics==0.1.1
6+
cloud_tpu_diagnostics==0.1.5
7+
datasets==3.6.0
8+
etils==1.12.2
9+
evaluate==0.4.4
10+
flax==0.10.6
11+
grain==0.2.10
12+
grpcio==1.72.0rc1
13+
huggingface_hub==0.33.0
14+
jax==0.6.0
15+
jaxlib==0.6.0 # Manually adding to ensure consistency in future
16+
jaxtyping==0.3.2
17+
jetstream==0.1.0
18+
jsonlines==4.0.0
19+
libtpu==0.0.15 # Manually adding to ensure consistency in future
20+
matplotlib==3.10.3
21+
ml_collections==1.1.0
22+
ml_dtypes==0.5.1
23+
ml_goodput_measurement==0.0.11
24+
nltk==3.9.1
25+
numpy==2.3.1
26+
omegaconf==2.3.0
27+
optax==0.2.5
28+
orbax==0.1.9
29+
pandas==2.3.0
30+
pathwaysutils==0.1.1
31+
Pillow==11.2.1
32+
protobuf==6.31.1
33+
psutil==7.0.0
34+
pytest==8.4.1
35+
PyYAML==6.0.2
36+
PyYAML==6.0.2
37+
Requests==2.32.4
38+
safetensors==0.5.3
39+
sentencepiece==0.2.0
40+
setuptools==80.9.0
41+
tabulate==0.9.0
42+
tensorboard_plugin_profile==2.13.0
43+
tensorboardX==2.6.2.2
44+
tensorboardX==2.6.4
45+
tensorflow==2.19.0
46+
tensorflow_datasets==4.9.9
47+
tensorflow_text==2.19.0
48+
tensorstore==0.1.75
49+
tiktoken==0.9.0
50+
torch==2.7.1
51+
tqdm==4.67.1
52+
transformer_engine==2.4.0
53+
transformers==4.52.4
54+
trl==0.19.0
55+
urllib3==2.5.0

0 commit comments

Comments
 (0)