@@ -135,9 +135,9 @@ class BatchedDblBufGemm {
135135 std::conditional_t <std::is_same<ArgAlphaFmaTag, AlphaTag::Yes>::value, AlphaTag::No, AlphaTag::Yes>;
136136
137137 HandleType *const __handle;
138- AViewType __A ;
139- BViewType __B ;
140- CViewType __C ;
138+ AViewType A_ ;
139+ BViewType B_ ;
140+ CViewType C_ ;
141141 ScalarType __alpha, __beta;
142142 ArgTransA __transA_tag;
143143 ArgTransB __transB_tag;
@@ -156,7 +156,7 @@ class BatchedDblBufGemm {
156156
157157 public:
158158 BatchedDblBufGemm (HandleType *const handle, ScalarType alpha, AViewType A, BViewType B, ScalarType beta, CViewType C)
159- : __handle(handle), __A (A), __B (B), __C (C), __alpha(alpha), __beta(beta) {}
159+ : __handle(handle), A_ (A), B_ (B), C_ (C), __alpha(alpha), __beta(beta) {}
160160
161161 int invoke () {
162162 __run ();
@@ -181,7 +181,7 @@ class BatchedDblBufGemm {
181181 constexpr int stride_n = TILE_N / reg_n;
182182 using functor_type = Functor<member_type, reg_m, reg_n, stride_m, stride_n>;
183183
184- functor_type functor (*this , __A, __B, __C );
184+ functor_type functor (*this , A_, B_, C_ );
185185
186186 if (__handle->enableDebug ) {
187187 std::cout << " algo_type:" << __handle->get_kernel_algo_type () << std::endl
@@ -244,37 +244,37 @@ class BatchedDblBufGemm {
244244 class Functor {
245245 private:
246246 BatchedDblBufGemm &__ei;
247- AViewType __A ;
248- BViewType __B ;
249- CViewType __C ;
247+ AViewType A_ ;
248+ BViewType B_ ;
249+ CViewType C_ ;
250250 ScalarType __alpha, __beta;
251251 int __k;
252252 size_t __n_sub_tiles, __tiles_per_col, __tiles_per_row;
253253
254254 public:
255255 size_t get_n_sub_tiles () { return __n_sub_tiles; }
256256
257- // NOTE: We cannot use __ei.{__A,__B,__C ,__beta,__alpha,__k} in the operator
257+ // NOTE: We cannot use __ei.{A_,B_,C_ ,__beta,__alpha,__k} in the operator
258258 // below. If those are used, we get an invalid memory error from cuda. I
259259 // suspect this is due the values not being copied to device and then
260260 // runtime resolution of the host address &__ei.
261- Functor (BatchedDblBufGemm &ei, AViewType A, BViewType B, CViewType C) : __ei(ei), __A (A), __B (B), __C (C) {
261+ Functor (BatchedDblBufGemm &ei, AViewType A, BViewType B, CViewType C) : __ei(ei), A_ (A), B_ (B), C_ (C) {
262262 if (std::is_same<ArgBatchSzDim, BatchLayout::Left>::value) {
263- ei.__c_batch_size = ei.__C .extent_int (0 );
264- ei.__c_m = ei.__C .extent_int (1 );
265- ei.__c_n = ei.__C .extent_int (2 );
263+ ei.__c_batch_size = ei.C_ .extent_int (0 );
264+ ei.__c_m = ei.C_ .extent_int (1 );
265+ ei.__c_n = ei.C_ .extent_int (2 );
266266 if (std::is_same<ArgTransA, Trans::Transpose>::value)
267- __k = ei.__A .extent_int (1 );
267+ __k = ei.A_ .extent_int (1 );
268268 else
269- __k = ei.__A .extent_int (2 );
269+ __k = ei.A_ .extent_int (2 );
270270 } else {
271- ei.__c_batch_size = ei.__C .extent_int (2 );
272- ei.__c_m = ei.__C .extent_int (0 );
273- ei.__c_n = ei.__C .extent_int (1 );
271+ ei.__c_batch_size = ei.C_ .extent_int (2 );
272+ ei.__c_m = ei.C_ .extent_int (0 );
273+ ei.__c_n = ei.C_ .extent_int (1 );
274274 if (std::is_same<ArgTransA, Trans::Transpose>::value)
275- __k = ei.__A .extent_int (0 );
275+ __k = ei.A_ .extent_int (0 );
276276 else
277- __k = ei.__A .extent_int (1 );
277+ __k = ei.A_ .extent_int (1 );
278278 }
279279 __beta = ei.__beta ; // Copy to device
280280 __alpha = ei.__alpha ; // Copy to device
@@ -381,10 +381,10 @@ class BatchedDblBufGemm {
381381
382382 // Fetch entire 2-rank sub-matrix
383383 auto svA =
384- subview_wrapper (__A , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transA_tag );
384+ subview_wrapper (A_ , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transA_tag );
385385 auto svB =
386- subview_wrapper (__B , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transB_tag );
387- auto svC = subview_wrapper (__C , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag );
386+ subview_wrapper (B_ , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transB_tag );
387+ auto svC = subview_wrapper (C_ , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag );
388388
389389 // Allocate scratch memory buffers used for prefetching
390390 view_type_2d_scratch svA_scr (member.team_scratch (0 ), TILE_M, TILE_K);
@@ -524,10 +524,10 @@ class BatchedDblBufGemm {
524524
525525 // Fetch entire 2-rank sub-matrix
526526 auto svA =
527- subview_wrapper (__A , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transA_tag );
527+ subview_wrapper (A_ , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transA_tag );
528528 auto svB =
529- subview_wrapper (__B , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transB_tag );
530- auto svC = subview_wrapper (__C , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag );
529+ subview_wrapper (B_ , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag , __ei.__transB_tag );
530+ auto svC = subview_wrapper (C_ , batch_idx, Kokkos::ALL (), Kokkos::ALL (), __ei.__batch_layout_tag );
531531
532532 // Allocate scratch memory buffers used for prefetching
533533 view_type_2d_scratch svA_scr (member.team_scratch (0 ), TILE_K, TILE_M);
0 commit comments