Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify domino example to run on AMD GPU #958

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ def parse_args():
'bfloat16 data type.', flush=True)

args.async_tensor_model_parallel_allreduce = True
args.gradient_accumulation_fusion = True
if torch.cuda.is_available() and torch.version.hip:
args.gradient_accumulation_fusion = False
elif torch.cuda.is_available() and torch.version.cuda:
args.gradient_accumulation_fusion = True

args.padded_vocab_size = 0 # tokenizer.py
args.model_type = 1
args.data_parallel_size = 1
Expand Down
22 changes: 0 additions & 22 deletions training/DeepSpeed-Domino/domino/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from domino.modules.fused_func import bias_dropout_add_fused_train
from domino.modules.fused_bias_gelu import bias_gelu

from megatron import fused_kernels


def initialize_domino():
Expand Down Expand Up @@ -111,27 +110,6 @@ def _compile_dependencies():
flush=True,
)

# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)


def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
Expand Down
1 change: 0 additions & 1 deletion training/DeepSpeed-Domino/pretrain_gpt3_6.7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ GPT_ARGS="
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--no-gradient-accumulation-fusion \
--fp16 \
--tensor-model-parallel-size $WORLD_SIZE
"
Expand Down
1 change: 0 additions & 1 deletion training/DeepSpeed-Domino/pretrain_llama_13b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ LLAMA_ARGS="
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--no-gradient-accumulation-fusion \
--fp16 \
--tensor-model-parallel-size $WORLD_SIZE \
--seed 3407 \
Expand Down
Loading