Skip to content

Commit e5bad86

Browse files
Bugfix: capi get/set methods broken for nested arrays
1 parent 1fb28f2 commit e5bad86

File tree

2 files changed

+141
-14
lines changed

2 files changed

+141
-14
lines changed

tests/test_capi.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import cffi
1212

1313
import xobjects as xo
14-
14+
from xobjects.test_helpers import for_all_test_contexts
1515

1616
ffi = cffi.FFI()
1717

@@ -576,8 +576,133 @@ def test_gpu_api():
576576
ctx.add_kernels(
577577
sources=[src_code],
578578
kernels=kernel_descriptions,
579-
# save_src_as=f'_test_{name}.c')
580579
save_source_as=None,
581580
compile=True,
582581
extra_classes=[xo.String[:]],
583582
)
583+
584+
585+
@for_all_test_contexts
586+
def test_array_of_arrays(test_context):
587+
cell_ids = [3, 5, 7]
588+
particle_per_cell = [
589+
[1, 8],
590+
[9, 3, 2],
591+
[4, 5, 6, 7],
592+
]
593+
594+
class Cells(xo.Struct):
595+
ids = xo.Int64[:]
596+
particles = xo.Int64[:][:]
597+
598+
cells = Cells(ids=cell_ids, particles=particle_per_cell)
599+
600+
# Data layout (displayed as uint64):
601+
#
602+
# [0] 216 (cells size)
603+
# [8] 56 (offset field 2 -- particles field)
604+
# [16] cell_ids data:
605+
# [0] 40 (cell_ids size)
606+
# [8] 3 (cell_ids length)
607+
# [16] {3, 5, 7} (cell_ids elements)
608+
# [56] particles data:
609+
# [0] 160 (particles size)
610+
# [8] 3 (particles length)
611+
# [16] 40 (offset particles[0])
612+
# [24] 72 (offset particles[1])
613+
# [32] 112 (offset particles[2])
614+
# [40] particles[0] data:
615+
# [0] 32 (particles[0] size)
616+
# [8] 2 (particles[0] length)
617+
# [16] {1, 8} (particles[0] elements)
618+
# [72] particles[1] data:
619+
# [0] 40 (particles[1] size)
620+
# [8] 3 (particles[1] length)
621+
# [16] {9, 3, 2} (particles[1
622+
# [112] particles[2] data:
623+
# [0] 48 (particles[2] size)
624+
# [8] 4 (particles[2] length)
625+
# [16] {4, 5, 6, 7} (particles[2] elements)
626+
627+
src = r"""
628+
#include "xobjects/headers/common.h"
629+
630+
int MAX_PARTICLES = 4;
631+
int MAX_CELLS = 3;
632+
633+
GPUKERN
634+
uint8_t loop_over(Cells cells, uint64_t* out_counts, uint64_t* out_vals)
635+
{
636+
uint8_t success = 1;
637+
int64_t num_cells = Cells_len_ids(cells);
638+
639+
for (int64_t i = 0; i < num_cells; i++) {
640+
int64_t id = Cells_get_ids(cells, i);
641+
int64_t count = Cells_len1_particles(cells, i);
642+
643+
printf("Cell ID: %lld\n Particles (count %lld): ", id, count);
644+
645+
if (i >= MAX_CELLS) {
646+
success = 0;
647+
continue;
648+
}
649+
650+
out_counts[i] = count;
651+
652+
ArrNInt64 particles = Cells_getp1_particles(cells, i);
653+
uint32_t num_particles = ArrNInt64_len(particles);
654+
655+
VECTORIZE_OVER(j, num_particles);
656+
int64_t val = ArrNInt64_get(particles, j);
657+
printf("%lld ", val);
658+
659+
if (j >= MAX_PARTICLES) {
660+
success = 0;
661+
continue;
662+
}
663+
664+
out_vals[i * MAX_PARTICLES + j] = val;
665+
END_VECTORIZE;
666+
printf("\n");
667+
}
668+
fflush(stdout);
669+
return success;
670+
}
671+
"""
672+
673+
kernels = {
674+
"loop_over": xo.Kernel(
675+
args=[
676+
xo.Arg(Cells, name="cells"),
677+
xo.Arg(xo.UInt64, pointer=True, name="out_counts"),
678+
xo.Arg(xo.UInt64, pointer=True, name="out_vals"),
679+
],
680+
n_threads="n",
681+
ret=xo.Arg(xo.UInt8),
682+
)
683+
}
684+
kernels.update(Cells._gen_kernels())
685+
686+
test_context.add_kernels(
687+
sources=[src],
688+
kernels=kernels,
689+
)
690+
691+
counts = np.zeros(len(cell_ids), dtype=np.uint64)
692+
vals = np.zeros(12, dtype=np.uint64)
693+
694+
for i, _ in enumerate(particle_per_cell):
695+
for j, expected in enumerate(particle_per_cell[i]):
696+
result = test_context.kernels.Cells_get_particles(
697+
obj=cells, i0=i, i1=j
698+
)
699+
assert result == expected
700+
701+
ret = test_context.kernels.loop_over(
702+
cells=cells,
703+
out_counts=counts,
704+
out_vals=vals,
705+
)
706+
assert ret == 1
707+
assert np.all(counts == [2, 3, 4])
708+
assert np.all(vals == [1, 8, 0, 0, 9, 3, 2, 0, 4, 5, 6, 7])

