|
11 | 11 | import cffi |
12 | 12 |
|
13 | 13 | import xobjects as xo |
14 | | - |
| 14 | +from xobjects.test_helpers import for_all_test_contexts |
15 | 15 |
|
16 | 16 | ffi = cffi.FFI() |
17 | 17 |
|
@@ -576,8 +576,133 @@ def test_gpu_api(): |
576 | 576 | ctx.add_kernels( |
577 | 577 | sources=[src_code], |
578 | 578 | kernels=kernel_descriptions, |
579 | | - # save_src_as=f'_test_{name}.c') |
580 | 579 | save_source_as=None, |
581 | 580 | compile=True, |
582 | 581 | extra_classes=[xo.String[:]], |
583 | 582 | ) |
| 583 | + |
| 584 | + |
| 585 | +@for_all_test_contexts |
| 586 | +def test_array_of_arrays(test_context): |
| 587 | + cell_ids = [3, 5, 7] |
| 588 | + particle_per_cell = [ |
| 589 | + [1, 8], |
| 590 | + [9, 3, 2], |
| 591 | + [4, 5, 6, 7], |
| 592 | + ] |
| 593 | + |
| 594 | + class Cells(xo.Struct): |
| 595 | + ids = xo.Int64[:] |
| 596 | + particles = xo.Int64[:][:] |
| 597 | + |
| 598 | + cells = Cells(ids=cell_ids, particles=particle_per_cell) |
| 599 | + |
| 600 | + # Data layout (displayed as uint64): |
| 601 | + # |
| 602 | + # [0] 216 (cells size) |
| 603 | + # [8] 56 (offset field 2 -- particles field) |
| 604 | + # [16] cell_ids data: |
| 605 | + # [0] 40 (cell_ids size) |
| 606 | + # [8] 3 (cell_ids length) |
| 607 | + # [16] {3, 5, 7} (cell_ids elements) |
| 608 | + # [56] particles data: |
| 609 | + # [0] 160 (particles size) |
| 610 | + # [8] 3 (particles length) |
| 611 | + # [16] 40 (offset particles[0]) |
| 612 | + # [24] 72 (offset particles[1]) |
| 613 | + # [32] 112 (offset particles[2]) |
| 614 | + # [40] particles[0] data: |
| 615 | + # [0] 32 (particles[0] size) |
| 616 | + # [8] 2 (particles[0] length) |
| 617 | + # [16] {1, 8} (particles[0] elements) |
| 618 | + # [72] particles[1] data: |
| 619 | + # [0] 40 (particles[1] size) |
| 620 | + # [8] 3 (particles[1] length) |
| 621 | + # [16] {9, 3, 2} (particles[1 |
| 622 | + # [112] particles[2] data: |
| 623 | + # [0] 48 (particles[2] size) |
| 624 | + # [8] 4 (particles[2] length) |
| 625 | + # [16] {4, 5, 6, 7} (particles[2] elements) |
| 626 | + |
| 627 | + src = r""" |
| 628 | + #include "xobjects/headers/common.h" |
| 629 | +
|
| 630 | + int MAX_PARTICLES = 4; |
| 631 | + int MAX_CELLS = 3; |
| 632 | +
|
| 633 | + GPUKERN |
| 634 | + uint8_t loop_over(Cells cells, uint64_t* out_counts, uint64_t* out_vals) |
| 635 | + { |
| 636 | + uint8_t success = 1; |
| 637 | + int64_t num_cells = Cells_len_ids(cells); |
| 638 | +
|
| 639 | + for (int64_t i = 0; i < num_cells; i++) { |
| 640 | + int64_t id = Cells_get_ids(cells, i); |
| 641 | + int64_t count = Cells_len1_particles(cells, i); |
| 642 | +
|
| 643 | + printf("Cell ID: %lld\n Particles (count %lld): ", id, count); |
| 644 | +
|
| 645 | + if (i >= MAX_CELLS) { |
| 646 | + success = 0; |
| 647 | + continue; |
| 648 | + } |
| 649 | +
|
| 650 | + out_counts[i] = count; |
| 651 | +
|
| 652 | + ArrNInt64 particles = Cells_getp1_particles(cells, i); |
| 653 | + uint32_t num_particles = ArrNInt64_len(particles); |
| 654 | +
|
| 655 | + VECTORIZE_OVER(j, num_particles); |
| 656 | + int64_t val = ArrNInt64_get(particles, j); |
| 657 | + printf("%lld ", val); |
| 658 | +
|
| 659 | + if (j >= MAX_PARTICLES) { |
| 660 | + success = 0; |
| 661 | + continue; |
| 662 | + } |
| 663 | +
|
| 664 | + out_vals[i * MAX_PARTICLES + j] = val; |
| 665 | + END_VECTORIZE; |
| 666 | + printf("\n"); |
| 667 | + } |
| 668 | + fflush(stdout); |
| 669 | + return success; |
| 670 | + } |
| 671 | + """ |
| 672 | + |
| 673 | + kernels = { |
| 674 | + "loop_over": xo.Kernel( |
| 675 | + args=[ |
| 676 | + xo.Arg(Cells, name="cells"), |
| 677 | + xo.Arg(xo.UInt64, pointer=True, name="out_counts"), |
| 678 | + xo.Arg(xo.UInt64, pointer=True, name="out_vals"), |
| 679 | + ], |
| 680 | + n_threads="n", |
| 681 | + ret=xo.Arg(xo.UInt8), |
| 682 | + ) |
| 683 | + } |
| 684 | + kernels.update(Cells._gen_kernels()) |
| 685 | + |
| 686 | + test_context.add_kernels( |
| 687 | + sources=[src], |
| 688 | + kernels=kernels, |
| 689 | + ) |
| 690 | + |
| 691 | + counts = np.zeros(len(cell_ids), dtype=np.uint64) |
| 692 | + vals = np.zeros(12, dtype=np.uint64) |
| 693 | + |
| 694 | + for i, _ in enumerate(particle_per_cell): |
| 695 | + for j, expected in enumerate(particle_per_cell[i]): |
| 696 | + result = test_context.kernels.Cells_get_particles( |
| 697 | + obj=cells, i0=i, i1=j |
| 698 | + ) |
| 699 | + assert result == expected |
| 700 | + |
| 701 | + ret = test_context.kernels.loop_over( |
| 702 | + cells=cells, |
| 703 | + out_counts=counts, |
| 704 | + out_vals=vals, |
| 705 | + ) |
| 706 | + assert ret == 1 |
| 707 | + assert np.all(counts == [2, 3, 4]) |
| 708 | + assert np.all(vals == [1, 8, 0, 0, 9, 3, 2, 0, 4, 5, 6, 7]) |
0 commit comments