Skip to content

Commit 11d6bd4

Browse files
committed
Added a test in struct with kernel namings that would fail in the previous implementation
1 parent 24f0db1 commit 11d6bd4

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

tests/test_struct.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright (c) CERN, 2021. #
44
# ########################################### #
55
import cffi
6+
import numpy as np
67

78
import xobjects as xo
89
from xobjects.test_helpers import for_all_test_contexts, requires_context
@@ -213,6 +214,88 @@ class MyStruct(xo.Struct):
213214
assert s2.b[1] == 4
214215

215216

217+
def test_kernel_namings():
218+
class MyStruct(xo.Struct):
219+
n = xo.Int32
220+
var_mult_1 = xo.Float64[:]
221+
var_mult_2 = xo.Float64[:]
222+
var_mult_3 = xo.Float64[:]
223+
var_mult_4 = xo.Float64[:]
224+
225+
_extra_c_sources = [r"""
226+
double mul(MyStruct stru) {
227+
int32_t n = MyStruct_get_n(stru);
228+
double* var_mult_1 = MyStruct_getp1_var_mult_1(stru, 0);
229+
double* var_mult_2 = MyStruct_getp1_var_mult_2(stru, 0);
230+
double y = 0;
231+
for (int32_t tid=0; tid<n; tid++){
232+
y+= var_mult_1[tid] * var_mult_2[tid];
233+
}
234+
return y;
235+
}
236+
237+
double mult(MyStruct stru) {
238+
int32_t n = MyStruct_get_n(stru);
239+
double* var_mult_1 = MyStruct_getp1_var_mult_1(stru, 0);
240+
double* var_mult_2 = MyStruct_getp1_var_mult_2(stru, 0);
241+
double* var_mult_3 = MyStruct_getp1_var_mult_3(stru, 0);
242+
double y = 0;
243+
for (int32_t tid=0; tid<n; tid++){
244+
y+= var_mult_1[tid] * var_mult_2[tid] * var_mult_3[tid];
245+
}
246+
return y;
247+
}
248+
249+
double mult_four(MyStruct stru) {
250+
int32_t n = MyStruct_get_n(stru);
251+
double* var_mult_1 = MyStruct_getp1_var_mult_1(stru, 0);
252+
double* var_mult_2 = MyStruct_getp1_var_mult_2(stru, 0);
253+
double* var_mult_3 = MyStruct_getp1_var_mult_3(stru, 0);
254+
double* var_mult_4 = MyStruct_getp1_var_mult_4(stru, 0);
255+
double y = 0;
256+
for (int32_t tid=0; tid<n; tid++){
257+
y+= var_mult_1[tid] * var_mult_2[tid] * var_mult_3[tid] * var_mult_4[tid];
258+
}
259+
return y;
260+
}"""]
261+
262+
kernel_descriptions = {
263+
"mul": xo.Kernel(
264+
args=[xo.Arg(MyStruct, name="stru")],
265+
ret=xo.Arg(xo.Float64),
266+
),
267+
"mult": xo.Kernel(
268+
c_name="mult",
269+
args=[xo.Arg(MyStruct, name="stru")],
270+
ret=xo.Arg(xo.Float64),
271+
),
272+
"pyname_mul": xo.Kernel(
273+
c_name="mult_four",
274+
args=[xo.Arg(MyStruct, name="stru")],
275+
ret=xo.Arg(xo.Float64),
276+
)
277+
}
278+
279+
a1 = np.arange(10.0)
280+
a2 = np.arange(10.0)
281+
a3 = np.arange(10.0)
282+
a4 = np.arange(10.0)
283+
stru = MyStruct(n=10, var_mult_1=a1, var_mult_2=a2, var_mult_3=a3, var_mult_4=a4)
284+
285+
ctx = stru._context
286+
ctx.add_kernels(kernels=kernel_descriptions)
287+
assert "mul" in ctx.kernels
288+
# assert "mult" in ctx.kernels
289+
assert "pyname_mul" in ctx.kernels
290+
291+
y = ctx.kernels.mul(stru=stru)
292+
assert y == 285.0
293+
y = ctx.kernels.mult(stru=stru)
294+
assert y == 2025.0
295+
y = ctx.kernels.pyname_mul(stru=stru)
296+
assert y == 15333.0
297+
298+
216299
@requires_context("ContextCpu")
217300
def test_compile_kernels_only_if_needed(tmp_path, mocker):
218301
"""Test the use case of xtrack.

0 commit comments

Comments
 (0)