Skip to content

Commit

Permalink
run domino example on amd
Browse files Browse the repository at this point in the history
Signed-off-by: Hongwei Chen <[email protected]>
  • Loading branch information
Hongwei Chen authored and hwchen2017 committed Feb 14, 2025
1 parent 83757d9 commit 25daf29
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 25 deletions.
6 changes: 5 additions & 1 deletion training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ def parse_args():
'bfloat16 data type.', flush=True)

args.async_tensor_model_parallel_allreduce = True
args.gradient_accumulation_fusion = True
if torch.cuda.is_available() and torch.version.hip:
args.gradient_accumulation_fusion = False
elif torch.cuda.is_available() and torch.version.cuda:
args.gradient_accumulation_fusion = True

args.padded_vocab_size = 0 # tokenizer.py
args.model_type = 1
args.data_parallel_size = 1
Expand Down
22 changes: 0 additions & 22 deletions training/DeepSpeed-Domino/domino/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from domino.modules.fused_func import bias_dropout_add_fused_train
from domino.modules.fused_bias_gelu import bias_gelu

from megatron import fused_kernels


def initialize_domino():
Expand Down Expand Up @@ -111,27 +110,6 @@ def _compile_dependencies():
flush=True,
)

# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)


def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
Expand Down
1 change: 0 additions & 1 deletion training/DeepSpeed-Domino/pretrain_gpt3_6.7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ GPT_ARGS="
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--no-gradient-accumulation-fusion \
--fp16 \
--tensor-model-parallel-size $WORLD_SIZE
"
Expand Down
1 change: 0 additions & 1 deletion training/DeepSpeed-Domino/pretrain_llama_13b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ LLAMA_ARGS="
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--no-gradient-accumulation-fusion \
--fp16 \
--tensor-model-parallel-size $WORLD_SIZE \
--seed 3407 \
Expand Down

0 comments on commit 25daf29

Please sign in to comment.