|
3 | 3 | # Copyright (c) CERN, 2021. # |
4 | 4 | # ########################################### # |
5 | 5 | import cffi |
| 6 | +import numpy as np |
6 | 7 |
|
7 | 8 | import xobjects as xo |
8 | 9 | from xobjects.test_helpers import for_all_test_contexts, requires_context |
@@ -213,6 +214,88 @@ class MyStruct(xo.Struct): |
213 | 214 | assert s2.b[1] == 4 |
214 | 215 |
|
215 | 216 |
|
| 217 | +def test_kernel_namings(): |
| 218 | + class MyStruct(xo.Struct): |
| 219 | + n = xo.Int32 |
| 220 | + var_mult_1 = xo.Float64[:] |
| 221 | + var_mult_2 = xo.Float64[:] |
| 222 | + var_mult_3 = xo.Float64[:] |
| 223 | + var_mult_4 = xo.Float64[:] |
| 224 | + |
| 225 | + _extra_c_sources = [r""" |
| 226 | +double mul(MyStruct stru) { |
| 227 | + int32_t n = MyStruct_get_n(stru); |
| 228 | + double* var_mult_1 = MyStruct_getp1_var_mult_1(stru, 0); |
| 229 | + double* var_mult_2 = MyStruct_getp1_var_mult_2(stru, 0); |
| 230 | + double y = 0; |
| 231 | + for (int32_t tid=0; tid<n; tid++){ |
| 232 | + y+= var_mult_1[tid] * var_mult_2[tid]; |
| 233 | + } |
| 234 | + return y; |
| 235 | + } |
| 236 | +
|
| 237 | +double mult(MyStruct stru) { |
| 238 | + int32_t n = MyStruct_get_n(stru); |
| 239 | + double* var_mult_1 = MyStruct_getp1_var_mult_1(stru, 0); |
| 240 | + double* var_mult_2 = MyStruct_getp1_var_mult_2(stru, 0); |
| 241 | + double* var_mult_3 = MyStruct_getp1_var_mult_3(stru, 0); |
| 242 | + double y = 0; |
| 243 | + for (int32_t tid=0; tid<n; tid++){ |
| 244 | + y+= var_mult_1[tid] * var_mult_2[tid] * var_mult_3[tid]; |
| 245 | + } |
| 246 | + return y; |
| 247 | + } |
| 248 | +
|
| 249 | +double mult_four(MyStruct stru) { |
| 250 | + int32_t n = MyStruct_get_n(stru); |
| 251 | + double* var_mult_1 = MyStruct_getp1_var_mult_1(stru, 0); |
| 252 | + double* var_mult_2 = MyStruct_getp1_var_mult_2(stru, 0); |
| 253 | + double* var_mult_3 = MyStruct_getp1_var_mult_3(stru, 0); |
| 254 | + double* var_mult_4 = MyStruct_getp1_var_mult_4(stru, 0); |
| 255 | + double y = 0; |
| 256 | + for (int32_t tid=0; tid<n; tid++){ |
| 257 | + y+= var_mult_1[tid] * var_mult_2[tid] * var_mult_3[tid] * var_mult_4[tid]; |
| 258 | + } |
| 259 | + return y; |
| 260 | + }"""] |
| 261 | + |
| 262 | + kernel_descriptions = { |
| 263 | + "mul": xo.Kernel( |
| 264 | + args=[xo.Arg(MyStruct, name="stru")], |
| 265 | + ret=xo.Arg(xo.Float64), |
| 266 | + ), |
| 267 | + "mult": xo.Kernel( |
| 268 | + c_name="mult", |
| 269 | + args=[xo.Arg(MyStruct, name="stru")], |
| 270 | + ret=xo.Arg(xo.Float64), |
| 271 | + ), |
| 272 | + "pyname_mul": xo.Kernel( |
| 273 | + c_name="mult_four", |
| 274 | + args=[xo.Arg(MyStruct, name="stru")], |
| 275 | + ret=xo.Arg(xo.Float64), |
| 276 | + ) |
| 277 | + } |
| 278 | + |
| 279 | + a1 = np.arange(10.0) |
| 280 | + a2 = np.arange(10.0) |
| 281 | + a3 = np.arange(10.0) |
| 282 | + a4 = np.arange(10.0) |
| 283 | + stru = MyStruct(n=10, var_mult_1=a1, var_mult_2=a2, var_mult_3=a3, var_mult_4=a4) |
| 284 | + |
| 285 | + ctx = stru._context |
| 286 | + ctx.add_kernels(kernels=kernel_descriptions) |
| 287 | + assert "mul" in ctx.kernels |
| 288 | + # assert "mult" in ctx.kernels |
| 289 | + assert "pyname_mul" in ctx.kernels |
| 290 | + |
| 291 | + y = ctx.kernels.mul(stru=stru) |
| 292 | + assert y == 285.0 |
| 293 | + y = ctx.kernels.mult(stru=stru) |
| 294 | + assert y == 2025.0 |
| 295 | + y = ctx.kernels.pyname_mul(stru=stru) |
| 296 | + assert y == 15333.0 |
| 297 | + |
| 298 | + |
216 | 299 | @requires_context("ContextCpu") |
217 | 300 | def test_compile_kernels_only_if_needed(tmp_path, mocker): |
218 | 301 | """Test the use case of xtrack. |
|
0 commit comments