Skip to content

Commit 0263f19

Browse files
mauro-bisazrael417
authored andcommitted
Increased max no. of element per thread to 20 to both fwd and bwd and
changes bwd to limit the unroll length in processCSR.
1 parent 0cb7a49 commit 0263f19

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

torch_harmonics/disco/csrc/disco_cuda_bwd.cu

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi,
457457
unsigned int subwarp_id = threadIdx.y % (WARP_SIZE/BDIM_X);
458458
subwarp_mask = MASK << (subwarp_id*BDIM_X);
459459
}
460+
constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE;
460461

461462
// K is a power of two <= BDIM_X
462463
const int log2_K = __popc(K-1);
@@ -500,13 +501,13 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi,
500501

501502
// K is a power of two <= 32
502503
#pragma unroll
503-
for(int j = 1; j < BDIM_X; j *= 2) {
504+
for(int j = 1; j < MAX_POW2_K; j *= 2) {
504505

505506
if (j >= K) break;
506507

507508
#pragma unroll
508509
for(int i = 0; i < NLOC; i++) {
509-
locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j, BDIM_X);
510+
locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j, MAX_POW2_K);
510511
}
511512
}
512513

@@ -678,8 +679,9 @@ void s2_disco_bwd_special_vec_k(int nchans, // no. of input float (not FLOATV
678679
__shared__ float *shYOffAll[BDIM_Y][BDIM_X+PAD];
679680

680681
// check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two
681-
if (!(K & K-1) && K <= BDIM_X) { processCSR_Kpow2_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL, y); }
682-
else { processCSR_Kanyv_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); }
682+
constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE;
683+
if (!(K & K-1) && K <= MAX_POW2_K) { processCSR_Kpow2_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL, y); }
684+
else { processCSR_Kanyv_reg_d<BDIM_X, PAD, NLOC>(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); }
683685

684686
return;
685687

@@ -709,7 +711,18 @@ void launch_gen_disco_bwd(int64_t batch_size,
709711
size_t shsize = (sizeof(FLOATV_T)*(nchans*K) + sizeof(float)*nchans)*block.y;
710712

711713
const int pscale = nlon_out / nlon_in;
712-
714+
#if 0
715+
printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<(%d,%d), (%d,%d)..., ..., %zu, ...>>> with:\n"
716+
"\tnchan_out: %ld\n"
717+
"\tK: %ld\n"
718+
"\tpscale: %d\n"
719+
"\tnlat_in: %ld\n"
720+
"\tnlon_in: %ld\n"
721+
"\tnlat_out: %ld\n"
722+
"\tnlon_out: %ld\n\n",
723+
THREADS, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale,
724+
nlat_in, nlon_in, nlat_out, nlon_out);
725+
#endif
713726
// will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows
714727
s2_disco_bwd_generic_vec_k<THREADS>
715728
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K,
@@ -752,7 +765,18 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans
752765
size_t shsize = (K & (K-1)) ? sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y : 0;
753766

754767
const int pscale = nlon_out / nlon_in;
755-
768+
#if 0
769+
printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n"
770+
"\tnchans: %ld\n"
771+
"\tK: %ld\n"
772+
"\tpscale: %d\n"
773+
"\tnlat_in: %ld\n"
774+
"\tnlon_in: %ld\n"
775+
"\tnlat_out: %ld\n"
776+
"\tnlon_in: %ld\n\n",
777+
BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale,
778+
nlat_in, nlon_in, nlat_out, nlon_out);
779+
#endif
756780
s2_disco_bwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
757781
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K,
758782
_xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp);

torch_harmonics/disco/csrc/disco_cuda_fwd.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,13 @@ void launch_gen_disco_fwd(int64_t batch_size,
659659
size_t shsize = sizeof(FLOATV_T)*(nchan_in*K)*block.y;
660660

661661
const int pscale = nlon_in / nlon_out;
662-
662+
#if 0
663+
printf("Launching s2_disco_fwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n"
664+
"\tnchan_in: %ld\n"
665+
"\tK: %ld\n"
666+
"\tpscale: %d\n\n",
667+
THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, nchan_in, K, pscale);
668+
#endif
663669
// will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows
664670
s2_disco_fwd_generic_vec_k<THREADS>
665671
<<<grid, block, shsize, stream>>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K,
@@ -701,7 +707,13 @@ void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans
701707
size_t shsize = 0; //sizeof(float)*chxgrp_out * block.y;
702708

703709
const int pscale = nlon_in / nlon_out;
704-
710+
#if 0
711+
printf("Launching s2_disco_fwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n"
712+
"\tnchan_in: %ld\n"
713+
"\tK: %ld\n"
714+
"\tpscale: %d\n\n",
715+
BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchan_in, K, pscale);
716+
#endif
705717
s2_disco_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
706718
<<<grid, block, shsize, stream>>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K,
707719
_xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp);

0 commit comments

Comments
 (0)