Skip to content

Use pvalue for resolver #941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ class greedy_ambiguity_resolution_algorithm
///
/// There is no (track_id) in this algorithm, only (track_index).

/// Associates each track_index with the track's chi2 value
std::vector<traccc::scalar> track_chi2;
/// Associates each track_index with the track's p-value
std::vector<traccc::scalar> track_pval;

/// Associates each track_index to the track's (measurement_id)s list
std::vector<std::vector<std::size_t>> measurements_per_track;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ greedy_ambiguity_resolution_algorithm::operator()(
std::vector<unsigned int> accepted_ids(n_tracks);
std::iota(accepted_ids.begin(), accepted_ids.end(), 0);

// Make measurement ID, chi2 and n_measurement vector
// Make measurement ID, pval and n_measurement vector
std::vector<std::vector<std::size_t>> meas_ids(n_tracks);
std::vector<traccc::scalar> chi_squares(n_tracks);
std::vector<traccc::scalar> pvals(n_tracks);
std::vector<std::size_t> n_meas(n_tracks);

for (unsigned int i = 0; i < n_tracks; i++) {
// Fill the chi-squares vectors
chi_squares[i] = track_candidates.at(i).header.trk_quality.chi2;
// Fill the pval vectors
pvals[i] = track_candidates.at(i).header.trk_quality.pval;

const auto& candidates = track_candidates.at(i).items;
const unsigned int n_cands = candidates.size();
Expand Down Expand Up @@ -122,15 +122,15 @@ greedy_ambiguity_resolution_algorithm::operator()(
static_cast<traccc::scalar>(n_meas[i]);
}

// Sort the track id with rel_shared and chi2 to find the worst track fast
// Sort the track id with rel_shared and pval to find the worst track fast
std::vector<unsigned int> sorted_ids = accepted_ids;

auto track_comparator = [&rel_shared, &chi_squares](unsigned int a,
unsigned int b) {
auto track_comparator = [&rel_shared, &pvals](unsigned int a,
unsigned int b) {
if (rel_shared[a] != rel_shared[b]) {
return rel_shared[a] < rel_shared[b];
}
return chi_squares[a] < chi_squares[b];
return pvals[a] > pvals[b];
};
std::sort(sorted_ids.begin(), sorted_ids.end(), track_comparator);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ void greedy_ambiguity_resolution_algorithm::compute_initial_state(
}
}

// Add this track chi2 value
state.track_chi2.push_back(fit_res.trk_quality.chi2);
// Add this track pval value
state.track_pval.push_back(fit_res.trk_quality.pval);
// Add all the (measurement_id)s of this track
state.measurements_per_track.push_back(std::move(measurements));
// Initially, every track is in the selected_track list. They will later
Expand Down Expand Up @@ -177,7 +177,7 @@ void greedy_ambiguity_resolution_algorithm::resolve(state_t& state) const {

/// Compares two tracks in order to find the one which should be evicted.
/// First we compare the relative amount of shared measurements. If that is
/// indecisive we use the chi2.
/// indecisive we use the pval.
auto track_comperator = [&state](std::pair<std::size_t, std::size_t> a,
std::pair<std::size_t, std::size_t> b) {
/// Helper to calculate the relative amount of shared measurements.
Expand All @@ -191,7 +191,7 @@ void greedy_ambiguity_resolution_algorithm::resolve(state_t& state) const {
return relative_shared_measurements(a.second) <
relative_shared_measurements(b.second);
}
return state.track_chi2[a.second] < state.track_chi2[b.second];
return state.track_pval[a.second] > state.track_pval[b.second];
};

std::size_t iteration_count = 0;
Expand Down
38 changes: 19 additions & 19 deletions tests/cpu/test_ambiguity_resolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ traccc::memory_resource mr{host_mr, &host_mr};
} // namespace

void fill_pattern(track_candidate_container_types::host& track_candidates,
const std::size_t idx, const traccc::scalar chi2,
const std::size_t idx, const traccc::scalar pval,
const std::vector<std::size_t>& pattern) {

track_candidates.at(idx).header.trk_quality.chi2 = chi2;
track_candidates.at(idx).header.trk_quality.pval = pval;

auto& cands = track_candidates.at(idx).items;
for (const auto& meas_id : pattern) {
Expand All @@ -57,9 +57,9 @@ TEST(AmbiguitySolverTests, GreedyResolverTest0) {
track_candidate_container_types::host trk_cands;

trk_cands.resize(3u);
fill_pattern(trk_cands, 0, 8.3f, {1, 3, 5, 11});
fill_pattern(trk_cands, 1, 2.2f, {6, 7, 8, 9, 10, 12});
fill_pattern(trk_cands, 2, 3.7f, {2, 4, 13});
fill_pattern(trk_cands, 0, 0.23f, {1, 3, 5, 11});
fill_pattern(trk_cands, 1, 0.85f, {6, 7, 8, 9, 10, 12});
fill_pattern(trk_cands, 2, 0.42f, {2, 4, 13});

traccc::host::greedy_ambiguity_resolution_algorithm::config_type
resolution_config;
Expand Down Expand Up @@ -122,8 +122,8 @@ TEST(AmbiguitySolverTests, GreedyResolverTest1) {
track_candidate_container_types::host trk_cands;

trk_cands.resize(2u);
fill_pattern(trk_cands, 0, 4.3f, {1, 3, 5, 11, 14, 16, 18});
fill_pattern(trk_cands, 1, 1.2f, {3, 5, 6, 13});
fill_pattern(trk_cands, 0, 0.12f, {1, 3, 5, 11, 14, 16, 18});
fill_pattern(trk_cands, 1, 0.53f, {3, 5, 6, 13});

traccc::host::greedy_ambiguity_resolution_algorithm::config_type
resolution_config;
Expand Down Expand Up @@ -166,8 +166,8 @@ TEST(AmbiguitySolverTests, GreedyResolverTest2) {
track_candidate_container_types::host trk_cands;

trk_cands.resize(2u);
fill_pattern(trk_cands, 0, 4.3f, {1, 3, 5, 11});
fill_pattern(trk_cands, 1, 1.2f, {3, 5, 6, 13});
fill_pattern(trk_cands, 0, 0.8f, {1, 3, 5, 11});
fill_pattern(trk_cands, 1, 0.9f, {3, 5, 6, 13});

traccc::host::greedy_ambiguity_resolution_algorithm::config_type
resolution_config;
Expand All @@ -178,7 +178,7 @@ TEST(AmbiguitySolverTests, GreedyResolverTest2) {
ASSERT_EQ(res_trk_cands.size(), 1u);

// The second track is selected over the first one as their relative
// shared measurement (2/4) is the same but its chi square is smaller
// shared measurement (2/4) is the same but its p-value is higher
ASSERT_EQ(get_pattern(res_trk_cands, 0),
std::vector<std::size_t>({3, 5, 6, 13}));
}
Expand All @@ -188,12 +188,12 @@ TEST(AmbiguitySolverTests, GreedyResolverTest3) {

track_candidate_container_types::host trk_cands;
trk_cands.resize(6u);
fill_pattern(trk_cands, 0, 5.3f, {1, 3, 5, 11});
fill_pattern(trk_cands, 1, 2.4f, {2, 6});
fill_pattern(trk_cands, 2, 2.5f, {3, 6, 12, 14, 19, 21});
fill_pattern(trk_cands, 3, 13.3f, {2, 7, 11, 13, 16});
fill_pattern(trk_cands, 4, 4.1f, {1, 7, 8});
fill_pattern(trk_cands, 5, 1.1f, {1, 3, 11, 22});
fill_pattern(trk_cands, 0, 0.2f, {1, 3, 5, 11});
fill_pattern(trk_cands, 1, 0.5f, {2, 6});
fill_pattern(trk_cands, 2, 0.4f, {3, 6, 12, 14, 19, 21});
fill_pattern(trk_cands, 3, 0.1f, {2, 7, 11, 13, 16});
fill_pattern(trk_cands, 4, 0.3f, {1, 7, 8});
fill_pattern(trk_cands, 5, 0.6f, {1, 3, 11, 22});

traccc::host::greedy_ambiguity_resolution_algorithm::config_type
resolution_config;
Expand Down Expand Up @@ -239,10 +239,10 @@ TEST(AmbiguitySolverTests, GreedyResolverTest4) {

std::uniform_int_distribution<std::size_t> track_length_dist(1, 20);
std::uniform_int_distribution<std::size_t> meas_id_dist(0, 10000);
std::uniform_real_distribution<traccc::scalar> chi2_dist(0.0f, 10.0f);
std::uniform_real_distribution<traccc::scalar> pval_dist(0.0f, 1.0f);

const std::size_t track_length = track_length_dist(gen);
const traccc::scalar chi2 = chi2_dist(gen);
const traccc::scalar pval = pval_dist(gen);
std::vector<std::size_t> pattern;
while (pattern.size() < track_length) {

Expand All @@ -264,7 +264,7 @@ TEST(AmbiguitySolverTests, GreedyResolverTest4) {
ASSERT_EQ(pattern.size(), track_length);

// Fill the pattern
fill_pattern(trk_cands, i, chi2, pattern);
fill_pattern(trk_cands, i, pval, pattern);
}

traccc::host::greedy_ambiguity_resolution_algorithm::config_type
Expand Down
Loading