16
16
17
17
#include " paddle/fluid/distributed/collective/Common.h"
18
18
#include " paddle/fluid/distributed/collective/NCCLTools.h"
19
+ #include " paddle/fluid/distributed/collective/static_check.h"
19
20
#include " paddle/fluid/distributed/collective/utils.h"
20
21
#include " paddle/fluid/platform/device/gpu/nccl_helper.h"
21
22
#include " paddle/fluid/platform/place.h"
@@ -138,8 +139,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
138
139
// numel > 0 indicates the tensor need to be sliced
139
140
const phi::DenseTensor& in_tensor_maybe_partial =
140
141
numel > 0 ? GetPartialTensor (in_tensor, offset, numel) : in_tensor;
141
- StaticCheckTensorsGatherLikeShape (
142
- *out_tensor, in_tensor_maybe_partial, rank_, size_);
142
+ CommStaticCheck::GatherLikeShape (*out_tensor,
143
+ in_tensor_maybe_partial,
144
+ /* dst_rank*/ rank_,
145
+ /* cur_rank*/ rank_,
146
+ size_);
143
147
return RunFnInNCCLEnv (
144
148
[&](ncclComm_t comm, gpuStream_t stream) {
145
149
NCCL_CHECK (platform::dynload::ncclAllGather (
@@ -162,7 +166,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
162
166
const AllreduceOptions& opts,
163
167
bool sync_op,
164
168
bool use_calc_stream) {
165
- StaticCheckTensorsSameShape (*out_tensor, in_tensor, rank_, size_);
169
+ CommStaticCheck::SameShape (*out_tensor,
170
+ in_tensor,
171
+ /* dst_rank*/ rank_,
172
+ /* cur_rank*/ rank_,
173
+ size_);
166
174
return RunFnInNCCLEnv (
167
175
[&](ncclComm_t comm, gpuStream_t stream) {
168
176
NCCL_CHECK (platform::dynload::ncclAllReduce (
@@ -214,12 +222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
214
222
// NOTE: Since `all_to_all` needs other processes's participation, it cannot
215
223
// simply be covered by static checks. Factors are set to 0 here to skip the
216
224
// shape check. Its shape check will be done by dynamic checks in debug mode.
217
- StaticCheckTensors (*out_tensor,
218
- in_tensor,
219
- rank_,
220
- size_,
221
- /* out_size_factor*/ 0 ,
222
- /* in_size_factor*/ 0 );
225
+ CommStaticCheck::CheckShape (*out_tensor,
226
+ in_tensor,
227
+ /* dst_rank*/ rank_,
228
+ /* cur_rank*/ rank_,
229
+ size_,
230
+ /* out_size_factor*/ 0 ,
231
+ /* in_size_factor*/ 0 );
223
232
return RunFnInNCCLEnv (
224
233
[&](ncclComm_t comm, gpuStream_t stream) {
225
234
int64_t in_row_size = in_tensor.numel () / in_dim[0 ],
@@ -287,7 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
287
296
const BroadcastOptions& opts,
288
297
bool sync_op,
289
298
bool use_calc_stream) {
290
- StaticCheckTensorsSameShape (*out_tensor, in_tensor, rank_, size_);
299
+ CommStaticCheck::SameShape (*out_tensor,
300
+ in_tensor,
301
+ /* dst_rank*/ rank_,
302
+ /* cur_rank*/ rank_,
303
+ size_);
291
304
return RunFnInNCCLEnv (
292
305
[&](ncclComm_t comm, gpuStream_t stream) {
293
306
int root = opts.source_rank + opts.source_root ;
@@ -312,7 +325,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
312
325
const ReduceOptions& opts,
313
326
bool sync_op,
314
327
bool use_calc_stream) {
315
- StaticCheckTensorsSameShape (*out_tensor, in_tensor, rank_, size_);
328
+ CommStaticCheck::SameShape (*out_tensor,
329
+ in_tensor,
330
+ /* dst_rank*/ opts.root_rank ,
331
+ /* cur_rank*/ rank_,
332
+ size_);
316
333
return RunFnInNCCLEnv (
317
334
[&](ncclComm_t comm, gpuStream_t stream) {
318
335
NCCL_CHECK (platform::dynload::ncclReduce (
@@ -337,7 +354,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
337
354
const ReduceScatterOptions& opts,
338
355
bool sync_op,
339
356
bool use_calc_stream) {
340
- StaticCheckTensorsScatterLikeShape (*out_tensor, in_tensor, rank_, size_);
357
+ CommStaticCheck::ScatterLikeShape (*out_tensor,
358
+ in_tensor,
359
+ /* dst_rank*/ rank_,
360
+ /* cur_rank*/ rank_,
361
+ size_);
341
362
return RunFnInNCCLEnv (
342
363
[&](ncclComm_t comm, gpuStream_t stream) {
343
364
NCCL_CHECK (platform::dynload::ncclReduceScatter (
@@ -361,7 +382,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
361
382
const ScatterOptions& opts,
362
383
bool sync_op,
363
384
bool use_calc_stream) {
364
- StaticCheckTensorsScatterLikeShape (*out_tensor, in_tensor, rank_, size_);
385
+ CommStaticCheck::ScatterLikeShape (*out_tensor,
386
+ in_tensor,
387
+ /* dst_rank*/ opts.root_rank ,
388
+ /* cur_rank*/ rank_,
389
+ size_);
365
390
return RunFnInNCCLEnv (
366
391
[&](ncclComm_t comm, gpuStream_t stream) {
367
392
int64_t numel = in_tensor.numel () / size_;
@@ -418,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
418
443
tensor = &partial_tensor;
419
444
}
420
445
421
- StaticCheckTensor (*tensor, rank_, size_);
446
+ CommStaticCheck::SingleTensor (*tensor, rank_, size_);
422
447
return RunFnInNCCLEnv (
423
448
[&](ncclComm_t comm, gpuStream_t stream) {
424
449
NCCL_CHECK (platform::dynload::ncclRecv (
@@ -446,7 +471,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
446
471
const phi::DenseTensor& tensor_maybe_partial =
447
472
numel > 0 ? GetPartialTensor (tensor, offset, numel) : tensor;
448
473
449
- StaticCheckTensor (tensor_maybe_partial, rank_, size_);
474
+ CommStaticCheck::SingleTensor (tensor_maybe_partial, rank_, size_);
450
475
return RunFnInNCCLEnv (
451
476
[&](ncclComm_t comm, gpuStream_t stream) {
452
477
NCCL_CHECK (platform::dynload::ncclSend (
0 commit comments