Skip to content

Commit a92b250

Browse files
vitentimarcpaterno
andauthored
Introducing BinPairSelector class system for bin pair selection. (#552)
* Introduce BinPairSelector class system for bin combinations. * Reorganize all tutorials. * New tutorial for Selectors. --------- Co-authored-by: Marc Paterno <[email protected]>
1 parent 855ac88 commit a92b250

27 files changed

+4566
-882
lines changed

firecrown/generators/_inferred_galaxy_zdist.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,10 @@ def get_lsst_y1_lens_harmonic_bin_collection() -> ZDistLSSTSRDBinCollection:
546546
zpu=zpu,
547547
sigma_z=y1_lens_bins["sigma_z"],
548548
z=RawGrid1D(values=[0.0, 3.5]),
549-
bin_name=f"lens_{zpl:.1f}_{zpu:.1f}_y1",
549+
bin_name=f"lsst_y1_lens{i}",
550550
measurements={Galaxies.COUNTS},
551551
)
552-
for zpl, zpu in pairwise(y1_lens_bins["edges"])
552+
for i, (zpl, zpu) in enumerate(pairwise(y1_lens_bins["edges"]))
553553
],
554554
)
555555

@@ -568,10 +568,10 @@ def get_lsst_y1_source_harmonic_bin_collection() -> ZDistLSSTSRDBinCollection:
568568
zpu=zpu,
569569
sigma_z=y1_source_bins["sigma_z"],
570570
z=RawGrid1D(values=[0.0, 3.5]),
571-
bin_name=f"source_{zpl:.1f}_{zpu:.1f}_y1",
571+
bin_name=f"lsst_y1_source{i}",
572572
measurements={Galaxies.SHEAR_E},
573573
)
574-
for zpl, zpu in pairwise(y1_source_bins["edges"])
574+
for i, (zpl, zpu) in enumerate(pairwise(y1_source_bins["edges"]))
575575
],
576576
)
577577

@@ -590,10 +590,10 @@ def get_lsst_y10_lens_harmonic_bin_collection() -> ZDistLSSTSRDBinCollection:
590590
zpu=zpu,
591591
sigma_z=y10_lens_bins["sigma_z"],
592592
z=RawGrid1D(values=[0.0, 3.5]),
593-
bin_name=f"lens_{zpl:.1f}_{zpu:.1f}_y10",
593+
bin_name=f"lsst_y10_lens{i}",
594594
measurements={Galaxies.COUNTS},
595595
)
596-
for zpl, zpu in pairwise(y10_lens_bins["edges"])
596+
for i, (zpl, zpu) in enumerate(pairwise(y10_lens_bins["edges"]))
597597
],
598598
)
599599

@@ -612,10 +612,10 @@ def get_lsst_y10_source_harmonic_bin_collection() -> ZDistLSSTSRDBinCollection:
612612
zpu=zpu,
613613
sigma_z=y10_source_bins["sigma_z"],
614614
z=RawGrid1D(values=[0.0, 3.5]),
615-
bin_name=f"source_{zpl:.1f}_{zpu:.1f}_y10",
615+
bin_name=f"lsst_y10_source{i}",
616616
measurements={Galaxies.SHEAR_E},
617617
)
618-
for zpl, zpu in pairwise(y10_source_bins["edges"])
618+
for i, (zpl, zpu) in enumerate(pairwise(y10_source_bins["edges"]))
619619
],
620620
)
621621

