@@ -124,15 +124,37 @@ void ucc_tl_ucp_reduce_knomial_progress(ucc_coll_task_t *coll_task)
124124
125125ucc_status_t ucc_tl_ucp_reduce_knomial_start (ucc_coll_task_t * coll_task )
126126{
127- ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task , ucc_tl_ucp_task_t );
128- ucc_coll_args_t * args = & TASK_ARGS (task );
129- ucc_tl_ucp_team_t * team = TASK_TEAM (task );
130- uint32_t radix = task -> reduce_kn .radix ;
131- ucc_rank_t root = (ucc_rank_t )args -> root ;
132- ucc_rank_t rank = UCC_TL_TEAM_RANK (team );
133- ucc_rank_t size = UCC_TL_TEAM_SIZE (team );
134- ucc_rank_t vrank = (rank - root + size ) % size ;
135- int isleaf = ((vrank % radix != 0 ) || (vrank == size - 1 ));
127+ ucc_tl_ucp_task_t * task =
128+ ucc_derived_of (coll_task , ucc_tl_ucp_task_t );
129+ ucc_coll_args_t * args = & TASK_ARGS (task );
130+ ucc_tl_ucp_team_t * team = TASK_TEAM (task );
131+ uint32_t radix = task -> reduce_kn .radix ;
132+ ucc_rank_t root = (ucc_rank_t )args -> root ;
133+ ucc_rank_t rank = UCC_TL_TEAM_RANK (team );
134+ ucc_rank_t size = UCC_TL_TEAM_SIZE (team );
135+ ucc_rank_t vrank = (rank - root + size ) % size ;
136+ int isleaf =
137+ (vrank % radix != 0 || vrank == size - 1 );
138+ int avg_pre_op =
139+ UCC_TL_UCP_TEAM_LIB (team )-> cfg .reduce_avg_pre_op ;
140+ int self_avg = (args -> op == UCC_OP_AVG &&
141+ avg_pre_op && vrank % radix == 0 );
142+ size_t data_size , count ;
143+ ucc_memory_type_t mtype ;
144+ ucc_datatype_t dt ;
145+ ucc_status_t status ;
146+
147+ if (root == rank ) {
148+ count = args -> dst .info .count ;
149+ dt = args -> dst .info .datatype ;
150+ mtype = args -> dst .info .mem_type ;
151+ } else {
152+ count = args -> src .info .count ;
153+ dt = args -> src .info .datatype ;
154+ mtype = args -> src .info .mem_type ;
155+ }
156+ data_size = count * ucc_dt_size (dt );
157+
136158
137159 UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_reduce_kn_start" , 0 );
138160 ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
@@ -141,10 +163,26 @@ ucc_status_t ucc_tl_ucp_reduce_knomial_start(ucc_coll_task_t *coll_task)
141163 args -> src .info .buffer = args -> dst .info .buffer ;
142164 }
143165
144- if (isleaf ) {
166+ if (isleaf && ! self_avg ) {
145167 task -> reduce_kn .scratch = args -> src .info .buffer ;
146168 }
147169
170+ if (isleaf && self_avg ) {
171+ /* In case of avg_pre_op, single leaf process which does not take part
172+ in first iteration reduction must divide itself by team_size */
173+ status = ucc_dt_reduce_multi_alpha (args -> src .info .buffer ,
174+ args -> src .info .buffer , task -> reduce_kn .scratch , 1 , count ,
175+ data_size , dt , UCC_OP_PROD ,
176+ (double )1 / (double )(UCC_TL_TEAM_SIZE (TASK_TEAM (task )) * 2 ),
177+ mtype , args );
178+ if (ucc_unlikely (UCC_OK != status )) {
179+ tl_error (UCC_TASK_LIB (task ),
180+ "failed to perform dt reduction" );
181+ task -> super .super .status = status ;
182+ return status ;
183+ }
184+ }
185+
148186 task -> reduce_kn .dist = 1 ;
149187 task -> reduce_kn .phase = UCC_REDUCE_KN_PHASE_INIT ;
150188
0 commit comments