Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion firecrown/app/examples/_des_y1_3x2pt_pt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _build_two_point_statistics(
TwoPoint(source0=src0, source1=src0, sacc_data_type="galaxy_shear_xi_plus"),
TwoPoint(source0=src0, source1=src0, sacc_data_type="galaxy_shear_xi_minus"),
TwoPoint(
source0=lens0, source1=src0, sacc_data_type="galaxy_shearDensity_xi_t"
source0=src0, source1=lens0, sacc_data_type="galaxy_shearDensity_xi_t"
),
TwoPoint(source0=lens0, source1=lens0, sacc_data_type="galaxy_density_xi"),
]
Expand Down
4 changes: 2 additions & 2 deletions firecrown/app/examples/_des_y1_3x2pt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def build_likelihood(params: NamedParameters) -> tuple[ConstGaussian, ModelingTo
for j in range(5):
for i in range(4):
stats[f"gammat_lens{j}_src{i}"] = TwoPoint(
source0=sources[f"lens{j}"],
source1=sources[f"src{i}"],
source0=sources[f"src{i}"],
source1=sources[f"lens{j}"],
sacc_data_type="galaxy_shearDensity_xi_t",
)

Expand Down
12 changes: 11 additions & 1 deletion firecrown/app/sacc/_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Load command for SACC files."""

import dataclasses
import contextlib
import io
import warnings
from pathlib import Path
from typing import Annotated

Expand Down Expand Up @@ -41,7 +44,14 @@ def _load_sacc_file(self) -> None:
try:
if not self.sacc_file.exists():
raise typer.BadParameter(f"SACC file not found: {self.sacc_file}")
self.sacc_data = factories.load_sacc_data(self.sacc_file.as_posix())

with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
warnings.catch_warnings(),
):
warnings.simplefilter("ignore")
self.sacc_data = factories.load_sacc_data(self.sacc_file.as_posix())
except Exception as e:
self.console.print(f"[bold red]Failed to load SACC file:[/bold red] {e}")
raise
Binary file modified tests/bug_398.sacc.gz
Binary file not shown.
Binary file modified tests/legacy_sacc_data.fits
Binary file not shown.
94 changes: 71 additions & 23 deletions tests/likelihood/gauss_family/test_const_gaussianPM.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,11 @@ def test_compute_chisq_fails_before_read(trivial_stats):
"""Note that the error message from the direct call to compute_chisq notes
that update() must be called; this can only be called after read()."""
likelihood = ConstGaussianPM(statistics=trivial_stats)
with pytest.raises(
AssertionError,
match=re.escape("update() must be called before compute_chisq()"),
with (
pytest.raises(
AssertionError,
match=re.escape("update() must be called before compute_chisq()"),
),
):
_ = likelihood.compute_chisq(ModelingTools())

Expand All @@ -225,7 +227,11 @@ def test_compute_chisq(trivial_stats, sacc_data_for_trivial_stat, trivial_params
likelihood = ConstGaussianPM(statistics=trivial_stats)
likelihood.read(sacc_data_for_trivial_stat)
likelihood.update(trivial_params)
assert likelihood.compute_chisq(ModelingTools()) == 2.0
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
assert likelihood.compute_chisq(ModelingTools()) == 2.0


def test_deprecated_compute(trivial_stats, sacc_data_for_trivial_stat, trivial_params):
Expand Down Expand Up @@ -256,7 +262,11 @@ def compute(
likelihood.compute_theory_vector = compute_theory_vector # type: ignore
likelihood.compute = compute # type: ignore

assert likelihood.compute_chisq(ModelingTools()) == 2.0
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
assert likelihood.compute_chisq(ModelingTools()) == 2.0


def test_required_parameters(trivial_stats, sacc_data_for_trivial_stat, trivial_params):
Expand All @@ -282,7 +292,11 @@ def test_reset(trivial_stats, sacc_data_for_trivial_stat, trivial_params):
likelihood.read(sacc_data_for_trivial_stat)
likelihood.update(trivial_params)
assert not trivial_stats[0].computed_theory_vector
assert likelihood.compute_loglike(ModelingTools()) == -1.0
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
assert likelihood.compute_loglike(ModelingTools()) == -1.0
assert trivial_stats[0].computed_theory_vector
likelihood.reset()
assert not trivial_stats[0].computed_theory_vector
Expand All @@ -307,7 +321,11 @@ def test_using_good_sacc(
likelihood.read(sacc_data_for_trivial_stat)
params = firecrown.parameters.ParamsMap(mean=10.5)
likelihood.update(params)
chisq = likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
chisq = likelihood.compute_chisq(tools_with_vanilla_cosmology)
assert isinstance(chisq, float)
assert chisq > 0.0

Expand All @@ -331,15 +349,23 @@ def test_make_realization_chisq(
likelihood.read(sacc_data_for_trivial_stat)
params = firecrown.parameters.ParamsMap(mean=10.5)
likelihood.update(params)
likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
likelihood.compute_chisq(tools_with_vanilla_cosmology)

new_sacc = likelihood.make_realization(sacc_data_for_trivial_stat)

new_likelihood = ConstGaussianPM(statistics=[TrivialStatistic()])
new_likelihood.read(new_sacc)
params = firecrown.parameters.ParamsMap(mean=10.5)
new_likelihood.update(params)
chisq = new_likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
chisq = new_likelihood.compute_chisq(tools_with_vanilla_cosmology)

# The new likelihood chisq is distributed as a chi-squared with 3 degrees of
# freedom. We want to check that the new chisq is within the 1-10^-6 quantile
Expand All @@ -358,7 +384,11 @@ def test_make_realization_chisq_mean(
likelihood.read(sacc_data_for_trivial_stat)
params = firecrown.parameters.ParamsMap(mean=10.5)
likelihood.update(params)
likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
likelihood.compute_chisq(tools_with_vanilla_cosmology)

chisq_list = []
for _ in range(1000):
Expand All @@ -368,7 +398,11 @@ def test_make_realization_chisq_mean(
new_likelihood.read(new_sacc)
params = firecrown.parameters.ParamsMap(mean=10.5)
new_likelihood.update(params)
chisq = new_likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
chisq = new_likelihood.compute_chisq(tools_with_vanilla_cosmology)
chisq_list.append(chisq)

# The new likelihood chisq is distributed as a chi-squared with 3 degrees of
Expand All @@ -387,7 +421,11 @@ def test_make_realization_data_vector(
likelihood.read(sacc_data_for_trivial_stat)
params = firecrown.parameters.ParamsMap(mean=10.5)
likelihood.update(params)
likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
likelihood.compute_chisq(tools_with_vanilla_cosmology)

data_vector_list = []
for _ in range(1000):
Expand Down Expand Up @@ -430,7 +468,11 @@ def test_make_realization_no_noise(
likelihood.read(sacc_data_for_trivial_stat)
params = firecrown.parameters.ParamsMap(mean=10.5)
likelihood.update(params)
likelihood.compute_chisq(tools_with_vanilla_cosmology)
with pytest.warns(
UserWarning,
match=re.escape("inverse covariance correction has not yet been computed."),
):
likelihood.compute_chisq(tools_with_vanilla_cosmology)

new_sacc = likelihood.make_realization(sacc_data_for_trivial_stat, add_noise=False)

Expand Down Expand Up @@ -561,8 +603,8 @@ def __init__(self, statistic):
self.statistic = statistic


@pytest.fixture
def minimal_const_gaussian_PM():
@pytest.fixture(name="minimal_const_gaussian_PM")
def fixture_minimal_const_gaussian_PM() -> ConstGaussianPM:
# Create minimal valid statistics for the class.
z = np.array([0.1, 0.2, 0.3])
dndz = np.array([1.0, 2.0, 3.0])
Expand Down Expand Up @@ -601,6 +643,9 @@ def minimal_const_gaussian_PM():
return likelihood


# pylint: disable=protected-access


def test_precomputed_warning(minimal_const_gaussian_PM):
# Check that running the precomputation twice gives a warning.
minimal_const_gaussian_PM._generate_maps()
Expand Down Expand Up @@ -705,8 +750,8 @@ def test_PM_correction_matrix(sacc_data):
sacc_data_type="galaxy_shear_xi_plus",
)
stats["gammat_lens0_src0"] = TwoPoint(
source0=nc.NumberCounts(sacc_tracer="lens0"),
source1=wl.WeakLensing(sacc_tracer="src0"),
source0=wl.WeakLensing(sacc_tracer="src0"),
source1=nc.NumberCounts(sacc_tracer="lens0"),
sacc_data_type="galaxy_shearDensity_xi_t",
)
stats["wtheta_lens0_lens0"] = TwoPoint(
Expand All @@ -731,8 +776,8 @@ def test_compute_chisq_with_correction(sacc_data):
# This tests the truthy branch of the if statement at line 50
stats = {}
stats["gammat_lens0_src0"] = TwoPoint(
source0=nc.NumberCounts(sacc_tracer="lens0"),
source1=wl.WeakLensing(sacc_tracer="src0"),
source0=wl.WeakLensing(sacc_tracer="src0"),
source1=nc.NumberCounts(sacc_tracer="lens0"),
sacc_data_type="galaxy_shearDensity_xi_t",
)
likelihood = ConstGaussianPM(statistics=list(stats.values()))
Expand All @@ -759,8 +804,8 @@ def test_get_lens_statistic_not_found(sacc_data):
# Test that _get_lens_statistic raises StopIteration when lens tracer not found
stats = {}
stats["gammat_lens0_src0"] = TwoPoint(
source0=nc.NumberCounts(sacc_tracer="lens0"),
source1=wl.WeakLensing(sacc_tracer="src0"),
source0=wl.WeakLensing(sacc_tracer="src0"),
source1=nc.NumberCounts(sacc_tracer="lens0"),
sacc_data_type="galaxy_shearDensity_xi_t",
)
likelihood = ConstGaussianPM(statistics=list(stats.values()))
Expand All @@ -775,8 +820,8 @@ def test_get_src_statistic_not_found(sacc_data):
# Test that _get_src_statistic raises StopIteration when source tracer not found
stats = {}
stats["gammat_lens0_src0"] = TwoPoint(
source0=nc.NumberCounts(sacc_tracer="lens0"),
source1=wl.WeakLensing(sacc_tracer="src0"),
source0=wl.WeakLensing(sacc_tracer="src0"),
source1=nc.NumberCounts(sacc_tracer="lens0"),
sacc_data_type="galaxy_shearDensity_xi_t",
)
likelihood = ConstGaussianPM(statistics=list(stats.values()))
Expand Down Expand Up @@ -818,3 +863,6 @@ def __init__(self):

with pytest.raises(StopIteration, match="missing attributes"):
likelihood._collect_data_vectors()


# pylint: enable=protected-access
32 changes: 16 additions & 16 deletions tests/test_pt_systematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data):
stats = [
TwoPoint("galaxy_shear_xi_plus", weak_lensing_source, weak_lensing_source),
TwoPoint("galaxy_shear_xi_minus", weak_lensing_source, weak_lensing_source),
TwoPoint("galaxy_shearDensity_xi_t", number_counts_source, weak_lensing_source),
TwoPoint("galaxy_shearDensity_xi_t", weak_lensing_source, number_counts_source),
TwoPoint("galaxy_density_xi", number_counts_source, number_counts_source),
]

Expand Down Expand Up @@ -179,9 +179,9 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data):
# print(list(likelihood.statistics[2].cells.keys()))
s2 = likelihood.statistics[2].statistic
assert isinstance(s2, TwoPoint)
cells_gG = s2.cells[TracerNames("galaxies", "shear")]
cells_gI = s2.cells[TracerNames("galaxies", "intrinsic_pt")]
cells_mI = s2.cells[TracerNames("magnification+rsd", "intrinsic_pt")]
cells_gG = s2.cells[TracerNames("shear", "galaxies")]
cells_gI = s2.cells[TracerNames("intrinsic_pt", "galaxies")]
cells_mI = s2.cells[TracerNames("intrinsic_pt", "magnification+rsd")]

# print(list(likelihood.statistics[3].cells.keys()))
s3 = likelihood.statistics[3].statistic
Expand Down Expand Up @@ -267,8 +267,8 @@ def test_pt_mixed_systematics(sacc_data):
)

stat = TwoPoint(
source0=nc_source,
source1=wl_source,
source0=wl_source,
source1=nc_source,
sacc_data_type="galaxy_shearDensity_xi_t",
)

Expand Down Expand Up @@ -354,8 +354,8 @@ def test_pt_mixed_systematics(sacc_data):
ells = s0.ells_for_xi

# print(list(likelihood.statistics[2].cells.keys()))
cells_gG = s0.cells[TracerNames("galaxies+magnification+rsd", "shear")]
cells_gI = s0.cells[TracerNames("galaxies+magnification+rsd", "intrinsic_pt")]
cells_gG = s0.cells[TracerNames("shear", "galaxies+magnification+rsd")]
cells_gI = s0.cells[TracerNames("intrinsic_pt", "galaxies+magnification+rsd")]
# pylint: enable=no-member

# Code that computes effect from IA using that Pk2D object
Expand Down Expand Up @@ -399,8 +399,8 @@ def test_pt_mixed_systematics_zdep(sacc_data):
)

stat = TwoPoint(
source0=nc_source,
source1=wl_source,
source0=wl_source,
source1=nc_source,
sacc_data_type="galaxy_shearDensity_xi_t",
)

Expand Down Expand Up @@ -494,8 +494,8 @@ def test_pt_mixed_systematics_zdep(sacc_data):
ells = s0.ells_for_xi

# print(list(likelihood.statistics[2].cells.keys()))
cells_gG = s0.cells[TracerNames("galaxies+magnification+rsd", "shear")]
cells_gI = s0.cells[TracerNames("galaxies+magnification+rsd", "intrinsic_pt")]
cells_gG = s0.cells[TracerNames("shear", "galaxies+magnification+rsd")]
cells_gI = s0.cells[TracerNames("intrinsic_pt", "galaxies+magnification+rsd")]
# pylint: enable=no-member

# Code that computes effect from IA using that Pk2D object
Expand Down Expand Up @@ -533,7 +533,7 @@ def test_pt_systematics_zdep(weak_lensing_source, number_counts_source, sacc_dat
stats = [
TwoPoint("galaxy_shear_xi_plus", weak_lensing_source, weak_lensing_source),
TwoPoint("galaxy_shear_xi_minus", weak_lensing_source, weak_lensing_source),
TwoPoint("galaxy_shearDensity_xi_t", number_counts_source, weak_lensing_source),
TwoPoint("galaxy_shearDensity_xi_t", weak_lensing_source, number_counts_source),
TwoPoint("galaxy_density_xi", number_counts_source, number_counts_source),
]

Expand Down Expand Up @@ -656,9 +656,9 @@ def test_pt_systematics_zdep(weak_lensing_source, number_counts_source, sacc_dat
# print(list(likelihood.statistics[2].cells.keys()))
s2 = likelihood.statistics[2].statistic
assert isinstance(s2, TwoPoint)
cells_gG = s2.cells[TracerNames("galaxies", "shear")]
cells_gI = s2.cells[TracerNames("galaxies", "intrinsic_pt")]
cells_mI = s2.cells[TracerNames("magnification+rsd", "intrinsic_pt")]
cells_gG = s2.cells[TracerNames("shear", "galaxies")]
cells_gI = s2.cells[TracerNames("intrinsic_pt", "galaxies")]
cells_mI = s2.cells[TracerNames("intrinsic_pt", "magnification+rsd")]

# print(list(likelihood.statistics[3].cells.keys()))
s3 = likelihood.statistics[3].statistic
Expand Down
Loading