firecrown/metadata_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
maybe_enforce_window,
2828
)
2929
from firecrown.metadata_functions._combination_utils import (
30+
make_binned_two_point_filtered,
3031
make_all_photoz_bin_combinations,
3132
make_all_photoz_bin_combinations_with_cmb,
3233
make_cmb_galaxy_combinations_only,
@@ -53,6 +54,7 @@
5354
"extract_all_photoz_bin_combinations",
5455
"extract_window_function",
5556
"maybe_enforce_window",
57+
"make_binned_two_point_filtered",
5658
"make_all_photoz_bin_combinations",
5759
"make_all_photoz_bin_combinations_with_cmb",
5860
"make_cmb_galaxy_combinations_only",

firecrown/metadata_functions/_combination_utils.py

Lines changed: 131 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
"""Utilities for creating combinations of photo-z bins and measurements."""
1+
"""Utilities for creating combinations of tomographic bins and measurements.
2+
3+
This module provides functions to generate two-point correlation combinations from
4+
tomographic redshift bins. It supports:
5+
- All possible galaxy-galaxy correlations
6+
- CMB-galaxy cross-correlations
7+
- Filtered combinations based on bin pair selectors
8+
"""
29

310
from itertools import product, chain
411

@@ -7,37 +14,95 @@
714
import firecrown.metadata_types as mdt
815

916

17+
def _validate_list_of_inferred_galaxy_zdists(
18+
inferred_galaxy_zdists: list[mdt.InferredGalaxyZDist],
19+
) -> None:
20+
"""Validate that tomographic bin names are unique.
21+
22+
:param inferred_galaxy_zdists: List of tomographic bins to validate.
23+
24+
:raises ValueError: If any bin names appear more than once in the list.
25+
"""
26+
bin_names_set = set()
27+
# Produce a list of duplicates
28+
bin_names = []
29+
for igz in inferred_galaxy_zdists:
30+
if igz.bin_name in bin_names_set:
31+
bin_names.append(igz.bin_name)
32+
else:
33+
bin_names_set.add(igz.bin_name)
34+
35+
if bin_names:
36+
raise ValueError(
37+
f"Duplicate inferred galaxy z distribution bin names found: {bin_names}"
38+
)
39+
40+
1041
def make_all_photoz_bin_combinations(
1142
inferred_galaxy_zdists: list[mdt.InferredGalaxyZDist],
1243
) -> list[mdt.TwoPointXY]:
13-
"""Extract the two-point function metadata from a sacc file."""
14-
bin_combinations = [
15-
mdt.TwoPointXY(
16-
x=igz1, y=igz2, x_measurement=x_measurement, y_measurement=y_measurement
17-
)
18-
for igz1, igz2 in product(inferred_galaxy_zdists, repeat=2)
19-
for x_measurement, y_measurement in product(
20-
igz1.measurements, igz2.measurements
21-
)
22-
if mdt.measurement_is_compatible(x_measurement, y_measurement)
44+
"""Create all possible two-point correlation combinations for galaxy bins.
45+
46+
This function generates all possible pairs of (bin, measurement) combinations,
47+
keeping only those where the measurements are compatible. For auto-correlations
48+
(same measurement type), only unique pairs are kept to avoid duplicates
49+
(e.g., only bin0-bin1, not both bin0-bin1 and bin1-bin0).
50+
51+
:param inferred_galaxy_zdists: List of tomographic redshift bins with their
52+
associated measurement types.
53+
54+
:return: List of all valid TwoPointXY combinations.
55+
56+
:raises ValueError: If duplicate bin names are found in inferred_galaxy_zdists.
57+
"""
58+
_validate_list_of_inferred_galaxy_zdists(inferred_galaxy_zdists)
59+
expanded = [
60+
(igz, m) for igz in inferred_galaxy_zdists for m in igz.measurement_list
61+
]
62+
63+
# Create all combinations of the expanded list, keeping only compatible ones
64+
# and avoiding duplicates in the case of correlations of the same type
65+
all_xy = [
66+
mdt.TwoPointXY(x=igz1, y=igz2, x_measurement=m1, y_measurement=m2)
67+
for (igz1, m1), (igz2, m2) in product(expanded, repeat=2)
68+
if mdt.measurement_is_compatible(m1, m2)
69+
and ((m1 != m2) or (igz2.bin_name >= igz1.bin_name))
2370
]
2471

25-
return bin_combinations
72+
# Reorder expanded to have alphabetical order considering first measurements, then
73+
# bin names.
74+
return sorted(
75+
all_xy,
76+
key=lambda xy: (
77+
xy.x_measurement,
78+
xy.y_measurement,
79+
xy.x.bin_name,
80+
xy.y.bin_name,
81+
),
82+
)
2683

2784

2885
def make_all_photoz_bin_combinations_with_cmb(
2986
inferred_galaxy_zdists: list[mdt.InferredGalaxyZDist],
3087
cmb_tracer_name: str = "cmb_convergence",
3188
include_cmb_auto: bool = False,
3289
) -> list[mdt.TwoPointXY]:
33-
"""Create all galaxy combinations plus mdt.CMB-galaxy cross-correlations.
90+
"""Create all galaxy-galaxy and CMB-galaxy correlation combinations.
3491
35-
:param inferred_galaxy_zdists: List of galaxy redshift bins
36-
:param cmb_tracer_name: Name of the mdt.CMB tracer
37-
:param include_cmb_auto: Whether to include mdt.CMB auto-correlation
38-
(default: False)
39-
:return: List of all XY combinations including mdt.CMB-galaxy crosses
92+
This function generates all possible two-point correlations including both
93+
galaxy-galaxy auto/cross-correlations and CMB-galaxy cross-correlations.
94+
95+
:param inferred_galaxy_zdists: List of galaxy redshift bins with their associated
96+
measurement types.
97+
:param cmb_tracer_name: Name to assign to the CMB tracer (default:
98+
"cmb_convergence").
99+
:param include_cmb_auto: Whether to include CMB auto-correlation (default: False).
100+
101+
:return: Combined list of galaxy-galaxy and CMB-galaxy correlation combinations.
102+
103+
:raises ValueError: If duplicate bin names are found in inferred_galaxy_zdists.
40104
"""
105+
_validate_list_of_inferred_galaxy_zdists(inferred_galaxy_zdists)
41106
# Get all galaxy-galaxy combinations first
42107
galaxy_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists)
43108
all_combinations = galaxy_combinations + make_cmb_galaxy_combinations_only(
@@ -52,12 +117,24 @@ def make_cmb_galaxy_combinations_only(
52117
cmb_tracer_name: str = "cmb_convergence",
53118
include_cmb_auto: bool = False,
54119
) -> list[mdt.TwoPointXY]:
55-
"""Create only mdt.CMB-galaxy cross-correlations.
120+
"""Create only CMB-galaxy cross-correlations.
121+
122+
This function generates cross-correlations between CMB convergence and galaxy
123+
measurements, optionally including the CMB auto-correlation. It does NOT include
124+
any galaxy-galaxy correlations.
125+
126+
:param inferred_galaxy_zdists: List of galaxy redshift bins with their
127+
associated measurement types.
128+
:param cmb_tracer_name: Name to assign to the CMB tracer (default:
129+
"cmb_convergence").
130+
:param include_cmb_auto: Whether to include CMB auto-correlation (default: False).
56131
57-
:param inferred_galaxy_zdists: List of galaxy redshift bins
58-
:param cmb_tracer_name: Name of the mdt.CMB tracer
59-
:return: List of mdt.CMB-galaxy cross-correlation XY combinations only
132+
:return: List of CMB-galaxy cross-correlation combinations (and optionally CMB
133+
auto).
134+
135+
:raises ValueError: If duplicate bin names are found in inferred_galaxy_zdists.
60136
"""
137+
_validate_list_of_inferred_galaxy_zdists(inferred_galaxy_zdists)
61138
# Create a mock mdt.CMB "bin"
62139
cmb_bin = mdt.InferredGalaxyZDist(
63140
bin_name=cmb_tracer_name,
@@ -94,3 +171,35 @@ def make_cmb_galaxy_combinations_only(
94171
)
95172

96173
return cmb_galaxy_combinations
174+
175+
176+
def make_binned_two_point_filtered(
177+
inferred_galaxy_zdists: list[mdt.InferredGalaxyZDist],
178+
bin_pair_selector: mdt.BinPairSelector,
179+
) -> list[mdt.TwoPointXY]:
180+
"""Create two-point correlations filtered by a bin pair selector.
181+
182+
This function generates all possible bin combinations and then filters them using
183+
the provided selector, keeping only pairs that satisfy the selection criteria
184+
(e.g., auto-correlations only, specific measurements, neighboring bins).
185+
186+
:param inferred_galaxy_zdists: List of tomographic redshift bins with their
187+
associated measurement types.
188+
:param bin_pair_selector: Selector defining which bin pairs to include.
189+
190+
:return: List of TwoPointXY combinations that pass the selector's criteria.
191+
192+
:raises ValueError: If duplicate bin names are found in inferred_galaxy_zdists.
193+
194+
Example:
195+
# Get only auto-correlations of source measurements
196+
selector = AutoNameBinPairSelector() & SourceBinPairSelector()
197+
combinations = make_binned_two_point_filtered(bins, selector)
198+
"""
199+
_validate_list_of_inferred_galaxy_zdists(inferred_galaxy_zdists)
200+
all_bin_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists)
201+
return [
202+
xy
203+
for xy in all_bin_combinations
204+
if bin_pair_selector.keep((xy.x, xy.y), (xy.x_measurement, xy.y_measurement))
205+
]

firecrown/metadata_functions/_extraction.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,19 @@ def extract_all_harmonic_metadata_indices(
458458

459459
def extract_all_harmonic_metadata(
460460
sacc_data: sacc.Sacc,
461-
allow_mixed_types: bool = False,
462461
allowed_data_type: None | list[str] = None,
462+
allow_mixed_types: bool = False,
463+
bin_pair_selector: None | mdt.BinPairSelector = None,
463464
) -> list[mdt.TwoPointHarmonic]:
464-
"""Extract the two-point function metadata and data from a sacc file."""
465+
"""Extract two-point harmonic-space metadata and data from a SACC file.
466+
467+
:param sacc_data: The SACC object containing tracers and data points.
468+
:param allowed_data_type: Optional list of SACC data type strings to include.
469+
If None, all harmonic-space data types are extracted.
470+
:param bin_pair_selector: Optional selector to filter which bin pairs to include.
471+
If None, all valid bin pairs are returned.
472+
:return: List of TwoPointHarmonic objects with metadata and ell values.
473+
"""
465474
inferred_galaxy_zdists_dict = {
466475
igz.bin_name: igz
467476
for igz in extract_all_tracers_inferred_galaxy_zdists(
@@ -478,6 +487,13 @@ def extract_all_harmonic_metadata(
478487

479488
XY = make_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, dt)
480489

490+
# Apply bin pair selector if provided
491+
if bin_pair_selector is not None:
492+
if not bin_pair_selector.keep(
493+
(XY.x, XY.y), (XY.x_measurement, XY.y_measurement)
494+
):
495+
continue
496+
481497
ells, _, indices = sacc_data.get_ell_cl(
482498
data_type=dt,
483499
tracer1=tracer_names[0],
@@ -501,10 +517,19 @@ def extract_all_harmonic_metadata(
501517

502518
def extract_all_real_metadata(
503519
sacc_data: sacc.Sacc,
504-
allow_mixed_types: bool = False,
505520
allowed_data_type: None | list[str] = None,
521+
allow_mixed_types: bool = False,
522+
bin_pair_selector: None | mdt.BinPairSelector = None,
506523
) -> list[mdt.TwoPointReal]:
507-
"""Extract the two-point function metadata and data from a sacc file."""
524+
"""Extract two-point real-space metadata and data from a SACC file.
525+
526+
:param sacc_data: The SACC object containing tracers and data points.
527+
:param allowed_data_type: Optional list of SACC data type strings to include.
528+
If None, all real-space data types are extracted.
529+
:param bin_pair_selector: Optional selector to filter which bin pairs to include.
530+
If None, all valid bin pairs are returned.
531+
:return: List of TwoPointReal objects with metadata and theta values.
532+
"""
508533
inferred_galaxy_zdists_dict = {
509534
igz.bin_name: igz
510535
for igz in extract_all_tracers_inferred_galaxy_zdists(
@@ -521,6 +546,13 @@ def extract_all_real_metadata(
521546

522547
XY = make_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, dt)
523548

549+
# Apply bin pair selector if provided
550+
if bin_pair_selector is not None:
551+
if not bin_pair_selector.keep(
552+
(XY.x, XY.y), (XY.x_measurement, XY.y_measurement)
553+
):
554+
continue
555+
524556
t1, t2 = tracer_names
525557
thetas, _, _ = sacc_data.get_theta_xi(
526558
data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True
@@ -532,14 +564,31 @@ def extract_all_real_metadata(
532564

533565

534566
def extract_all_photoz_bin_combinations(
535-
sacc_data: sacc.Sacc, allow_mixed_types: bool = False
567+
sacc_data: sacc.Sacc,
568+
allow_mixed_types: bool = False,
569+
bin_pair_selector: None | mdt.BinPairSelector = None,
536570
) -> list[mdt.TwoPointXY]:
537-
"""Extracts the two-point function metadata from a sacc file."""
571+
"""Extract all two-point bin pair combinations from a SACC file.
572+
573+
:param sacc_data: The SACC object containing tracers and data points.
574+
:param bin_pair_selector: Optional selector to filter which bin pairs to include.
575+
If None, all valid bin pairs are returned.
576+
:return: List of TwoPointXY objects representing valid bin pair combinations.
577+
"""
538578
inferred_galaxy_zdists = extract_all_tracers_inferred_galaxy_zdists(
539579
sacc_data, allow_mixed_types
540580
)
541581
bin_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists)
542582

583+
if bin_pair_selector is not None:
584+
bin_combinations = [
585+
xy
586+
for xy in bin_combinations
587+
if bin_pair_selector.keep(
588+
(xy.x, xy.y), (xy.x_measurement, xy.y_measurement)
589+
)
590+
]
591+
543592
return bin_combinations
544593

545594

0 commit comments

Comments
 (0)