@@ -567,47 +567,47 @@ def test_getp1_dyn_length_dyn_type_string_array():
567567 assert ord (ffi .cast ("char *" , s2 )[8 + ii ]) == ch
568568
569569
570- def test_gpu_api ():
571- for ctx in xo . context . get_test_contexts ( ):
572- src_code = """
573- /*gpufun*/
574- void myfun(double x, double y,
575- double* z){
576- z[0] = x * y;
577- }
570+ @ for_all_test_contexts
571+ def test_gpu_api ( test_context ):
572+ src_code = """
573+ /*gpufun*/
574+ void myfun(double x, double y,
575+ double* z){
576+ z[0] = x * y;
577+ }
578578
579- /*gpukern*/
580- void my_mul(const int n,
581- /*gpuglmem*/ const double* x1,
582- /*gpuglmem*/ const double* x2,
583- /*gpuglmem*/ double* y) {
584- int tid = 0 //vectorize_over tid n
585- double z;
586- myfun(x1[tid], x2[tid], &z);
587- y[tid] = z;
588- //end_vectorize
589- }
590- """
591-
592- kernel_descriptions = {
593- "my_mul" : xo .Kernel (
594- args = [
595- xo .Arg (xo .Int32 , name = "n" ),
596- xo .Arg (xo .Float64 , pointer = True , const = True , name = "x1" ),
597- xo .Arg (xo .Float64 , pointer = True , const = True , name = "x2" ),
598- xo .Arg (xo .Float64 , pointer = True , const = False , name = "y" ),
599- ],
600- n_threads = "n" ,
601- ),
579+ /*gpukern*/
580+ void my_mul(const int n,
581+ /*gpuglmem*/ const double* x1,
582+ /*gpuglmem*/ const double* x2,
583+ /*gpuglmem*/ double* y) {
584+ int tid = 0 //vectorize_over tid n
585+ double z;
586+ myfun(x1[tid], x2[tid], &z);
587+ y[tid] = z;
588+ //end_vectorize
602589 }
590+ """
603591
604- ctx .add_kernels (
605- sources = [src_code ],
606- kernels = kernel_descriptions ,
607- save_source_as = None ,
608- compile = True ,
609- extra_classes = [xo .String [:]],
610- )
592+ kernel_descriptions = {
593+ "my_mul" : xo .Kernel (
594+ args = [
595+ xo .Arg (xo .Int32 , name = "n" ),
596+ xo .Arg (xo .Float64 , pointer = True , const = True , name = "x1" ),
597+ xo .Arg (xo .Float64 , pointer = True , const = True , name = "x2" ),
598+ xo .Arg (xo .Float64 , pointer = True , const = False , name = "y" ),
599+ ],
600+ n_threads = "n" ,
601+ ),
602+ }
603+
604+ test_context .add_kernels (
605+ sources = [src_code ],
606+ kernels = kernel_descriptions ,
607+ save_source_as = None ,
608+ compile = True ,
609+ extra_classes = [xo .String [:]],
610+ )
611611
612612
613613@for_all_test_contexts
@@ -662,7 +662,7 @@ class Cells(xo.Struct):
662662 Cells cells,
663663 GPUGLMEM uint64_t* out_counts,
664664 GPUGLMEM uint64_t* out_vals,
665- GPUGLMEM uint8_t* success,
665+ GPUGLMEM uint8_t* success
666666 )
667667 {
668668 int64_t num_cells = Cells_len_ids(cells);
@@ -693,6 +693,15 @@ class Cells(xo.Struct):
693693 END_VECTORIZE;
694694 }
695695 }
696+
697+ GPUKERN void kernel_Cells_get_particles(
698+ Cells obj,
699+ int64_t i0,
700+ int64_t i1,
701+ GPUGLMEM int64_t* out
702+ ) {
703+ *out = Cells_get_particles(obj, i0, i1);
704+ }
696705 """
697706
698707 kernels = {
@@ -704,9 +713,17 @@ class Cells(xo.Struct):
704713 xo .Arg (xo .UInt8 , pointer = True , name = "success" ),
705714 ],
706715 n_threads = 3 ,
707- )
716+ ),
717+ "kernel_Cells_get_particles" : xo .Kernel (
718+ args = [
719+ xo .Arg (Cells , name = "cells" ),
720+ xo .Arg (xo .Int64 , pointer = True , name = "i0" ),
721+ xo .Arg (xo .Int64 , pointer = True , name = "i1" ),
722+ xo .Arg (xo .Int64 , pointer = True , name = "out" ),
723+ ],
724+ n_threads = 3 ,
725+ ),
708726 }
709- kernels .update (Cells ._gen_kernels ())
710727
711728 test_context .add_kernels (
712729 sources = [src ],
@@ -719,8 +736,9 @@ class Cells(xo.Struct):
719736
720737 for i , _ in enumerate (particle_per_cell ):
721738 for j , expected in enumerate (particle_per_cell [i ]):
722- result = test_context .kernels .Cells_get_particles (
723- obj = cells , i0 = i , i1 = j
739+ result = test_context .zeros (shape = (1 ,), dtype = np .int64 )
740+ test_context .kernels .kernel_Cells_get_particles (
741+ obj = cells , i0 = i , i1 = j , out = result
724742 )
725743 assert result == expected
726744
0 commit comments