Skip to content

Commit 33c68f6

Browse files
bclevinevitentimarcpaterno
authored
Gaussian likelihood with point mass marginalization (#573)
* Introducing Gaussian likelihood with point mass marginalization ConstGaussianPM * Test improvements: - Add test_compute_chisq_with_correction() for inv_cov_correction branch - Add test_get_lens_statistic_not_found() for error handling - Add test_get_src_statistic_not_found() for error handling - Add test_compute_chisq_without_cholesky() to cover GaussFamily direct chi-squared path - Use object.__setattr__() in test fixtures to bypass type checking for mock assignments * consolidate ConstGaussian classes and remove unnecessary inheritance - Update ConstGaussianPM to inherit directly from ConstGaussian - Update all imports across codebase to use firecrown.likelihood.gaussian - Fix mypy type ignore comment in test file * consolidate point mass data into dataclass and improve error handling - Add PointMassData dataclass to encapsulate all point mass attributes - Replace individual _pm_* attributes with single _pm_data container - Refactor _generate_maps, _cache_precomputed_data, and data access methods - Improve _collect_data_vectors error handling with detailed missing attributes - Consolidate _get_lens_statistic and _get_src_statistic into shared _get_statistic - Add default constants for sigma_B and point_mass parameters - Update tests to use new PointMassData structure * make _generate_maps return None and inline caching logic - Remove unused _cache_precomputed_data method - Change _generate_maps return type from PointMassData to None - Move data caching logic directly into _generate_maps - Update early return to not return cached data - Remove return value documentation from docstring * Refactor point mass marginalization code for better organization and type safety - Replace _pm_maps_ready boolean with pm_maps_ready property that checks _pm_data existence - Make PointMassData fields optional with None defaults and add assert_prepared() method - Modify _generate_maps() to set instance variables directly instead of returning PointMassData - Add typing.cast() calls after assert_prepared() to satisfy mypy type checker - Remove redundant state tracking and improve code maintainability * Refactor likelihood module structure Rename gaussian_pointmass.py to _gaussian_pointmass.py and update imports to use new private module naming convention: - Update internal imports to use _gaussian and _base modules - Export ConstGaussianPM and TrivialStatistic from public firecrown.likelihood API - Update all example and test imports to use public API - Update Makefile pre-commit target to include docs build - Refactor test fixtures to use public reset() method instead of protected _reset() for TwoPointTheory This change aligns with the module reorganization where implementation modules use underscore prefix while maintaining a clean public API through __init__.py exports. --------- Co-authored-by: Sandro Dias Pinto Vitenti <[email protected]> Co-authored-by: Marc Paterno <[email protected]>
1 parent 604614d commit 33c68f6

File tree

11 files changed

+1492
-38
lines changed

11 files changed

+1492
-38
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ clean: clean-coverage clean-docs clean-build ## Remove all generated files (can
251251

252252
##@ Pre-commit
253253

254-
pre-commit: format lint test-coverage ## Run all pre-commit checks (format, lint, test with coverage)
254+
pre-commit: format lint test-coverage docs ## Run all pre-commit checks (format, lint, test with coverage, docs)
255255
@echo ""
256256
@echo "✅ All pre-commit checks passed!"
257257

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""The DES Y1 3x2pt likelihood factory module with point mass marginalization."""
2+
3+
import os
4+
5+
import pyccl
6+
from firecrown.likelihood.factories import load_sacc_data
7+
from firecrown.likelihood import NamedParameters, TwoPoint, ConstGaussianPM
8+
import firecrown.likelihood.weak_lensing as wl
9+
import firecrown.likelihood.number_counts as nc
10+
from firecrown.modeling_tools import ModelingTools
11+
from firecrown.ccl_factory import CCLFactory
12+
13+
14+
# The likelihood used for DES Y1 3x2pt analysis is a Gaussian likelihood, which
15+
# necessitates providing a list of statistics that are each represented by a two-point
16+
# function. To construct the two-point function, a list of sources is required, with
17+
# each source being responsible for computing the theoretical prediction for a specific
18+
# segment of the data. These sources are created using the build_likelihood function
19+
# and also contain a list of systematics. The systematics are classes that modify the
20+
# theoretical prediction and are also constructed in the build_likelihood function.
21+
def build_likelihood(params: NamedParameters) -> tuple[ConstGaussianPM, ModelingTools]:
22+
"""Build the DES Y1 3x2pt likelihood."""
23+
# Creates a LAI systematic. This is a systematic that is applied to
24+
# all weak-lensing probes. The `sacc_tracer` argument is used to identify the
25+
# section of the SACC file that this systematic will be applied to. In this
26+
# case we want to apply it to all weak-lensing probes, so we use the
27+
# empty string.
28+
lai_systematic = wl.LinearAlignmentSystematic(sacc_tracer="")
29+
30+
# Creating sources, each one maps to a specific section of a SACC file. In
31+
# this case src0, src1, src2 and src3 describe weak-lensing probes. The
32+
# sources are saved in a dictionary since they will be used by one or more
33+
# two-point function.
34+
sources: dict[str, wl.WeakLensing | nc.NumberCounts] = {}
35+
36+
for i in range(4):
37+
# Each weak-lensing section has its own multiplicative bias. Parameters
38+
# reflect this by using src{i}_ prefix.
39+
mbias = wl.MultiplicativeShearBias(sacc_tracer=f"src{i}")
40+
41+
# We also include a photo-z shift bias (a constant shift in dndz). We
42+
# also have a different parameter for each bin, so here again we use the
43+
# src{i}_ prefix.
44+
wl_pzshift = wl.PhotoZShift(sacc_tracer=f"src{i}")
45+
46+
# Now we can finally create the weak-lensing source that will compute the
47+
# theoretical prediction for that section of the data, given the
48+
# systematics.
49+
sources[f"src{i}"] = wl.WeakLensing(
50+
sacc_tracer=f"src{i}", systematics=[lai_systematic, mbias, wl_pzshift]
51+
)
52+
53+
# Creating the number counting sources. There are five sources each one
54+
# labeled by lens{i}.
55+
for i in range(5):
56+
# We also include a photo-z shift for the dndz.
57+
nc_pzshift = nc.PhotoZShift(sacc_tracer=f"lens{i}")
58+
59+
# The source is created and saved (temporarily in the sources dict).
60+
sources[f"lens{i}"] = nc.NumberCounts(
61+
sacc_tracer=f"lens{i}", systematics=[nc_pzshift], derived_scale=True
62+
)
63+
64+
# Now that we have all sources we can instantiate all the two-point
65+
# functions. The weak-lensing sources have two "data types", for each one we
66+
# create a new two-point function.
67+
stats = {}
68+
for stat, sacc_stat in [
69+
("xip", "galaxy_shear_xi_plus"),
70+
("xim", "galaxy_shear_xi_minus"),
71+
]:
72+
# Creating all auto/cross-correlations two-point function objects for the
73+
# weak-lensing probes.
74+
for i in range(4):
75+
for j in range(i, 4):
76+
stats[f"{stat}_src{i}_src{j}"] = TwoPoint(
77+
source0=sources[f"src{i}"],
78+
source1=sources[f"src{j}"],
79+
sacc_data_type=sacc_stat,
80+
)
81+
# The following two-point function objects compute the cross correlations
82+
# between the weak-lensing and number count sources.
83+
for j in range(5):
84+
for i in range(4):
85+
stats[f"gammat_lens{j}_src{i}"] = TwoPoint(
86+
source0=sources[f"lens{j}"],
87+
source1=sources[f"src{i}"],
88+
sacc_data_type="galaxy_shearDensity_xi_t",
89+
)
90+
91+
# Finally the instances for the lensing auto-correlations are created.
92+
for i in range(5):
93+
stats[f"wtheta_lens{i}_lens{i}"] = TwoPoint(
94+
source0=sources[f"lens{i}"],
95+
source1=sources[f"lens{i}"],
96+
sacc_data_type="galaxy_density_xi",
97+
)
98+
99+
# Here we instantiate the actual likelihood. The statistics argument carry
100+
# the order of the data/theory vector.
101+
likelihood = ConstGaussianPM(statistics=list(stats.values()))
102+
103+
# We load the correct SACC file.
104+
sacc_file = params.get_string("sacc_file")
105+
# Translate envriornment variables, if needed.
106+
sacc_file = os.path.expandvars(sacc_file)
107+
sacc_data = load_sacc_data(sacc_file)
108+
109+
# The read likelihood method is called passing the loaded SACC file, the
110+
# two-point functions will receive the appropriated sections of the SACC
111+
# file and the sources their respective dndz.
112+
likelihood.read(sacc_data)
113+
114+
modeling_tools = ModelingTools(ccl_factory=CCLFactory(require_nonlinear_pk=True))
115+
116+
# After reading in the sacc data, we apply the point mass marginalization.
117+
cosmo = pyccl.CosmologyVanillaLCDM()
118+
likelihood.compute_pointmass(cosmo)
119+
120+
# This script will be loaded by the appropriated connector. The framework
121+
# will call the factory function that should return a Likelihood instance.
122+
return likelihood, modeling_tools

firecrown/likelihood/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SourceSystematic,
4040
Statistic,
4141
Tracer,
42+
TrivialStatistic,
4243
)
4344
from firecrown.likelihood._likelihood import (
4445
load_likelihood,
@@ -49,6 +50,7 @@
4950
from firecrown.likelihood._gaussian import ConstGaussian
5051
from firecrown.likelihood._gaussfamily import GaussFamily, State
5152
from firecrown.likelihood._student_t import StudentT
53+
from firecrown.likelihood._gaussian_pointmass import ConstGaussianPM
5254

5355
# Two-point statistics
5456
from firecrown.likelihood._two_point import TwoPoint, TwoPointFactory
@@ -103,6 +105,7 @@
103105
"GaussFamily",
104106
"State",
105107
"StudentT",
108+
"ConstGaussianPM",
106109
# Two-point statistics
107110
"TwoPoint",
108111
"TwoPointFactory",
@@ -120,6 +123,7 @@
120123
"SourceSystematic",
121124
# Base statistic class
122125
"Statistic",
126+
"TrivialStatistic",
123127
# Subpackages
124128
"weak_lensing",
125129
"number_counts",

firecrown/likelihood/_gaussfamily.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,14 @@ class GaussFamily(Likelihood):
142142
def __init__(
143143
self,
144144
statistics: Sequence[Statistic],
145+
use_cholesky: bool = True,
145146
) -> None:
146147
"""Initialize the base class parts of a GaussFamily object.
147148
148149
:param statistics: A list of statistics to be include in chisquared calculations
150+
:param use_cholesky: Whether to use Cholesky decomposition for chi-squared
151+
calculation. Set to False if covariance modifications make Cholesky
152+
incompatible.
149153
"""
150154
super().__init__()
151155
self.state: State = State.INITIALIZED
@@ -168,6 +172,7 @@ def __init__(
168172
self.cov_index_map: None | dict[int, int] = None
169173
self.theory_vector: None | npt.NDArray[np.double] = None
170174
self.data_vector: None | npt.NDArray[np.double] = None
175+
self._use_cholesky_for_chisq: bool = use_cholesky
171176

172177
@classmethod
173178
def create_ready(
@@ -392,6 +397,66 @@ def compute(
392397
)
393398
return self.get_data_vector(), self.compute_theory_vector(tools)
394399

400+
def _compute_chisq_cholesky(self, residuals: npt.NDArray[np.float64]) -> float:
401+
"""Compute chi-squared using Cholesky decomposition.
402+
403+
This is the numerically stable method for chi-squared calculation.
404+
405+
:param residuals: The residuals (data - theory)
406+
:return: The chi-squared value
407+
"""
408+
assert self.cholesky is not None
409+
x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True)
410+
chisq = np.dot(x, x)
411+
return float(chisq)
412+
413+
def _compute_chisq_direct(self, residuals: npt.NDArray[np.float64]) -> float:
414+
"""Compute chi-squared using direct inverse covariance multiplication.
415+
416+
This method is less numerically stable but necessary when covariance
417+
modifications make Cholesky decomposition incompatible.
418+
419+
:param residuals: The residuals (data - theory)
420+
:return: The chi-squared value
421+
"""
422+
assert self.inv_cov is not None
423+
chisq = residuals @ self.inv_cov @ residuals
424+
return float(chisq)
425+
426+
def _get_theory_and_data(
427+
self, tools: ModelingTools
428+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
429+
"""Get theory and data vectors.
430+
431+
:param tools: The ModelingTools to use for theory calculation
432+
:return: Tuple of (theory_vector, data_vector)
433+
"""
434+
theory_vector: npt.NDArray[np.float64]
435+
data_vector: npt.NDArray[np.float64]
436+
try:
437+
theory_vector = self.compute_theory_vector(tools)
438+
data_vector = self.get_data_vector()
439+
except NotImplementedError:
440+
data_vector, theory_vector = self.compute(tools)
441+
442+
assert len(data_vector) == len(theory_vector)
443+
return theory_vector, data_vector
444+
445+
def compute_chisq_impl(self, residuals: npt.NDArray[np.float64]) -> float:
446+
"""Implementation of chi-squared calculation.
447+
448+
This method can be overridden by subclasses that need different
449+
chi-squared calculation strategies. By default, it uses either
450+
Cholesky decomposition (more stable) or direct inverse covariance
451+
multiplication, depending on the _use_cholesky_for_chisq flag.
452+
453+
:param residuals: The residuals (data - theory)
454+
:return: The chi-squared value
455+
"""
456+
if self._use_cholesky_for_chisq:
457+
return self._compute_chisq_cholesky(residuals)
458+
return self._compute_chisq_direct(residuals)
459+
395460
@final
396461
@enforce_states(
397462
initial=[State.UPDATED, State.COMPUTED],
@@ -405,22 +470,9 @@ def compute_chisq(self, tools: ModelingTools) -> float:
405470
theory vector
406471
:return: the chi-squared
407472
"""
408-
theory_vector: npt.NDArray[np.float64]
409-
data_vector: npt.NDArray[np.float64]
410-
residuals: npt.NDArray[np.float64]
411-
try:
412-
theory_vector = self.compute_theory_vector(tools)
413-
data_vector = self.get_data_vector()
414-
except NotImplementedError:
415-
data_vector, theory_vector = self.compute(tools)
416-
417-
assert len(data_vector) == len(theory_vector)
473+
theory_vector, data_vector = self._get_theory_and_data(tools)
418474
residuals = np.array(data_vector - theory_vector, dtype=np.float64)
419-
420-
assert self.cholesky is not None
421-
x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True)
422-
chisq = np.dot(x, x)
423-
return chisq
475+
return self.compute_chisq_impl(residuals)
424476

425477
@enforce_states(
426478
initial=[State.READY, State.UPDATED, State.COMPUTED],

firecrown/likelihood/_gaussian.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010

1111
class ConstGaussian(GaussFamily):
12-
"""A Gaussian log-likelihood with a constant covariance matrix."""
12+
"""Base class for constant covariance Gaussian likelihoods.
13+
14+
Provides shared implementations of compute_loglike and make_realization_vector
15+
for all constant covariance Gaussian likelihood variants.
16+
"""
1317

1418
def compute_loglike(self, tools: ModelingTools) -> float:
1519
"""Compute the log-likelihood.

0 commit comments

Comments
 (0)