@@ -168,22 +168,51 @@ def test_array_dynamic_type_init_get_set(array_cls, example_shape):
168168 ],
169169)
170170def test_array_get_shape (test_context , array_type ):
171- kernels = array_type ._gen_kernels ()
172- test_context .add_kernels (kernels = kernels )
171+ source = """
172+ GPUKERN void get_nd_and_shape(
173+ ARRAY_TYPE arr,
174+ GPUGLMEM int64_t* out_nd,
175+ GPUGLMEM int64_t* out_shape
176+ ) {
177+ *out_nd = ARRAY_TYPE_nd(arr);
178+ ARRAY_TYPE_shape(arr, out_shape);
179+ }
180+ """ .replace (
181+ "ARRAY_TYPE" , array_type .__name__
182+ )
183+
184+ kernels = {
185+ "get_nd_and_shape" : xo .Kernel (
186+ c_name = "get_nd_and_shape" ,
187+ args = [
188+ xo .Arg (array_type , name = "arr" ),
189+ xo .Arg (xo .Int64 , pointer = True , name = "out_nd" ),
190+ xo .Arg (xo .Int64 , pointer = True , name = "out_shape" ),
191+ ],
192+ ),
193+ }
194+
195+ test_context .add_kernels (
196+ sources = [src ],
197+ kernels = kernels ,
198+ )
173199
174200 instance = array_type (np .array (range (3 * 5 * 7 )).reshape ((3 , 5 , 7 )))
175201
176- nd_function = test_context .kernels [f"{ array_type .__name__ } _nd" ]
177- nd = nd_function (obj = instance )
178- assert nd == 3
202+ expected_nd = 3
203+ result_nd = test_context .zeros ((1 ,), dtype = np .int64 )
204+
205+ expected_shape = [3 , 5 , 7 ]
206+ result_shape = test_context .zeros ((expected_nd ,), dtype = np .int64 )
179207
180- shape = np .zeros (nd , dtype = np .int64 )
181- shape_function = test_context .kernels [f"{ array_type .__name__ } _shape" ]
182- shape_function (obj = instance , out_shape = shape )
208+ test_context .kernels .get_nd_and_shape (
209+ arr = instance ,
210+ out_nd = result_nd ,
211+ out_shape = result_shape ,
212+ )
183213
184- assert shape [0 ] == 3
185- assert shape [1 ] == 5
186- assert shape [2 ] == 7
214+ assert result_nd [0 ] == expected_nd
215+ assert np .all (result_shape == expected_shape )
187216
188217
189218def test_struct1 ():
0 commit comments