@@ -158,6 +158,70 @@ def test_array_dynamic_type_init_get_set(array_cls, example_shape):
158158 assert arr [ii ].field1 [idx_in_field ] == 13 * vv
159159
160160
161+ @for_all_test_contexts
162+ @pytest .mark .parametrize (
163+ "array_type" ,
164+ [
165+ xo .UInt64 [3 , 5 , 7 ],
166+ xo .UInt64 [:, :, :],
167+ xo .UInt64 [:, 5 , :],
168+ ],
169+ )
170+ def test_array_get_shape (test_context , array_type ):
171+ source = """
172+ #include "xobjects/headers/common.h"
173+
174+ GPUKERN void get_nd_and_shape(
175+ ARRAY_TYPE arr,
176+ GPUGLMEM int64_t* out_nd,
177+ GPUGLMEM int64_t* out_shape
178+ ) {
179+ *out_nd = ARRAY_TYPE_nd(arr);
180+ ARRAY_TYPE_shape(arr, out_shape);
181+ }
182+ """ .replace (
183+ "ARRAY_TYPE" , array_type .__name__
184+ )
185+
186+ kernels = {
187+ "get_nd_and_shape" : xo .Kernel (
188+ c_name = "get_nd_and_shape" ,
189+ args = [
190+ xo .Arg (array_type , name = "arr" ),
191+ xo .Arg (xo .Int64 , pointer = True , name = "out_nd" ),
192+ xo .Arg (xo .Int64 , pointer = True , name = "out_shape" ),
193+ ],
194+ ),
195+ }
196+
197+ test_context .add_kernels (
198+ sources = [source ],
199+ kernels = kernels ,
200+ )
201+
202+ instance = array_type (
203+ np .array (range (3 * 5 * 7 )).reshape ((3 , 5 , 7 )),
204+ _context = test_context ,
205+ )
206+
207+ expected_nd = 3
208+ result_nd = test_context .zeros ((1 ,), dtype = np .int64 )
209+
210+ expected_shape = [3 , 5 , 7 ]
211+ result_shape = test_context .zeros ((expected_nd ,), dtype = np .int64 )
212+
213+ test_context .kernels .get_nd_and_shape (
214+ arr = instance ,
215+ out_nd = result_nd ,
216+ out_shape = result_shape ,
217+ )
218+
219+ assert result_nd [0 ] == expected_nd
220+ assert result_shape [0 ] == expected_shape [0 ]
221+ assert result_shape [1 ] == expected_shape [1 ]
222+ assert result_shape [2 ] == expected_shape [2 ]
223+
224+
161225def test_struct1 ():
162226 kernels = Struct1 ._gen_kernels ()
163227 ctx = xo .ContextCpu ()
@@ -539,47 +603,47 @@ def test_getp1_dyn_length_dyn_type_string_array():
539603 assert ord (ffi .cast ("char *" , s2 )[8 + ii ]) == ch
540604
541605
542- def test_gpu_api ():
543- for ctx in xo . context . get_test_contexts ( ):
544- src_code = """
545- /*gpufun*/
546- void myfun(double x, double y,
547- double* z){
548- z[0] = x * y;
549- }
606+ @ for_all_test_contexts
607+ def test_gpu_api ( test_context ):
608+ src_code = """
609+ /*gpufun*/
610+ void myfun(double x, double y,
611+ double* z){
612+ z[0] = x * y;
613+ }
550614
551- /*gpukern*/
552- void my_mul(const int n,
553- /*gpuglmem*/ const double* x1,
554- /*gpuglmem*/ const double* x2,
555- /*gpuglmem*/ double* y) {
556- int tid = 0 //vectorize_over tid n
557- double z;
558- myfun(x1[tid], x2[tid], &z);
559- y[tid] = z;
560- //end_vectorize
561- }
562- """
563-
564- kernel_descriptions = {
565- "my_mul" : xo .Kernel (
566- args = [
567- xo .Arg (xo .Int32 , name = "n" ),
568- xo .Arg (xo .Float64 , pointer = True , const = True , name = "x1" ),
569- xo .Arg (xo .Float64 , pointer = True , const = True , name = "x2" ),
570- xo .Arg (xo .Float64 , pointer = True , const = False , name = "y" ),
571- ],
572- n_threads = "n" ,
573- ),
615+ /*gpukern*/
616+ void my_mul(const int n,
617+ /*gpuglmem*/ const double* x1,
618+ /*gpuglmem*/ const double* x2,
619+ /*gpuglmem*/ double* y) {
620+ int tid = 0 //vectorize_over tid n
621+ double z;
622+ myfun(x1[tid], x2[tid], &z);
623+ y[tid] = z;
624+ //end_vectorize
574625 }
626+ """
575627
576- ctx .add_kernels (
577- sources = [src_code ],
578- kernels = kernel_descriptions ,
579- save_source_as = None ,
580- compile = True ,
581- extra_classes = [xo .String [:]],
582- )
628+ kernel_descriptions = {
629+ "my_mul" : xo .Kernel (
630+ args = [
631+ xo .Arg (xo .Int32 , name = "n" ),
632+ xo .Arg (xo .Float64 , pointer = True , const = True , name = "x1" ),
633+ xo .Arg (xo .Float64 , pointer = True , const = True , name = "x2" ),
634+ xo .Arg (xo .Float64 , pointer = True , const = False , name = "y" ),
635+ ],
636+ n_threads = "n" ,
637+ ),
638+ }
639+
640+ test_context .add_kernels (
641+ sources = [src_code ],
642+ kernels = kernel_descriptions ,
643+ save_source_as = None ,
644+ compile = True ,
645+ extra_classes = [xo .String [:]],
646+ )
583647
584648
585649@for_all_test_contexts
@@ -595,7 +659,9 @@ class Cells(xo.Struct):
595659 ids = xo .Int64 [:]
596660 particles = xo .Int64 [:][:]
597661
598- cells = Cells (ids = cell_ids , particles = particle_per_cell )
662+ cells = Cells (
663+ ids = cell_ids , particles = particle_per_cell , _context = test_context
664+ )
599665
600666 # Data layout (displayed as uint64):
601667 #
@@ -627,23 +693,24 @@ class Cells(xo.Struct):
627693 src = r"""
628694 #include "xobjects/headers/common.h"
629695
630- int MAX_PARTICLES = 4;
631- int MAX_CELLS = 3;
696+ static const int MAX_PARTICLES = 4;
697+ static const int MAX_CELLS = 3;
632698
633- GPUKERN
634- uint8_t loop_over(Cells cells, uint64_t* out_counts, uint64_t* out_vals)
699+ GPUKERN void loop_over(
700+ Cells cells,
701+ GPUGLMEM uint64_t* out_counts,
702+ GPUGLMEM uint64_t* out_vals,
703+ GPUGLMEM uint8_t* success
704+ )
635705 {
636- uint8_t success = 1;
637706 int64_t num_cells = Cells_len_ids(cells);
638707
639708 for (int64_t i = 0; i < num_cells; i++) {
640709 int64_t id = Cells_get_ids(cells, i);
641710 int64_t count = Cells_len1_particles(cells, i);
642711
643- printf("Cell ID: %lld\n Particles (count %lld): ", id, count);
644-
645712 if (i >= MAX_CELLS) {
646- success = 0;
713+ * success = 0;
647714 continue;
648715 }
649716
@@ -654,19 +721,23 @@ class Cells(xo.Struct):
654721
655722 VECTORIZE_OVER(j, num_particles);
656723 int64_t val = ArrNInt64_get(particles, j);
657- printf("%lld ", val);
658724
659725 if (j >= MAX_PARTICLES) {
660- success = 0;
661- continue;
726+ *success = 0;
727+ } else {
728+ out_vals[i * MAX_PARTICLES + j] = val;
662729 }
663-
664- out_vals[i * MAX_PARTICLES + j] = val;
665730 END_VECTORIZE;
666- printf("\n");
667731 }
668- fflush(stdout);
669- return success;
732+ }
733+
734+ GPUKERN void kernel_Cells_get_particles(
735+ Cells obj,
736+ int64_t i0,
737+ int64_t i1,
738+ GPUGLMEM int64_t* out
739+ ) {
740+ *out = Cells_get_particles(obj, i0, i1);
670741 }
671742 """
672743
@@ -676,33 +747,46 @@ class Cells(xo.Struct):
676747 xo .Arg (Cells , name = "cells" ),
677748 xo .Arg (xo .UInt64 , pointer = True , name = "out_counts" ),
678749 xo .Arg (xo .UInt64 , pointer = True , name = "out_vals" ),
750+ xo .Arg (xo .UInt8 , pointer = True , name = "success" ),
679751 ],
680- n_threads = "n" ,
681- ret = xo .Arg (xo .UInt8 ),
682- )
752+ n_threads = 4 ,
753+ ),
754+ "kernel_Cells_get_particles" : xo .Kernel (
755+ args = [
756+ xo .Arg (Cells , name = "obj" ),
757+ xo .Arg (xo .Int64 , name = "i0" ),
758+ xo .Arg (xo .Int64 , name = "i1" ),
759+ xo .Arg (xo .Int64 , pointer = True , name = "out" ),
760+ ],
761+ ),
683762 }
684- kernels .update (Cells ._gen_kernels ())
685763
686764 test_context .add_kernels (
687765 sources = [src ],
688766 kernels = kernels ,
689767 )
690768
691- counts = np .zeros (len (cell_ids ), dtype = np .uint64 )
692- vals = np .zeros (12 , dtype = np .uint64 )
769+ counts = test_context .zeros (len (cell_ids ), dtype = np .uint64 )
770+ vals = test_context .zeros (12 , dtype = np .uint64 )
771+ success = test_context .zeros ((1 ,), dtype = np .uint8 ) + 1
693772
694773 for i , _ in enumerate (particle_per_cell ):
695774 for j , expected in enumerate (particle_per_cell [i ]):
696- result = test_context .kernels .Cells_get_particles (
697- obj = cells , i0 = i , i1 = j
775+ result = test_context .zeros (shape = (1 ,), dtype = np .int64 )
776+ test_context .kernels .kernel_Cells_get_particles (
777+ obj = cells , i0 = i , i1 = j , out = result
698778 )
699- assert result == expected
779+ assert result [ 0 ] == expected
700780
701- ret = test_context .kernels .loop_over (
781+ test_context .kernels .loop_over (
702782 cells = cells ,
703783 out_counts = counts ,
704784 out_vals = vals ,
785+ success = success ,
705786 )
706- assert ret == 1
787+ counts = test_context .nparray_from_context_array (counts )
788+ vals = test_context .nparray_from_context_array (vals )
789+
790+ assert success [0 ] == 1
707791 assert np .all (counts == [2 , 3 , 4 ])
708792 assert np .all (vals == [1 , 8 , 0 , 0 , 9 , 3 , 2 , 0 , 4 , 5 , 6 , 7 ])
0 commit comments