@@ -15,7 +15,8 @@ usage() {
1515 echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py"
1616 echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4."
1717 echo " --dtype Batch size, defaults to bfloat16."
18- echo " --enable-te If set, will run with env var ENABLE_TE=1."
18+ echo " --enable-te If set, will run with env var ENABLE_TE=1."
19+ echo " --enable-cudnn-fa If set, will use cudnn fa."
1920 echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
2021 echo " --disable-fused-attn Whether disable TE fused attention."
2122 echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
@@ -26,13 +27,13 @@ usage() {
2627 echo " --data-parallel Data parallelism to use. Defaults to 1."
2728 echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
2829 echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
29- echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
30+ echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
3031 echo " -n, --nodes Number of nodes."
3132 echo " -h, --help Print usage."
3233 exit $1
3334}
3435
35- args=$( getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- " $@ " )
36+ args=$( getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable- dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- " $@ " )
3637if [[ $? -ne 0 ]]; then
3738 exit $1
3839fi
5051PP=1
5152NODES=1
5253ENABLE_TE=0
54+ ENABLE_CUDNN_FA=0
5355MODEL_TYPE=126M
5456NVTE_FUSED_ATTN=1
5557DROPOUT=0
@@ -75,6 +77,10 @@ while [ : ]; do
7577 ENABLE_TE=1
7678 shift 1
7779 ;;
80+ --enable-cudnn-fa)
81+ ENABLE_CUDNN_FA=1
82+ shift 1
83+ ;;
7884 --enable-dropout)
7985 DROPOUT=' 0.1'
8086 shift 1
@@ -128,7 +134,7 @@ while [ : ]; do
128134 ;;
129135 --)
130136 shift ;
131- break
137+ break
132138 ;;
133139 * )
134140 echo " UNKNOWN OPTION $1 "
@@ -149,6 +155,7 @@ print_var NGPUS
149155print_var OUTPUT
150156print_var MULTIPROCESS
151157print_var ENABLE_TE
158+ print_var ENABLE_CUDNN_FA
152159print_var NVTE_FUSED_ATTN
153160print_var EVALUATE
154161print_var DROPOUT
@@ -196,10 +203,10 @@ if dcn_factor > 1:
196203 if dp % dcn_factor == 0:
197204 dcn_dp = dcn_factor
198205 dp = int(dp / dcn_factor)
199- elif fsdp % dcn_factor == 0:
206+ elif fsdp % dcn_factor == 0:
200207 dcn_fsdp = dcn_factor
201208 fsdp = int(fsdp / dcn_factor)
202- elif pp % dcn_factor == 0:
209+ elif pp % dcn_factor == 0:
203210 dcn_pp = dcn_factor
204211 pp = int(pp / dcn_factor)
205212
@@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
209216 USE_REPEATED_LAYER = False
210217 ICI_MESH_SHAPE = [64,1,1]
211218 MAX_STEPS = 600000
212-
219+
213220 MAX_SEQ_LEN = 2048
214221 VOCAB_SIZE = 50304
215222 PACKED_INPUT = True
216223 PERCORE_BATCH_SIZE = 4
217-
224+
218225 NUM_LAYERS = 12
219226 NUM_HEADS = 12
220227 MODEL_DIMS = 768
@@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
223230
224231 TRAINABLE_POSITION_EMB = True
225232 TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN
226-
233+
227234 USE_BIAS = True
228235 LAYERNORM_EPSILON = 1e-5
229236 ATTEN_LOGIT_CAP = -1.0
230237 INIT_STD = 0.023
231238 SOFTMAX_INIT_STD = 0.023
232239 ACTIVATION_CLS = layers.GELU
233-
240+
234241 ## optimizer-related
235242 ADAM_BETA1 = 0.9
236243 ADAM_BETA2 = 0.95
@@ -255,15 +262,15 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
255262 ## disable eval to avoid including eval
256263 ## in steps/sec calculation
257264 EVAL_INTERVAL_STEPS = 100000
258-
265+
259266 def task(self):
260267 task_p = super().task()
261268 task_p = configure_gpt3_task(self, task_p)
262269
263270 task_p.train.num_train_steps = self.MAX_STEPS
264271
265272 model_p = task_p.model
266-
273+
267274 ### compute layernorm reductions in fp32. Needed for stable training on GPUs
268275 stacked_p = model_p.lm_tpl.stacked_transformer_tpl
269276 if stacked_p.cls == layers.PipelinedTransformer:
@@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
274281 transformer_layer_p.ln_tpl.reductions_in_fp32 = True
275282 transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
276283 task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True
277-
284+
278285 model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
279286 softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
280287 model_p.lm_tpl.softmax_tpl.params_init = softmax_init
281-
288+
282289 model_p.apply_eval_sample_weights = True
283-
290+
284291 ## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0
285292 stacked_p.dropout_prob = 0.0
286293 stacked_p.input_dropout_prob = self.DROPOUT_PROB
@@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset):
316323if pp > 1:
317324 @experiment_registry.register
318325 class Synthetic126MCI(GPT126MPP, SyntheticDataset):
319-
326+
320327 ICI_MESH_SHAPE = [pp, dp, fsdp, tp]
321328 DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1]
322329 MICROBATCH_SIZE = 2
323330 NUM_STAGES = pp
324331 PERCORE_BATCH_SIZE = percore_batch_size
325332 FRPOP_DTYPE = dtype
326-
333+
327334 def task(self):
328335 task_p = super().task()
329336 task_p.train.always_use_train_for_model_init=False
@@ -333,7 +340,7 @@ if pp > 1:
333340else:
334341 @experiment_registry.register
335342 class Synthetic126MCI(Synthetic126M):
336-
343+
337344 ICI_MESH_SHAPE = [dp, fsdp, tp]
338345 DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
339346 PERCORE_BATCH_SIZE = percore_batch_size
@@ -343,7 +350,7 @@ else:
343350
344351 ## disable eval
345352 EVAL_INTERVAL_STEPS = 100000
346-
353+
347354 def task(self):
348355 task_p = super().task()
349356
@@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE
374381export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN
375382export VOCAB_PATH=${VOCAB_PATH:- gs:// t5-data/ vocabs/ cc_all.32000.100extra/ sentencepiece.model}
376383
384+ if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then
385+ ADDITIONAL_ARGS=" ${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True"
386+ fi
387+
377388if [[ ${MODEL_TYPE} == " 126M" ]]; then
378389 CONFIG=ci_configs.Synthetic126MCI
379390elif [[ ${MODEL_TYPE} == " 5B" ]]; then
0 commit comments