Skip to content

Commit 056a3b0

Browse files
authored
Add an option to test-pax.sh to enable XLA cuDNN flash attention (#1045)
Provide an option to run XLA cuDNN flash attention as an alternative to TE cuDNN flash attention.
1 parent f116054 commit 056a3b0

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

.github/container/test-pax.sh

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ usage() {
1515
echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py"
1616
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4."
1717
echo " --dtype Batch size, defaults to bfloat16."
18-
echo " --enable-te If set, will run with env var ENABLE_TE=1."
18+
echo " --enable-te If set, will run with env var ENABLE_TE=1."
19+
echo " --enable-cudnn-fa If set, will use cudnn fa."
1920
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
2021
echo " --disable-fused-attn Whether disable TE fused attention."
2122
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
@@ -26,13 +27,13 @@ usage() {
2627
echo " --data-parallel Data parallelism to use. Defaults to 1."
2728
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
2829
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
29-
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
30+
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
3031
echo " -n, --nodes Number of nodes."
3132
echo " -h, --help Print usage."
3233
exit $1
3334
}
3435

35-
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
36+
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
3637
if [[ $? -ne 0 ]]; then
3738
exit $1
3839
fi
@@ -50,6 +51,7 @@ TP=1
5051
PP=1
5152
NODES=1
5253
ENABLE_TE=0
54+
ENABLE_CUDNN_FA=0
5355
MODEL_TYPE=126M
5456
NVTE_FUSED_ATTN=1
5557
DROPOUT=0
@@ -75,6 +77,10 @@ while [ : ]; do
7577
ENABLE_TE=1
7678
shift 1
7779
;;
80+
--enable-cudnn-fa)
81+
ENABLE_CUDNN_FA=1
82+
shift 1
83+
;;
7884
--enable-dropout)
7985
DROPOUT='0.1'
8086
shift 1
@@ -128,7 +134,7 @@ while [ : ]; do
128134
;;
129135
--)
130136
shift;
131-
break
137+
break
132138
;;
133139
*)
134140
echo "UNKNOWN OPTION $1"
@@ -149,6 +155,7 @@ print_var NGPUS
149155
print_var OUTPUT
150156
print_var MULTIPROCESS
151157
print_var ENABLE_TE
158+
print_var ENABLE_CUDNN_FA
152159
print_var NVTE_FUSED_ATTN
153160
print_var EVALUATE
154161
print_var DROPOUT
@@ -196,10 +203,10 @@ if dcn_factor > 1:
196203
if dp % dcn_factor == 0:
197204
dcn_dp = dcn_factor
198205
dp = int(dp / dcn_factor)
199-
elif fsdp % dcn_factor == 0:
206+
elif fsdp % dcn_factor == 0:
200207
dcn_fsdp = dcn_factor
201208
fsdp = int(fsdp / dcn_factor)
202-
elif pp % dcn_factor == 0:
209+
elif pp % dcn_factor == 0:
203210
dcn_pp = dcn_factor
204211
pp = int(pp / dcn_factor)
205212
@@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
209216
USE_REPEATED_LAYER = False
210217
ICI_MESH_SHAPE = [64,1,1]
211218
MAX_STEPS = 600000
212-
219+
213220
MAX_SEQ_LEN = 2048
214221
VOCAB_SIZE = 50304
215222
PACKED_INPUT = True
216223
PERCORE_BATCH_SIZE = 4
217-
224+
218225
NUM_LAYERS = 12
219226
NUM_HEADS = 12
220227
MODEL_DIMS = 768
@@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
223230
224231
TRAINABLE_POSITION_EMB = True
225232
TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN
226-
233+
227234
USE_BIAS = True
228235
LAYERNORM_EPSILON = 1e-5
229236
ATTEN_LOGIT_CAP = -1.0
230237
INIT_STD = 0.023
231238
SOFTMAX_INIT_STD = 0.023
232239
ACTIVATION_CLS = layers.GELU
233-
240+
234241
## optimizer-related
235242
ADAM_BETA1 = 0.9
236243
ADAM_BETA2 = 0.95
@@ -255,15 +262,15 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
255262
## disable eval to avoid including eval
256263
## in steps/sec calculation
257264
EVAL_INTERVAL_STEPS = 100000
258-
265+
259266
def task(self):
260267
task_p = super().task()
261268
task_p = configure_gpt3_task(self, task_p)
262269
263270
task_p.train.num_train_steps = self.MAX_STEPS
264271
265272
model_p = task_p.model
266-
273+
267274
### compute layernorm reductions in fp32. Needed for stable training on GPUs
268275
stacked_p = model_p.lm_tpl.stacked_transformer_tpl
269276
if stacked_p.cls == layers.PipelinedTransformer:
@@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
274281
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
275282
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
276283
task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True
277-
284+
278285
model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
279286
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
280287
model_p.lm_tpl.softmax_tpl.params_init = softmax_init
281-
288+
282289
model_p.apply_eval_sample_weights = True
283-
290+
284291
## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0
285292
stacked_p.dropout_prob = 0.0
286293
stacked_p.input_dropout_prob = self.DROPOUT_PROB
@@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset):
316323
if pp > 1:
317324
@experiment_registry.register
318325
class Synthetic126MCI(GPT126MPP, SyntheticDataset):
319-
326+
320327
ICI_MESH_SHAPE = [pp, dp, fsdp, tp]
321328
DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1]
322329
MICROBATCH_SIZE = 2
323330
NUM_STAGES = pp
324331
PERCORE_BATCH_SIZE = percore_batch_size
325332
FRPOP_DTYPE = dtype
326-
333+
327334
def task(self):
328335
task_p = super().task()
329336
task_p.train.always_use_train_for_model_init=False
@@ -333,7 +340,7 @@ if pp > 1:
333340
else:
334341
@experiment_registry.register
335342
class Synthetic126MCI(Synthetic126M):
336-
343+
337344
ICI_MESH_SHAPE = [dp, fsdp, tp]
338345
DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
339346
PERCORE_BATCH_SIZE = percore_batch_size
@@ -343,7 +350,7 @@ else:
343350
344351
## disable eval
345352
EVAL_INTERVAL_STEPS = 100000
346-
353+
347354
def task(self):
348355
task_p = super().task()
349356
@@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE
374381
export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN
375382
export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model}
376383

384+
if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then
385+
ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True"
386+
fi
387+
377388
if [[ ${MODEL_TYPE} == "126M" ]]; then
378389
CONFIG=ci_configs.Synthetic126MCI
379390
elif [[ ${MODEL_TYPE} == "5B" ]]; then

0 commit comments

Comments
 (0)