xobjects/capi.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,10 @@ def get_layers(parts):
103103

104104

105105
def int_from_obj(offset, conf):
106-
inttype = gen_pointer(conf.get("inttype", "int64_t") + "*", conf)
107-
chartype = gen_pointer(conf.get("chartype", "char") + "*", conf)
108-
return f"*({inttype})(({chartype}) obj+{offset})"
106+
"""Generate code to read the integer at location obj + offset (bytes)."""
107+
int_pointer_type = gen_pointer(conf.get("inttype", "int64_t") + "*", conf)
108+
char_pointer_type = gen_pointer(conf.get("chartype", "char") + "*", conf)
109+
return f"*({int_pointer_type})(({char_pointer_type}) obj+{offset})"
109110

110111

111112
def Field_get_c_offset(self, conf):
@@ -146,7 +147,7 @@ def Index_get_c_offset(part, conf, icount):
146147
out.append(f" offset+={soffset};")
147148
else:
148149
lookup_field_offset = f"offset+{soffset}"
149-
out.append(f" offset={int_from_obj(lookup_field_offset, conf)};")
150+
out.append(f" offset+={int_from_obj(lookup_field_offset, conf)};")
150151
return out
151152

152153

@@ -206,16 +207,17 @@ def gen_fun_kernel(cls, path, action, const, extra, ret, add_nindex=False):
206207
def gen_c_pointed(target: Arg, conf):
207208
size = gen_c_size_from_arg(target, conf)
208209
ret = gen_c_type_from_arg(target, conf)
210+
209211
if target.pointer or is_compound(target.atype) or is_string(target.atype):
210212
chartype = gen_pointer(conf.get("chartype", "char") + "*", conf)
211213
return f"({ret})(({chartype}) obj+offset)"
212-
else:
213-
rettype = gen_pointer(ret + "*", conf)
214-
if size == 1:
215-
return f"*(({rettype}) obj+offset)"
216-
else:
217-
chartype = gen_pointer(conf.get("chartype", "char") + "*", conf)
218-
return f"*({rettype})(({chartype}) obj+offset)"
214+
215+
rettype = gen_pointer(ret + "*", conf)
216+
if size == 1:
217+
return f"*(({rettype}) obj+offset)"
218+
219+
chartype = gen_pointer(conf.get("chartype", "char") + "*", conf)
220+
return f"*({rettype})(({chartype}) obj+offset)"
219221

220222

221223
def gen_method_get(cls, path, conf):
@@ -507,7 +509,7 @@ def methods_from_path(cls, path, conf):
507509

508510
if is_array(lasttype):
509511
out.append(gen_method_len(cls, path, conf))
510-
# out.append(gen_method_shape(cls, path, conf))
512+
out.append(gen_method_shape(cls, path, conf))
511513
# out.append(gen_method_nd(cls, path, conf))
512514
# out.append(gen_method_strides(cls, path, conf))
513515
# out.append(gen_method_getpos(cls, path, conf))

0 commit comments

Comments
 (0)