Skip to content

Commit 4552be4

Browse files
authored
Refactor collective communication static check (#48646)
* refactor: classify static check * refactor: rename to static_check & use forward decl * refactor: switch to unary & binary funcs
1 parent f9815bf commit 4552be4

File tree

6 files changed

+273
-147
lines changed

6 files changed

+273
-147
lines changed

paddle/fluid/distributed/collective/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ endif()
2121
if(WITH_NCCL OR WITH_RCCL)
2222
cc_library(
2323
processgroup_nccl
24-
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
24+
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc static_check.cc
2525
DEPS processgroup
2626
processgroup_stream
2727
place

paddle/fluid/distributed/collective/NCCLTools.cc

-104
Original file line numberDiff line numberDiff line change
@@ -44,109 +44,5 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) {
4444
return oss.str();
4545
}
4646

47-
void StaticCheckTensor(const phi::DenseTensor& tensor,
48-
int rank,
49-
int world_size) {
50-
// place check
51-
PADDLE_ENFORCE_EQ(
52-
platform::is_gpu_place(tensor.place()),
53-
true,
54-
platform::errors::InvalidArgument("Tensor should be in GPU place."));
55-
// rank check
56-
PADDLE_ENFORCE_GE(rank,
57-
0,
58-
platform::errors::InvalidArgument(
59-
"Rank should be greater than or equal to 0."));
60-
PADDLE_ENFORCE_LT(
61-
rank,
62-
world_size,
63-
platform::errors::InvalidArgument("Rank is out of the process group."));
64-
}
65-
66-
// static check for collective
67-
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
68-
const phi::DenseTensor& in_tensor,
69-
int rank,
70-
int world_size,
71-
int out_size_factor,
72-
int in_size_factor) {
73-
// place check
74-
PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_tensor.place()),
75-
true,
76-
platform::errors::InvalidArgument(
77-
"Output tensor should be in GPU place."));
78-
PADDLE_ENFORCE_EQ(platform::is_gpu_place(in_tensor.place()),
79-
true,
80-
platform::errors::InvalidArgument(
81-
"Input tensor should be in GPU place."));
82-
// rank check
83-
PADDLE_ENFORCE_GE(rank,
84-
0,
85-
platform::errors::InvalidArgument(
86-
"Rank should be greater than or equal to 0."));
87-
PADDLE_ENFORCE_LT(
88-
rank,
89-
world_size,
90-
platform::errors::InvalidArgument("Rank is out of the process group."));
91-
// shape check
92-
int64_t out_size = out_tensor.numel();
93-
PADDLE_ENFORCE_GT(out_size,
94-
0,
95-
platform::errors::InvalidArgument(
96-
"Size of output tensor should be greater than 0."));
97-
int64_t in_size = in_tensor.numel();
98-
PADDLE_ENFORCE_GT(in_size,
99-
0,
100-
platform::errors::InvalidArgument(
101-
"Size of input tensor should be greater than 0."));
102-
PADDLE_ENFORCE_EQ(
103-
out_size * out_size_factor,
104-
in_size * in_size_factor,
105-
platform::errors::InvalidArgument(
106-
"Input and output tensors should have matching sizes."));
107-
// dtype check
108-
PADDLE_ENFORCE_EQ(
109-
out_tensor.dtype(),
110-
in_tensor.dtype(),
111-
platform::errors::InvalidArgument(
112-
"Input and output tensors should have the same data type."));
113-
}
114-
115-
void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
116-
const phi::DenseTensor& in_tensor,
117-
int rank,
118-
int world_size) {
119-
StaticCheckTensors(out_tensor,
120-
in_tensor,
121-
rank,
122-
world_size,
123-
/*out_size_factor*/ 1,
124-
/*in_size_factor*/ 1);
125-
}
126-
127-
void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
128-
const phi::DenseTensor& in_tensor,
129-
int rank,
130-
int world_size) {
131-
StaticCheckTensors(out_tensor,
132-
in_tensor,
133-
rank,
134-
world_size,
135-
/*out_size_factor*/ world_size,
136-
/*in_size_factor*/ 1);
137-
}
138-
139-
void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
140-
const phi::DenseTensor& in_tensor,
141-
int rank,
142-
int world_size) {
143-
StaticCheckTensors(out_tensor,
144-
in_tensor,
145-
rank,
146-
world_size,
147-
/*out_size_factor*/ 1,
148-
/*in_size_factor*/ world_size);
149-
}
150-
15147
} // namespace distributed
15248
} // namespace paddle

paddle/fluid/distributed/collective/NCCLTools.h

-27
Original file line numberDiff line numberDiff line change
@@ -63,32 +63,5 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
6363

6464
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
6565

66-
// static check for p2p
67-
void StaticCheckTensor(const phi::DenseTensor& tensor,
68-
int rank,
69-
int world_size);
70-
71-
// static check for collective
72-
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
73-
const phi::DenseTensor& in_tensor,
74-
int rank,
75-
int world_size,
76-
int out_size_factor,
77-
int in_size_factor);
78-
79-
void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
80-
const phi::DenseTensor& in_tensor,
81-
int rank,
82-
int world_size);
83-
84-
void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
85-
const phi::DenseTensor& in_tensor,
86-
int rank,
87-
int world_size);
88-
89-
void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
90-
const phi::DenseTensor& in_tensor,
91-
int rank,
92-
int world_size);
9366
} // namespace distributed
9467
} // namespace paddle

paddle/fluid/distributed/collective/ProcessGroupNCCL.cc

