Skip to content

Commit 52a981f

Browse files
fixes for gpu, built in kernels cannot be passed to add_kernels on gpu
1 parent bc0e6b2 commit 52a981f

File tree

1 file changed

+61
-43
lines changed

1 file changed

+61
-43
lines changed

tests/test_capi.py

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -567,47 +567,47 @@ def test_getp1_dyn_length_dyn_type_string_array():
567567
assert ord(ffi.cast("char *", s2)[8 + ii]) == ch
568568

569569

570-
def test_gpu_api():
571-
for ctx in xo.context.get_test_contexts():
572-
src_code = """
573-
/*gpufun*/
574-
void myfun(double x, double y,
575-
double* z){
576-
z[0] = x * y;
577-
}
570+
@for_all_test_contexts
571+
def test_gpu_api(test_context):
572+
src_code = """
573+
/*gpufun*/
574+
void myfun(double x, double y,
575+
double* z){
576+
z[0] = x * y;
577+
}
578578
579-
/*gpukern*/
580-
void my_mul(const int n,
581-
/*gpuglmem*/ const double* x1,
582-
/*gpuglmem*/ const double* x2,
583-
/*gpuglmem*/ double* y) {
584-
int tid = 0 //vectorize_over tid n
585-
double z;
586-
myfun(x1[tid], x2[tid], &z);
587-
y[tid] = z;
588-
//end_vectorize
589-
}
590-
"""
591-
592-
kernel_descriptions = {
593-
"my_mul": xo.Kernel(
594-
args=[
595-
xo.Arg(xo.Int32, name="n"),
596-
xo.Arg(xo.Float64, pointer=True, const=True, name="x1"),
597-
xo.Arg(xo.Float64, pointer=True, const=True, name="x2"),
598-
xo.Arg(xo.Float64, pointer=True, const=False, name="y"),
599-
],
600-
n_threads="n",
601-
),
579+
/*gpukern*/
580+
void my_mul(const int n,
581+
/*gpuglmem*/ const double* x1,
582+
/*gpuglmem*/ const double* x2,
583+
/*gpuglmem*/ double* y) {
584+
int tid = 0 //vectorize_over tid n
585+
double z;
586+
myfun(x1[tid], x2[tid], &z);
587+
y[tid] = z;
588+
//end_vectorize
602589
}
590+
"""
603591

604-
ctx.add_kernels(
605-
sources=[src_code],
606-
kernels=kernel_descriptions,
607-
save_source_as=None,
608-
compile=True,
609-
extra_classes=[xo.String[:]],
610-
)
592+
kernel_descriptions = {
593+
"my_mul": xo.Kernel(
594+
args=[
595+
xo.Arg(xo.Int32, name="n"),
596+
xo.Arg(xo.Float64, pointer=True, const=True, name="x1"),
597+
xo.Arg(xo.Float64, pointer=True, const=True, name="x2"),
598+
xo.Arg(xo.Float64, pointer=True, const=False, name="y"),
599+
],
600+
n_threads="n",
601+
),
602+
}
603+
604+
test_context.add_kernels(
605+
sources=[src_code],
606+
kernels=kernel_descriptions,
607+
save_source_as=None,
608+
compile=True,
609+
extra_classes=[xo.String[:]],
610+
)
611611

612612

613613
@for_all_test_contexts
@@ -662,7 +662,7 @@ class Cells(xo.Struct):
662662
Cells cells,
663663
GPUGLMEM uint64_t* out_counts,
664664
GPUGLMEM uint64_t* out_vals,
665-
GPUGLMEM uint8_t* success,
665+
GPUGLMEM uint8_t* success
666666
)
667667
{
668668
int64_t num_cells = Cells_len_ids(cells);
@@ -693,6 +693,15 @@ class Cells(xo.Struct):
693693
END_VECTORIZE;
694694
}
695695
}
696+
697+
GPUKERN void kernel_Cells_get_particles(
698+
Cells obj,
699+
int64_t i0,
700+
int64_t i1,
701+
GPUGLMEM int64_t* out
702+
) {
703+
*out = Cells_get_particles(obj, i0, i1);
704+
}
696705
"""
697706

698707
kernels = {
@@ -704,9 +713,17 @@ class Cells(xo.Struct):
704713
xo.Arg(xo.UInt8, pointer=True, name="success"),
705714
],
706715
n_threads=3,
707-
)
716+
),
717+
"kernel_Cells_get_particles": xo.Kernel(
718+
args=[
719+
xo.Arg(Cells, name="cells"),
720+
xo.Arg(xo.Int64, pointer=True, name="i0"),
721+
xo.Arg(xo.Int64, pointer=True, name="i1"),
722+
xo.Arg(xo.Int64, pointer=True, name="out"),
723+
],
724+
n_threads=3,
725+
),
708726
}
709-
kernels.update(Cells._gen_kernels())
710727

711728
test_context.add_kernels(
712729
sources=[src],
@@ -719,8 +736,9 @@ class Cells(xo.Struct):
719736

720737
for i, _ in enumerate(particle_per_cell):
721738
for j, expected in enumerate(particle_per_cell[i]):
722-
result = test_context.kernels.Cells_get_particles(
723-
obj=cells, i0=i, i1=j
739+
result = test_context.zeros(shape=(1,), dtype=np.int64)
740+
test_context.kernels.kernel_Cells_get_particles(
741+
obj=cells, i0=i, i1=j, out=result
724742
)
725743
assert result == expected
726744

0 commit comments

Comments
 (0)