@@ -457,6 +457,7 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi,
457457 unsigned int subwarp_id = threadIdx .y % (WARP_SIZE/BDIM_X);
458458 subwarp_mask = MASK << (subwarp_id*BDIM_X);
459459 }
460+ constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE;
460461
461462 // K is a power of two <= BDIM_X
462463 const int log2_K = __popc (K-1 );
@@ -500,13 +501,13 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi,
500501
501502 // K is a power of two <= 32
502503 #pragma unroll
503- for (int j = 1 ; j < BDIM_X ; j *= 2 ) {
504+ for (int j = 1 ; j < MAX_POW2_K ; j *= 2 ) {
504505
505506 if (j >= K) break ;
506507
507508 #pragma unroll
508509 for (int i = 0 ; i < NLOC; i++) {
509- locy[i] += __shfl_xor_sync (subwarp_mask, locy[i], j, BDIM_X );
510+ locy[i] += __shfl_xor_sync (subwarp_mask, locy[i], j, MAX_POW2_K );
510511 }
511512 }
512513
@@ -678,8 +679,9 @@ void s2_disco_bwd_special_vec_k(int nchans, // no. of input float (not FLOATV
678679 __shared__ float *shYOffAll[BDIM_Y][BDIM_X+PAD];
679680
680681 // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two
681- if (!(K & K-1 ) && K <= BDIM_X) { processCSR_Kpow2_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL , y); }
682- else { processCSR_Kanyv_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); }
682+ constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE;
683+ if (!(K & K-1 ) && K <= MAX_POW2_K) { processCSR_Kpow2_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL , y); }
684+ else { processCSR_Kanyv_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); }
683685
684686 return ;
685687
@@ -709,7 +711,18 @@ void launch_gen_disco_bwd(int64_t batch_size,
709711 size_t shsize = (sizeof (FLOATV_T)*(nchans*K) + sizeof (float )*nchans)*block.y ;
710712
711713 const int pscale = nlon_out / nlon_in;
712-
714+ #if 0
715+ printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<(%d,%d), (%d,%d)..., ..., %zu, ...>>> with:\n"
716+ "\tnchan_out: %ld\n"
717+ "\tK: %ld\n"
718+ "\tpscale: %d\n"
719+ "\tnlat_in: %ld\n"
720+ "\tnlon_in: %ld\n"
721+ "\tnlat_out: %ld\n"
722+ "\tnlon_out: %ld\n\n",
723+ THREADS, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale,
724+ nlat_in, nlon_in, nlat_out, nlon_out);
725+ #endif
713726 // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows
714727 s2_disco_bwd_generic_vec_k<THREADS>
715728 <<<grid, block, shsize, stream>>> (nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K,
@@ -752,7 +765,18 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans
752765 size_t shsize = (K & (K-1 )) ? sizeof (float )*DIV_UP (nchans, BDIM_X)*BDIM_X*block.y : 0 ;
753766
754767 const int pscale = nlon_out / nlon_in;
755-
768+ #if 0
769+ printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n"
770+ "\tnchans: %ld\n"
771+ "\tK: %ld\n"
772+ "\tpscale: %d\n"
773+ "\tnlat_in: %ld\n"
774+ "\tnlon_in: %ld\n"
775+ "\tnlat_out: %ld\n"
776+ "\tnlon_in: %ld\n\n",
777+ BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale,
778+ nlat_in, nlon_in, nlat_out, nlon_out);
779+ #endif
756780 s2_disco_bwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
757781 <<<grid, block, shsize, stream>>> (nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K,
758782 _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp);
0 commit comments