Skip to content

Commit 700fa2e

Browse files
authored
feat(zero-bubble): reorder comm-nodes for batch-p2p (#257)
1 parent 149e668 commit 700fa2e

File tree

3 files changed

+83
-7
lines changed

3 files changed

+83
-7
lines changed

primus/backends/megatron/core/optimizer/zbpp_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ class ZeroBubblePPChainedOptimizer(ChainedOptimizer):
2525
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
2626
super().__init__(chained_optimizers)
2727

28-
self.partial_reduced_total_norm = torch.FloatTensor([0])
28+
self.partial_reduced_total_norm = torch.zeros([0], dtype=torch.float, device="cuda")
2929
self.local_total_norm = None
30-
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
31-
self.zero_float_tensor = torch.cuda.FloatTensor([0])
30+
self.dummy_overflow_buf = torch.zeros([0], dtype=torch.int, device="cuda")
31+
self.zero_float_tensor = torch.zeros([0], dtype=torch.float, device="cuda")
3232
self.parameters_backup = None
3333
self.do_prev_step = False
3434
self.do_this_step = False

primus/backends/megatron/core/pipeline_parallel/zerobubble/runtime.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,8 +1327,8 @@ def fused_pipeline_ops(
13271327
ops.append(recv_next_op)
13281328
if len(ops) > 0:
13291329
reqs = torch.distributed.batch_isend_irecv(ops)
1330+
13301331
# batch_isend_irecv only returns 1 handle
1331-
assert len(reqs) == 1
13321332
r = reqs[0]
13331333
# Keep the returned value consistent with p2p_pipeline_ops
13341334
sp_reqs = [r] * len(tensor_send_prev)
@@ -1341,6 +1341,12 @@ def fused_pipeline_ops(
13411341
return reqs, (sp_reqs, rp_reqs, sn_reqs, rn_reqs)
13421342

13431343

1344+
class HackReq:
1345+
"""Class to hack async p2p request because the async p2p performance bad"""
1346+
1347+
def wait(): ...
1348+
1349+
13441350
def multi_pipeline_ops(
13451351
tensor_send_prev: List[torch.Tensor],
13461352
tensor_recv_prev: List[torch.Tensor],
@@ -1353,14 +1359,33 @@ def multi_pipeline_ops(
13531359
p2p_func = fused_pipeline_ops
13541360
else:
13551361
p2p_func = p2p_pipeline_ops
1356-
return p2p_func(
1362+
1363+
reqs = p2p_func(
13571364
tensor_send_prev=tensor_send_prev,
13581365
tensor_recv_prev=tensor_recv_prev,
13591366
tensor_send_next=tensor_send_next,
13601367
tensor_recv_next=tensor_recv_next,
13611368
group=group,
13621369
)
13631370

1371+
if batch:
1372+
hack_req = HackReq()
1373+
hack_reqs = []
1374+
1375+
real_reqs, all_tensor_reqs = reqs
1376+
for req in real_reqs:
1377+
req.wait()
1378+
hack_reqs.append(hack_req)
1379+
1380+
torch.cuda.synchronize()
1381+
for tensor_reqs in all_tensor_reqs:
1382+
for req in tensor_reqs:
1383+
req = hack_req
1384+
1385+
reqs = (hack_reqs, all_tensor_reqs)
1386+
1387+
return reqs
1388+
13641389

13651390
def bootstrap_and_profile_p2p_communication(config, send_tensor_shapes, recv_tensor_shapes, p2p_communicator):
13661391
# When we fuse some send-recv communication ops in a device and can't fuse on other devices
@@ -1435,8 +1460,10 @@ def bootstrap_and_profile_p2p_communication(config, send_tensor_shapes, recv_ten
14351460
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
14361461
p2p_communicator.send_backward(recv_data, False)
14371462
t.stop()
1438-
per_communication = torch.cuda.FloatTensor(
1439-
[t.elapsed() / (parallel_state.get_pipeline_model_parallel_world_size() - 1) / 2 / 10]
1463+
per_communication = torch.tensor(
1464+
[t.elapsed() / (parallel_state.get_pipeline_model_parallel_world_size() - 1) / 2 / 10],
1465+
dtype=torch.float,
1466+
device="cuda",
14401467
)
14411468
torch.distributed.all_reduce(per_communication, torch.distributed.ReduceOp.MAX)
14421469
ScheduleTimers.comm_time = per_communication.item()

primus/backends/megatron/core/pipeline_parallel/zerobubble/scheduler/communication.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import math
1313
from typing import List, Tuple
1414

15+
from megatron.training.global_vars import get_args
16+
1517
from primus.modules.module_utils import log_rank_all
1618

1719
from .graph import BW, B, CommDirection, F, FuncType, GraphConfig, ScheduledNode
@@ -225,20 +227,67 @@ def add_post_validation_nodes_before_deadline(
225227
return local_order, comm_pairs
226228

227229

230+
def reorder_communication_nodes(local_order: List[List[ScheduledNode]]):
231+
"""reorder communication nodes to combine them with batch"""
232+
recordered_w_list = []
233+
234+
def ismatch(recv_node, node):
235+
return (recv_node.type == FuncType.RECV_FORWARD and node.type == FuncType.F) or (
236+
recv_node.type == FuncType.RECV_BACKWARD and node.type == FuncType.B
237+
)
238+
239+
for stage in local_order:
240+
stage_list = []
241+
w_list = []
242+
recv_list = []
243+
for node in stage:
244+
if node.type == FuncType.W:
245+
w_list.append(node)
246+
elif node.type in (F, B, BW):
247+
for i in range(len(recv_list)):
248+
if (
249+
recv_list[i].microbatch == node.microbatch
250+
and recv_list[i].chunk == node.chunk
251+
and ismatch(recv_list[i], node)
252+
):
253+
recv_i = recv_list[i]
254+
stage_list.append(recv_i)
255+
stage_list.extend(w_list)
256+
w_list = []
257+
stage_list.append(node)
258+
259+
elif node.type in (FuncType.RECV_FORWARD, FuncType.RECV_BACKWARD):
260+
recv_list.append(node)
261+
else: # communication nodes
262+
stage_list.append(node)
263+
264+
stage_list.extend(w_list)
265+
recordered_w_list.append(stage_list)
266+
267+
return recordered_w_list
268+
269+
228270
def add_communication_nodes_without_sorting(
229271
config: GraphConfig,
230272
local_order: List[List[ScheduledNode]],
231273
post_validation: bool,
232274
) -> List[List[ScheduledNode]]:
275+
233276
local_order, comm_pairs = insert_send_nodes(config, local_order)
277+
234278
if post_validation:
235279
local_order, post_validation_comm_pairs = add_post_validation_nodes_before_deadline(
236280
config, local_order
237281
)
238282
comm_pairs.extend(post_validation_comm_pairs)
239283
local_order = insert_recv_nodes(config, local_order, comm_pairs)
284+
240285
if post_validation:
241286
local_order = tag_rollback_communication(config, local_order)
287+
288+
if get_args().num_virtual_stages_per_pipeline_rank is None:
289+
local_order = reorder_communication_nodes(local_order)
290+
242291
return local_order
243292

244293

0 commit comments

Comments
 (0)