@@ -89,6 +89,8 @@ static const std::vector<std::string> compiler_flags{
8989 " -I../../Source" ,
9090 " -I../../Source/Template" ,
9191 " -I../templates" ,
92+ " -I/share/workspace/nvidia_projects/GraphBLAS/CUDA/templates"
93+ " -I/share/workspace/nvidia_projects/GraphBLAS/CUDA/"
9294// "-L../../build/CUDA",
9395 " -I/usr/local/cuda/include" ,
9496};
@@ -141,8 +143,10 @@ class phase1launchFactory
141143 jit::GBJitCache filecache = jit::GBJitCache::Instance () ;
142144 filecache.getFile (semiring_factory_) ;
143145
146+ auto sr_code = std::to_string (semiring_factory_.sr_code );
147+
144148 std::stringstream string_to_be_jitted ;
145- std::vector<std::string> template_types = {M->type ->name };
149+ std::vector<std::string> template_types = {M->type ->name , sr_code };
146150
147151 std::string hashable_name = base_name + " _" + kernel_name;
148152 string_to_be_jitted << hashable_name << std::endl <<
@@ -155,7 +159,7 @@ class phase1launchFactory
155159 dim3 grid (get_number_of_blocks (M));
156160 dim3 block (get_threads_per_block ());
157161
158- jit::launcher ( hashable_name,
162+ jit::launcher ( hashable_name + " _ " + M-> type -> name + " _ " + sr_code ,
159163 string_to_be_jitted.str (),
160164 header_names,
161165 compiler_flags,
@@ -211,7 +215,7 @@ class phase2launchFactory
211215 const int64_t mnz = GB_nnz (M) ;
212216 jit::launcher ( hashable_name,
213217 string_to_be_jitted.str (),
214- header_names,
218+ header_names,
215219 compiler_flags,
216220 file_callback)
217221 .set_kernel_inst ( kernel_name, {})
@@ -228,13 +232,13 @@ class phase2launchFactory
228232};
229233
230234template < int threads_per_block = 32 , int chunk_size = 128 >
231- class phase2endlaunchFactory
235+ class phase2endlaunchFactory
232236{
233237
234238 std::string base_name = " GB_jit" ;
235239 std::string kernel_name = " AxB_phase2end" ;
236240
237- public:
241+ public:
238242
239243 int get_threads_per_block () {
240244 return threads_per_block;
@@ -253,8 +257,8 @@ class phase2endlaunchFactory
253257 int64_t *bucketp, int64_t *bucket, int64_t *offset,
254258 GrB_Matrix C, GrB_Matrix M)
255259 {
256-
257- bool result = false ;
260+
261+ bool result = false ;
258262
259263 dim3 grid (get_number_of_blocks (M));
260264 dim3 block (get_threads_per_block ());
@@ -269,7 +273,7 @@ class phase2endlaunchFactory
269273
270274 jit::launcher ( hashable_name,
271275 string_to_be_jitted.str (),
272- header_names,
276+ header_names,
273277 compiler_flags,
274278 file_callback)
275279 .set_kernel_inst ( kernel_name , {})
@@ -306,8 +310,8 @@ class phase3launchFactory
306310
307311 bool jitGridBlockLaunch (int64_t start, int64_t end, int64_t *bucketp, int64_t *bucket,
308312 GrB_Matrix C, GrB_Matrix M, GrB_Matrix A, GrB_Matrix B) {
309-
310- bool result = false ;
313+
314+ bool result = false ;
311315
312316 // ----------------------------------------------------------------------
313317 // phase3: do the numerical work
@@ -500,13 +504,9 @@ class reduceFactory
500504 }
501505
502506 // Note: this does assume the erased types are compatible w/ the monoid's ztype
503- bool jitGridBlockLaunch (GrB_Matrix A, void * output, unsigned int N,
507+ bool jitGridBlockLaunch (GrB_Matrix A, void * output,
504508 GrB_Monoid op)
505509 {
506- int blocksz = get_threads_per_block ();
507- int gridsz = get_number_of_blocks (N);
508- dim3 grid (gridsz);
509- dim3 block (blocksz);
510510
511511 // TODO: We probably want to "macrofy" the GrB_Monoid and define it in the `string_to_be_jitted`
512512// void GB_stringify_binop
@@ -533,6 +533,14 @@ class reduceFactory
533533 hashable_name << std::endl << R"( #include ")" <<
534534 hashable_name << R"( .cuh")" << std::endl;
535535
536+ bool is_sparse = GB_IS_SPARSE (A);
537+ int64_t N = is_sparse ? GB_nnz (A) : GB_NCOLS (A) * GB_NROWS (A);
538+
539+ int blocksz = get_threads_per_block ();
540+ int gridsz = get_number_of_blocks (N);
541+ dim3 grid (gridsz);
542+ dim3 block (blocksz);
543+
536544 jit::launcher (hashable_name,
537545 string_to_be_jitted.str (),
538546 header_names,
@@ -542,7 +550,7 @@ class reduceFactory
542550 .configure (grid, block)
543551
544552 // FIXME: GB_ADD is hardcoded into kernel for now
545- .launch ( A, temp_scalar, N);
553+ .launch ( A, temp_scalar, N, is_sparse );
546554
547555
548556 checkCudaErrors ( cudaDeviceSynchronize () );
@@ -589,9 +597,9 @@ inline bool GB_cuda_mxm_phase3(GB_cuda_semiring_factory &mysemiringfactory, GB_b
589597}
590598
591599
592- inline bool GB_cuda_reduce (GrB_Matrix A, void *output, unsigned int N, GrB_Monoid op) {
600+ inline bool GB_cuda_reduce (GrB_Matrix A, void *output, GrB_Monoid op) {
593601 reduceFactory rf;
594- return rf.jitGridBlockLaunch (A, output, N, op);
602+ return rf.jitGridBlockLaunch (A, output, op);
595603}
596604
597605
0 commit comments