+40-15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/fluid/distributed/collective/Common.h"
1818
#include "paddle/fluid/distributed/collective/NCCLTools.h"
19+
#include "paddle/fluid/distributed/collective/static_check.h"
1920
#include "paddle/fluid/distributed/collective/utils.h"
2021
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
2122
#include "paddle/fluid/platform/place.h"
@@ -138,8 +139,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
138139
// numel > 0 indicates the tensor need to be sliced
139140
const phi::DenseTensor& in_tensor_maybe_partial =
140141
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
141-
StaticCheckTensorsGatherLikeShape(
142-
*out_tensor, in_tensor_maybe_partial, rank_, size_);
142+
CommStaticCheck::GatherLikeShape(*out_tensor,
143+
in_tensor_maybe_partial,
144+
/*dst_rank*/ rank_,
145+
/*cur_rank*/ rank_,
146+
size_);
143147
return RunFnInNCCLEnv(
144148
[&](ncclComm_t comm, gpuStream_t stream) {
145149
NCCL_CHECK(platform::dynload::ncclAllGather(
@@ -162,7 +166,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
162166
const AllreduceOptions& opts,
163167
bool sync_op,
164168
bool use_calc_stream) {
165-
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
169+
CommStaticCheck::SameShape(*out_tensor,
170+
in_tensor,
171+
/*dst_rank*/ rank_,
172+
/*cur_rank*/ rank_,
173+
size_);
166174
return RunFnInNCCLEnv(
167175
[&](ncclComm_t comm, gpuStream_t stream) {
168176
NCCL_CHECK(platform::dynload::ncclAllReduce(
@@ -214,12 +222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
214222
// NOTE: Since `all_to_all` needs other processes's participation, it cannot
215223
// simply be covered by static checks. Factors are set to 0 here to skip the
216224
// shape check. Its shape check will be done by dynamic checks in debug mode.
217-
StaticCheckTensors(*out_tensor,
218-
in_tensor,
219-
rank_,
220-
size_,
221-
/*out_size_factor*/ 0,
222-
/*in_size_factor*/ 0);
225+
CommStaticCheck::CheckShape(*out_tensor,
226+
in_tensor,
227+
/*dst_rank*/ rank_,
228+
/*cur_rank*/ rank_,
229+
size_,
230+
/*out_size_factor*/ 0,
231+
/*in_size_factor*/ 0);
223232
return RunFnInNCCLEnv(
224233
[&](ncclComm_t comm, gpuStream_t stream) {
225234
int64_t in_row_size = in_tensor.numel() / in_dim[0],
@@ -287,7 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
287296
const BroadcastOptions& opts,
288297
bool sync_op,
289298
bool use_calc_stream) {
290-
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
299+
CommStaticCheck::SameShape(*out_tensor,
300+
in_tensor,
301+
/*dst_rank*/ rank_,
302+
/*cur_rank*/ rank_,
303+
size_);
291304
return RunFnInNCCLEnv(
292305
[&](ncclComm_t comm, gpuStream_t stream) {
293306
int root = opts.source_rank + opts.source_root;
@@ -312,7 +325,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
312325
const ReduceOptions& opts,
313326
bool sync_op,
314327
bool use_calc_stream) {
315-
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
328+
CommStaticCheck::SameShape(*out_tensor,
329+
in_tensor,
330+
/*dst_rank*/ opts.root_rank,
331+
/*cur_rank*/ rank_,
332+
size_);
316333
return RunFnInNCCLEnv(
317334
[&](ncclComm_t comm, gpuStream_t stream) {
318335
NCCL_CHECK(platform::dynload::ncclReduce(
@@ -337,7 +354,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
337354
const ReduceScatterOptions& opts,
338355
bool sync_op,
339356
bool use_calc_stream) {
340-
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
357+
CommStaticCheck::ScatterLikeShape(*out_tensor,
358+
in_tensor,
359+
/*dst_rank*/ rank_,
360+
/*cur_rank*/ rank_,
361+
size_);
341362
return RunFnInNCCLEnv(
342363
[&](ncclComm_t comm, gpuStream_t stream) {
343364
NCCL_CHECK(platform::dynload::ncclReduceScatter(
@@ -361,7 +382,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
361382
const ScatterOptions& opts,
362383
bool sync_op,
363384
bool use_calc_stream) {
364-
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
385+
CommStaticCheck::ScatterLikeShape(*out_tensor,
386+
in_tensor,
387+
/*dst_rank*/ opts.root_rank,
388+
/*cur_rank*/ rank_,
389+
size_);
365390
return RunFnInNCCLEnv(
366391
[&](ncclComm_t comm, gpuStream_t stream) {
367392
int64_t numel = in_tensor.numel() / size_;
@@ -418,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
418443
tensor = &partial_tensor;
419444
}
420445

421-
StaticCheckTensor(*tensor, rank_, size_);
446+
CommStaticCheck::SingleTensor(*tensor, rank_, size_);
422447
return RunFnInNCCLEnv(
423448
[&](ncclComm_t comm, gpuStream_t stream) {
424449
NCCL_CHECK(platform::dynload::ncclRecv(
@@ -446,7 +471,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
446471
const phi::DenseTensor& tensor_maybe_partial =
447472
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
448473

449-
StaticCheckTensor(tensor_maybe_partial, rank_, size_);
474+
CommStaticCheck::SingleTensor(tensor_maybe_partial, rank_, size_);
450475
return RunFnInNCCLEnv(
451476
[&](ncclComm_t comm, gpuStream_t stream) {
452477
NCCL_CHECK(platform::dynload::ncclSend(

0 commit comments

Comments
 (0)