Skip to content

Commit 414e8a1

Browse files
fix test_array_get_shape
1 parent 7944487 commit 414e8a1

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

tests/test_capi.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,51 @@ def test_array_dynamic_type_init_get_set(array_cls, example_shape):
168168
],
169169
)
170170
def test_array_get_shape(test_context, array_type):
171-
kernels = array_type._gen_kernels()
172-
test_context.add_kernels(kernels=kernels)
171+
source = """
172+
GPUKERN void get_nd_and_shape(
173+
ARRAY_TYPE arr,
174+
GPUGLMEM int64_t* out_nd,
175+
GPUGLMEM int64_t* out_shape
176+
) {
177+
*out_nd = ARRAY_TYPE_nd(arr);
178+
ARRAY_TYPE_shape(arr, out_shape);
179+
}
180+
""".replace(
181+
"ARRAY_TYPE", array_type.__name__
182+
)
183+
184+
kernels = {
185+
"get_nd_and_shape": xo.Kernel(
186+
c_name="get_nd_and_shape",
187+
args=[
188+
xo.Arg(array_type, name="arr"),
189+
xo.Arg(xo.Int64, pointer=True, name="out_nd"),
190+
xo.Arg(xo.Int64, pointer=True, name="out_shape"),
191+
],
192+
),
193+
}
194+
195+
test_context.add_kernels(
196+
sources=[src],
197+
kernels=kernels,
198+
)
173199

174200
instance = array_type(np.array(range(3 * 5 * 7)).reshape((3, 5, 7)))
175201

176-
nd_function = test_context.kernels[f"{array_type.__name__}_nd"]
177-
nd = nd_function(obj=instance)
178-
assert nd == 3
202+
expected_nd = 3
203+
result_nd = test_context.zeros((1,), dtype=np.int64)
204+
205+
expected_shape = [3, 5, 7]
206+
result_shape = test_context.zeros((expected_nd,), dtype=np.int64)
179207

180-
shape = np.zeros(nd, dtype=np.int64)
181-
shape_function = test_context.kernels[f"{array_type.__name__}_shape"]
182-
shape_function(obj=instance, out_shape=shape)
208+
test_context.kernels.get_nd_and_shape(
209+
arr=instance,
210+
out_nd=result_nd,
211+
out_shape=result_shape,
212+
)
183213

184-
assert shape[0] == 3
185-
assert shape[1] == 5
186-
assert shape[2] == 7
214+
assert result_nd[0] == expected_nd
215+
assert np.all(result_shape == expected_shape)
187216

188217

189218
def test_struct1():

0 commit comments

Comments
 (0)