Skip to content

Commit 0dee6e0

Browse files
authored
support combined 1f1b (THUDM#565)
1 parent 610cde8 commit 0dee6e0

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

slime/backends/megatron_utils/initialize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def init(args):
8484
torch.backends.cudnn.benchmark = False
8585
torch.use_deterministic_algorithms(True, warn_only=False)
8686

87+
if args.tp_comm_overlap:
88+
from megatron.training.initialize import _initialize_tp_communicators
89+
90+
_initialize_tp_communicators()
91+
8792
if getattr(args, "custom_megatron_init_path", None):
8893
from slime.utils.misc import load_function
8994

slime/backends/megatron_utils/model.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def forward_only(
230230
config = get_model_config(model[0])
231231

232232
def forward_step(
233-
data_iterator: DataIterator, model: GPTModel
233+
data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False
234234
) -> tuple[torch.Tensor, Callable[[torch.Tensor], dict[str, list[torch.Tensor]]]]:
235235
"""Forward step used by Megatron's pipeline engine.
236236
@@ -244,6 +244,8 @@ def forward_step(
244244
to be collected by the engine.
245245
"""
246246

247+
assert not return_schedule_plan, "forward_only step should never return schedule plan"
248+
247249
# Get the batch.
248250
batch = get_batch(data_iterator, ["tokens", "total_lengths", "response_lengths"])
249251
unconcat_tokens = batch["unconcat_tokens"]
@@ -364,7 +366,7 @@ def train_one_step(
364366
custom_before_train_step_hook = load_function(args.custom_megatron_before_train_step_hook_path)
365367
custom_before_train_step_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler)
366368

367-
def forward_step(data_iterator: DataIterator, model: GPTModel) -> tuple[
369+
def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False) -> tuple[
368370
torch.Tensor,
369371
Callable[[torch.Tensor], tuple[torch.Tensor, int, dict[str, torch.Tensor | list[str]]]],
370372
]:
@@ -402,13 +404,25 @@ def forward_step(data_iterator: DataIterator, model: GPTModel) -> tuple[
402404
old_stage = os.environ["ROUTING_REPLAY_STAGE"]
403405
os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward"
404406

405-
output_tensor = model(
406-
input_ids=batch["tokens"],
407-
position_ids=None,
408-
attention_mask=None,
409-
labels=None,
410-
packed_seq_params=batch["packed_seq_params"],
411-
)
407+
if return_schedule_plan:
408+
assert (
409+
args.overlap_moe_expert_parallel_comm
410+
), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
411+
output_tensor = model.build_schedule_plan(
412+
input_ids=batch["tokens"],
413+
position_ids=None,
414+
attention_mask=None,
415+
labels=None,
416+
packed_seq_params=batch["packed_seq_params"],
417+
)
418+
else:
419+
output_tensor = model(
420+
input_ids=batch["tokens"],
421+
position_ids=None,
422+
attention_mask=None,
423+
labels=None,
424+
packed_seq_params=batch["packed_seq_params"],
425+
)
412426

413427
if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
414428
os.environ["ROUTING_REPLAY_STAGE"] = old_stage

0 commit comments

Comments
 (0)