Skip to content

Commit 7aaff6d

Browse files
author
Hongwei Chen
committed
run domino example on amd
1 parent 83757d9 commit 7aaff6d

File tree

4 files changed

+5
-25
lines changed

4 files changed

+5
-25
lines changed

training/DeepSpeed-Domino/domino/arguments.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,11 @@ def parse_args():
294294
'bfloat16 data type.', flush=True)
295295

296296
args.async_tensor_model_parallel_allreduce = True
297-
args.gradient_accumulation_fusion = True
297+
if torch.cuda.is_available() and torch.version.hip:
298+
args.gradient_accumulation_fusion = False
299+
elif torch.cuda.is_available() and torch.version.cuda:
300+
args.gradient_accumulation_fusion = True
301+
298302
args.padded_vocab_size = 0 # tokenizer.py
299303
args.model_type = 1
300304
args.data_parallel_size = 1

training/DeepSpeed-Domino/domino/initialize.py

-22
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from domino.modules.fused_func import bias_dropout_add_fused_train
1414
from domino.modules.fused_bias_gelu import bias_gelu
1515

16-
from megatron import fused_kernels
1716

1817

1918
def initialize_domino():
@@ -111,27 +110,6 @@ def _compile_dependencies():
111110
flush=True,
112111
)
113112

114-
# Always build on rank zero first.
115-
if torch.distributed.get_rank() == 0:
116-
start_time = time.time()
117-
print("> compiling and loading fused kernels ...", flush=True)
118-
fused_kernels.load(args)
119-
torch.distributed.barrier()
120-
else:
121-
torch.distributed.barrier()
122-
fused_kernels.load(args)
123-
# Simple barrier to make sure all ranks have passed the
124-
# compilation phase successfully before moving on to the
125-
# rest of the program. We think this might ensure that
126-
# the lock is released.
127-
torch.distributed.barrier()
128-
if torch.distributed.get_rank() == 0:
129-
print(
130-
">>> done with compiling and loading fused kernels. "
131-
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
132-
flush=True,
133-
)
134-
135113

136114
def set_jit_fusion_options():
137115
"""Set PyTorch JIT layer fusion options."""

training/DeepSpeed-Domino/pretrain_gpt3_6.7b.sh

-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ GPT_ARGS="
4545
--weight-decay 1e-2 \
4646
--lr-warmup-fraction .01 \
4747
--clip-grad 1.0 \
48-
--no-gradient-accumulation-fusion \
4948
--fp16 \
5049
--tensor-model-parallel-size $WORLD_SIZE
5150
"

training/DeepSpeed-Domino/pretrain_llama_13b.sh

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ LLAMA_ARGS="
5353
--weight-decay 1e-2 \
5454
--lr-warmup-fraction .01 \
5555
--clip-grad 1.0 \
56-
--no-gradient-accumulation-fusion \
5756
--fp16 \
5857
--tensor-model-parallel-size $WORLD_SIZE \
5958
--seed 3407 \

0 commit comments

Comments
 (0)