Skip to content

Commit abc1bff

Browse files
ChengYaowenxie-amd
authored andcommitted
fix(zero-bubble): fix te-backend fp8 bugs
1 parent 3c68e4f commit abc1bff

File tree

1 file changed

+7
-0
lines changed
  • primus/backends/megatron/core/pipeline_parallel/zerobubble

1 file changed

+7
-0
lines changed

primus/backends/megatron/core/pipeline_parallel/zerobubble/runtime.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator
2626
from megatron.core.pipeline_parallel.schedules import (
2727
backward_step,
28+
check_first_val_step,
2829
deallocate_output_tensor,
2930
forward_step,
3031
get_tensor_shapes,
@@ -84,6 +85,8 @@ class TrainingIterationConfig:
8485
recv_tensor_shapes: List
8586
send_tensor_shapes: List
8687

88+
first_val_step: Optional[bool] = None
89+
8790

8891
class SpQueue:
8992
"""A queue of a stack"""
@@ -548,6 +551,9 @@ def schedule_f_impl(self, scheduled_node: ScheduledNode):
548551
conf.config,
549552
conf.collect_non_loss_data,
550553
checkpoint_activations_microbatch=None,
554+
is_first_microbatch=check_first_val_step(
555+
conf.first_val_step, conf.forward_only, scheduled_node.microbatch == 0
556+
),
551557
vp_stage=vp_stage,
552558
is_last_stage=is_last_stage,
553559
current_microbatch=scheduled_node.microbatch,
@@ -1142,6 +1148,7 @@ def multi_no_sync():
11421148
tensor_shape=tensor_shape,
11431149
recv_tensor_shapes=recv_tensor_shapes,
11441150
send_tensor_shapes=send_tensor_shapes,
1151+
first_val_step=first_val_step,
11451152
)
11461153
return iteration_config
11471154

0 commit comments

Comments
 (0)