Skip to content

Commit 1cbaaba

Browse files
committed
test updates, fix init bug
1 parent f274155 commit 1cbaaba

2 files changed

Lines changed: 79 additions & 2 deletions

File tree

Plugins/Arrow/src/ArrowUtil.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "ActsPlugins/Arrow/ArrowUtil.hpp"
1010

1111
#include <algorithm>
12+
#include <mutex>
1213
#include <stdexcept>
1314
#include <vector>
1415

@@ -39,6 +40,21 @@ T unwrap(arrow::Result<T> result, const std::string& what) {
3940
return std::move(result).ValueOrDie();
4041
}
4142

43+
// Arrow's compute kernels (e.g. `equal` used for the event_id filter
44+
// pushdown below) are registered lazily. With the linker-isolated arrow
45+
// island the registry is empty until we ask for it explicitly — without
46+
// this call, scan-time filtering fails with "No function registered with
47+
// name: equal".
48+
void ensureComputeInitialized() {
49+
static std::once_flag flag;
50+
std::call_once(flag, [] {
51+
auto status = arrow::compute::Initialize();
52+
if (!status.ok()) {
53+
throwArrow("arrow compute init", status);
54+
}
55+
});
56+
}
57+
4258
} // namespace
4359

4460
std::shared_ptr<arrow::Field> eventIdField() {
@@ -133,6 +149,8 @@ class ParquetDatasetReader::Impl {
133149
public:
134150
explicit Impl(std::filesystem::path directory)
135151
: m_directory(std::move(directory)) {
152+
ensureComputeInitialized();
153+
136154
if (!std::filesystem::exists(m_directory) ||
137155
!std::filesystem::is_directory(m_directory)) {
138156
throw std::invalid_argument("ParquetDatasetReader: not a directory: " +

Python/Examples/tests/test_arrow.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from acts import UnitConstants as u
99
from acts.examples import Sequencer
1010

11-
from helpers import arrowEnabled, isCI, pythia8Enabled
11+
from helpers import AssertCollectionExistsAlg, arrowEnabled, isCI, pythia8Enabled
1212

1313

1414
pytestmark = pytest.mark.skipif(
@@ -152,6 +152,7 @@ def _add_arrow_writer(
152152
s: Sequencer,
153153
outputDir: Path,
154154
inputs_to_tables: dict[str, str],
155+
eventsPerShard: int = 2,
155156
) -> None:
156157
"""Wire one ArrowParticleOutputConverter per (input, table) pair, and one
157158
ParquetWriter picking up all the resulting tables.
@@ -185,7 +186,7 @@ def _add_arrow_writer(
185186
table_key: table_key
186187
for table_key in inputs_to_tables.values()
187188
},
188-
eventsPerShard=2,
189+
eventsPerShard=eventsPerShard,
189190
)
190191
)
191192

@@ -206,6 +207,64 @@ def test_particle_gun_generated(tmp_path, ptcl_gun):
206207
)
207208

208209

210+
def test_particle_gun_roundtrip(tmp_path, ptcl_gun):
211+
"""Write sharded Parquet, then drive a second Sequencer off ParquetReader
212+
and check the reader exposes — and processes — the same number of events
213+
that were written."""
214+
from acts.examples.arrow import ParquetReader
215+
216+
# nevents/eventsPerShard chosen so the write phase produces multiple
217+
# shards with a non-full final shard — exercises shard discovery, the
218+
# multi-fragment scan, and the partial-shard edge case in one test.
219+
nevents = 5
220+
events_per_shard = 2
221+
expected_shards = (nevents + events_per_shard - 1) // events_per_shard
222+
223+
s_write = Sequencer(numThreads=1, events=nevents)
224+
ptcl_gun(s_write)
225+
_add_arrow_writer(
226+
s_write,
227+
tmp_path,
228+
{"particles_generated": "particles_generated_arrow"},
229+
eventsPerShard=events_per_shard,
230+
)
231+
s_write.run()
232+
233+
out_dir = tmp_path / "particles_generated_arrow"
234+
_assert_particles_parquet(out_dir, nevents)
235+
236+
shards = sorted(out_dir.glob("*.parquet"))
237+
assert len(shards) == expected_shards, (
238+
f"expected {expected_shards} shards for {nevents} events at "
239+
f"{events_per_shard} events/shard, got {len(shards)}: "
240+
f"{[s.name for s in shards]}"
241+
)
242+
243+
reader = ParquetReader(
244+
level=acts.logging.INFO,
245+
inputDir=str(tmp_path),
246+
collections={"particles_generated_arrow": "particles_generated_arrow"},
247+
)
248+
assert reader.availableEvents() == (0, nevents)
249+
250+
# No `events=` — the sequencer derives the event range from the reader,
251+
# so a wrong count here would surface as a mismatch with `nevents`.
252+
s_read = Sequencer(numThreads=1)
253+
s_read.addReader(reader)
254+
counter = AssertCollectionExistsAlg(
255+
collections="particles_generated_arrow",
256+
name="roundtrip_check",
257+
level=acts.logging.INFO,
258+
)
259+
s_read.addAlgorithm(counter)
260+
s_read.run()
261+
262+
assert counter.events_seen == nevents, (
263+
f"reader-driven sequencer processed {counter.events_seen} events, "
264+
f"expected {nevents}"
265+
)
266+
267+
209268
def test_particle_gun_fatras(tmp_path, fatras):
210269
"""Particle gun + Fatras → both generated and simulated particles → Parquet."""
211270
nevents = 5

0 commit comments

Comments
 (0)