@@ -13,7 +13,7 @@ usage() {
1313 echo " Usage: $0 [OPTIONS]"
1414 echo " "
1515 echo " OPTIONS DESCRIPTION"
16- echo " -a, --additional-args Additional args to pass to MaxText/train.py"
16+ echo " -a, --additional-args Additional args to pass to MaxText/train.py. Can be passed many times. "
1717 echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65" . Default to 0.90, contradicting JAX default of 0.75.
1818 echo " --model-name Specify the model names to run [Preferred]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models:
1919 gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b"
@@ -34,7 +34,7 @@ usage() {
3434 1. test-maxtext.sh -b 2 --model-name=gpt3-52k
3535 2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3636 3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess
37- 4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false
37+ 4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a " scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false"
3838 5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess
3939 6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
4040 7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
@@ -76,7 +76,7 @@ eval set -- "$args"
7676while [ : ]; do
7777 case " $1 " in
7878 -a | --additional-args)
79- ADDITIONAL_ARGS=" $2 "
79+ ADDITIONAL_ARGS=" $ADDITIONAL_ARGS $ 2"
8080 shift 2
8181 ;;
8282 --mem-fraction)
@@ -245,22 +245,58 @@ RUN_NAME="logdir" ## the RUN_NAME cannot be changed
245245if [ -z " $DECODER_BLOCK " ]; then
246246
247247 # this part could be used to test different model ootb
248- RUN_SETTINGS=" MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} model_name=${MODEL} \
249- steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} remat_policy=${REMAT_POLICY} enable_checkpointing=false\
250- base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE \
251- dcn_fsdp_parallelism=$dcn_FSDP ici_fsdp_parallelism=$ici_FSDP \
252- ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP \
253- ici_tensor_parallelism=$ici_TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS} "
254-
248+ RUN_SETTINGS=" MaxText/train.py \
249+ MaxText/configs/base.yml \
250+ run_name=${RUN_NAME} \
251+ model_name=${MODEL} \
252+ steps=${STEPS} \
253+ per_device_batch_size=${BATCH_PER_GPU} \
254+ remat_policy=${REMAT_POLICY} \
255+ enable_checkpointing=false\
256+ base_output_directory=${OUTPUT} \
257+ dataset_path=local \
258+ dataset_type=synthetic \
259+ hardware=${HARDWARE} \
260+ enable_goodput_recording=false \
261+ monitor_goodput=false \
262+ dcn_fsdp_parallelism=${dcn_FSDP} \
263+ ici_fsdp_parallelism=${ici_FSDP} \
264+ ici_data_parallelism=${ici_DP} \
265+ dcn_data_parallelism=${dcn_DP} \
266+ ici_tensor_parallelism=${ici_TP} \
267+ dcn_tensor_parallelism=1 \
268+ ${ADDITIONAL_ARGS} "
255269else
256270 # this is essentially used for CI run
257- RUN_SETTINGS=" MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=${DECODER_BLOCK} \
258- steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} base_emb_dim=2560 base_mlp_dim=8192 remat_policy=${REMAT_POLICY} attention=${ATTN_TYPE} \
259- base_num_query_heads=8 base_num_kv_heads=8 base_num_decoder_layers=8 head_dim=128 enable_checkpointing=false\
260- base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE \
261- dcn_fsdp_parallelism=$dcn_FSDP ici_fsdp_parallelism=$ici_FSDP \
262- ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP \
263- ici_tensor_parallelism=$ici_TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS} "
271+ RUN_SETTINGS=" MaxText/train.py \
272+ MaxText/configs/base.yml \
273+ run_name=${RUN_NAME} \
274+ decoder_block=${DECODER_BLOCK} \
275+ steps=$STEPS \
276+ per_device_batch_size=${BATCH_PER_GPU} \
277+ base_emb_dim=2560 \
278+ base_mlp_dim=8192 \
279+ remat_policy=${REMAT_POLICY} \
280+ attention=${ATTN_TYPE} \
281+ base_num_query_heads=8 \
282+ base_num_kv_heads=8 \
283+ base_num_decoder_layers=8 \
284+ head_dim=128 \
285+ logits_via_embedding=true \
286+ enable_checkpointing=false \
287+ base_output_directory=${OUTPUT} \
288+ dataset_path=local \
289+ dataset_type=synthetic \
290+ hardware=${HARDWARE} \
291+ enable_goodput_recording=false \
292+ monitor_goodput=false \
293+ dcn_fsdp_parallelism=${dcn_FSDP} \
294+ ici_fsdp_parallelism=${ici_FSDP} \
295+ ici_data_parallelism=${ici_DP} \
296+ dcn_data_parallelism=${dcn_DP} \
297+ ici_tensor_parallelism=${ici_TP} \
298+ dcn_tensor_parallelism=1 \
299+ ${ADDITIONAL_ARGS} "
264300fi
265301
266302echo " Command: python3 $RUN_SETTINGS "
0 commit comments