Skip to content

Commit 5307ccf

Browse files
authored
Add differentiable likelihood (#4)
* Added jax differentiable likelihood. * Flake8 compliance. * Added jax tests to workflow. * Cobaya compliance, fixed import error in jax like import. * Jax optional reqs changes. * Added jax tests. * Minor fixes & code style. * Path for jax like test. * Add ell cut for jax test. * Fix indexing off-by-one in data load. * Added missing test to default file run. * Added HMC example. * Updated readme. * Typo. * Typo.
1 parent 0e1f38e commit 5307ccf

File tree

9 files changed

+650
-6
lines changed

9 files changed

+650
-6
lines changed

.github/workflows/testing.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ jobs:
1818

1919
steps:
2020
- name: Checkout repository
21-
uses: actions/checkout@v2
21+
uses: actions/checkout@v3
2222

2323
- name: Install Dependencies
2424
shell: bash -el {0}
2525
run: |
2626
python -m pip install --upgrade pip
2727
pip install pytest
2828
pip install -r requirements.txt
29-
pip install -e .
29+
pip install -e .[jax]
3030
3131
- name: Cache testing Dataset
3232
shell: bash -el {0}
@@ -39,4 +39,4 @@ jobs:
3939
- name: Run Tests
4040
shell: bash -el {0}
4141
run: |
42-
pytest -v act_dr6_cmbonly/tests/test_act.py
42+
pytest -v

README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ You can now run the tests with
2828
```
2929
pytest -v --pyargs act_dr6_cmbonly
3030
```
31-
If the tests return without any error (i.e. with only warnings), then the code is probably correctly installed. You can then attempt to run chains with
31+
If the tests return without any error (i.e. with only warnings), then the code is probably correctly installed. You may get some tests which get skipped if you do not install the differentiable likelihood (see below) - you do not need to worry about this. You can then attempt to run chains with
3232
```
3333
cobaya-run yamls/run_act.yaml
3434
```
@@ -46,3 +46,18 @@ If you ever get an error that the likelihood cannot locate the data, then you ca
4646
By default, the likelihood will look for the data in either
4747
- `<pip directory>/act_dr6_cmbonly/data/` if no cobaya packages path is given, or
4848
- `<cobaya packages path>/data/ACTDR6CMBonly/` if a cobaya packages path is given.
49+
50+
## The differentiable likelihood
51+
52+
If you are for whatever reason interested, I created a differentiable likelihood as well. You can install this by installing the package with
53+
```
54+
pip install -e .[jax]
55+
```
56+
Which will also install the `jax` and `cosmopower-jax` prerequisites. The differentiable likelihood can then be imported with
57+
```
58+
import act_dr6_cmbonly
59+
like = act_dr6_cmbonly.ACTDR6jax()
60+
```
61+
I provide an example of how to run a chain with the differentiable likelihood, see the `examples/run_hmc.py` file.
62+
63+
For the most part, I do not expect that a differentiable likelihood adds much to ACT DR6 on its own. However, should people be interested in running joint analyses with other probes that provide differentiable likelihoods, then this simple likelihood should suffice. For the most part, it is simply 100 lines of python that does the same as the cobaya likelihood, but with JAX instead of numpy (as a result, some of the data is stored a bit differently internally to make use of JAX optimizations).

act_dr6_cmbonly/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
__author__ = "Hidde T. Jense"
22
__url__ = "https://github.com/ACTCollaboration/dr6-cmbonly"
3-
__version__ = "0.1.2"
3+
__version__ = "0.1.3"
44

55
try:
66
from .act_dr6_cmbonly import ACTDR6CMBonly # noqa: F401
77
from .PlanckActCut import PlanckActCut # noqa: F401
88
except ImportError:
99
pass
10+
11+
try:
12+
from .act_dr6_jaxlike import ACTDR6jax # noqa: F401
13+
except ImportError:
14+
ACTDR6jax = None # noqa: F401

act_dr6_cmbonly/act_dr6_cmbonly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,5 @@ def loglike(self, cl):
131131
return self.logp_const + logp
132132

133133
def logp(self, **param_values):
134-
cl = self.theory.get_Cl(ell_factor=True)
134+
cl = self.provider.get_Cl(ell_factor=True)
135135
return self.loglike(cl)

act_dr6_cmbonly/act_dr6_jaxlike.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import jax.numpy as np
2+
import os
3+
import sacc
4+
from typing import Sequence
5+
6+
7+
class ACTDR6jax:
8+
"""
9+
A differentiable Likelihood implementation for ACT DR6
10+
foreground-marginalized (cmb-only) data.
11+
12+
To make use of this module, make sure to install the package with
13+
pip install -e .[jax]
14+
to include any JAX dependencies.
15+
16+
I do not expect this class to get used a lot, but for the sake of
17+
"it is easy and doable to write this in 100 lines of code", I wrote this.
18+
19+
Author: Hidde T. Jense
20+
"""
21+
data_folder: str = "ACTDR6CMBonly"
22+
input_filename: str = "act_dr6_cmb_sacc.fits"
23+
polarizations: Sequence[str] = ["tt", "te", "ee"]
24+
tt_lmax: int = 9000
25+
ell_cuts: dict = {
26+
"TT": [600, 6500],
27+
"TE": [600, 6500],
28+
"EE": [500, 6500]
29+
}
30+
31+
def __init__(self, verbose: bool = False) -> None:
32+
self.__verbose = verbose
33+
34+
def load_data(self) -> None:
35+
# load the data
36+
if self.verbose:
37+
print(f"Loading data from {self.input_filename}.")
38+
39+
saccfile = sacc.Sacc.load_fits(
40+
os.path.join(self.data_folder, self.input_filename)
41+
)
42+
43+
idx_max = 0
44+
pol_dt = {"t": "0", "e": "e", "b": "b"}
45+
self.spec_meta = []
46+
self.cull = []
47+
48+
for pol in self.polarizations:
49+
p1, p2 = pol.lower()
50+
t1, t2 = pol_dt[p1], pol_dt[p2]
51+
dt = f"cl_{t1}{t2}"
52+
53+
tracers = saccfile.get_tracer_combinations(dt)
54+
55+
for tr1, tr2 in tracers:
56+
if self.verbose:
57+
print(f"{tr1}x{tr2}")
58+
59+
lmin, lmax = self.ell_cuts.get(pol.upper(), (-np.inf, np.inf))
60+
61+
ls, mu, ind = saccfile.get_ell_cl(dt, tr1, tr2,
62+
return_ind=True)
63+
mask = np.logical_and(ls >= lmin, ls <= lmax)
64+
if not np.all(mask):
65+
if self.verbose:
66+
print(f"Cutting {pol} data to the \
67+
range [{lmin}-{lmax}].")
68+
self.cull.append(ind[~mask])
69+
window = saccfile.get_bandpower_windows(ind[mask])
70+
71+
self.spec_meta.append({
72+
"data_type": dt,
73+
"tracer1": tr1,
74+
"tracer2": tr2,
75+
"pol": pol.lower(),
76+
"ell": ls[mask],
77+
"spec": mu[mask],
78+
"idx": ind[mask],
79+
"window": window
80+
})
81+
82+
idx_max = max(idx_max, max(ind))
83+
84+
self.data_vec = np.zeros((idx_max + 1,))
85+
self.spec_picker = np.zeros((idx_max + 1, len(self.spec_meta)))
86+
self.win_func = np.zeros((idx_max + 1, self.tt_lmax - 1))
87+
88+
"""
89+
Most of the magic happens here - we create some binning and stacking
90+
functions to ensure that everything can be quickly jit-ed by JAX.
91+
92+
(This is not a super clean way to do this, but it is the most
93+
efficient way to do this easily.)
94+
"""
95+
for i, m in enumerate(self.spec_meta):
96+
self.data_vec = self.data_vec.at[m["idx"]].set(m["spec"])
97+
self.spec_picker = self.spec_picker.at[m["idx"], i].set(1)
98+
99+
j1, j2 = m["window"].values.min()-2, m["window"].values.max()-2
100+
if j2 >= self.tt_lmax-1:
101+
j2 = self.tt_lmax-2
102+
imax = j2-j1+1
103+
self.win_func = self.win_func.at[m["idx"], j1:j2+1].set(
104+
m["window"].weight[:imax].astype(float).T
105+
)
106+
107+
self.covmat = np.array(saccfile.covariance.covmat)
108+
109+
for culls in self.cull:
110+
self.covmat = self.covmat.at[culls, :].set(0.0)
111+
self.covmat = self.covmat.at[:, culls].set(0.0)
112+
self.covmat = self.covmat.at[culls, culls].set(1e10)
113+
114+
self.inv_cov = np.linalg.inv(self.covmat)
115+
self.logp_const = -0.5 * np.log(2.0 * np.pi) * len(self.data_vec)
116+
self.logp_const -= 0.5 * np.linalg.slogdet(self.covmat)[1]
117+
118+
def logp(self, dell: np.ndarray) -> float:
119+
"""
120+
Compute the log-likelihood of some cosmology.
121+
It expects an Lx3 array, where L is the number of ell-modes
122+
(where dell[:,0] = TT[2..lmax], dell[:,1] = TE[2..lmax], and
123+
dell[:,2] = EE[2..lmax]).
124+
"""
125+
ps_vec = np.dot(self.win_func, dell[:self.tt_lmax-1])
126+
ps_vec = np.sum(ps_vec * self.spec_picker, axis=1)
127+
128+
self.ps_vec = ps_vec
129+
130+
delta = self.data_vec - ps_vec
131+
logp = -0.5 * (delta @ self.inv_cov @ delta)
132+
return self.logp_const + logp
133+
134+
@property
135+
def verbose(self) -> bool:
136+
return self.__verbose

act_dr6_cmbonly/tests/test_jax.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
This test checks that the differentiable DR6
3+
likelihood is working as intended.
4+
"""
5+
import pytest # noqa F401
6+
import sys
7+
8+
try:
9+
import jax # noqa F401
10+
import jax.numpy as np
11+
from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
12+
except ImportError:
13+
pass
14+
15+
16+
@pytest.mark.skipif("jax" not in sys.modules, reason="JAX is not installed.")
17+
def test_import_jaxlike():
18+
from act_dr6_cmbonly import ACTDR6jax # noqa F401
19+
20+
21+
@pytest.mark.skipif("jax" not in sys.modules, reason="JAX is not installed.")
22+
def test_jaxlike():
23+
import act_dr6_cmbonly
24+
like = act_dr6_cmbonly.ACTDR6jax() # noqa F401
25+
26+
27+
@pytest.mark.skipif("jax" not in sys.modules, reason="JAX is not installed.")
28+
def test_jax_load_data():
29+
import act_dr6_cmbonly
30+
like = act_dr6_cmbonly.ACTDR6jax()
31+
like.data_folder = "act_dr6_cmbonly/data"
32+
like.load_data()
33+
34+
35+
@pytest.mark.skipif("cosmopower_jax" not in sys.modules,
36+
reason="Cosmopower-JAX is not installed.")
37+
def test_jax_loglike():
38+
import act_dr6_cmbonly
39+
T_CMB = 2.7255e6
40+
cosmo_params = np.array([0.025, 0.12, 0.68, 0.054, 0.97, 3.05])
41+
42+
emu_tt = CPJ(probe='cmb_tt')
43+
emu_te = CPJ(probe='cmb_te')
44+
emu_ee = CPJ(probe='cmb_ee')
45+
ellfac = emu_tt.modes * (emu_tt.modes + 1.0) / (2.0 * np.pi)
46+
cl_tt = (T_CMB) ** 2.0 * emu_tt.predict(cosmo_params) * ellfac
47+
cl_te = (T_CMB) ** 2.0 * emu_te.predict(cosmo_params) * ellfac
48+
cl_ee = (T_CMB) ** 2.0 * emu_ee.predict(cosmo_params) * ellfac
49+
50+
cell = np.stack([cl_tt, cl_te, cl_ee], axis=1)
51+
52+
like = act_dr6_cmbonly.ACTDR6jax()
53+
like.data_folder = "act_dr6_cmbonly/data"
54+
like.ell_cuts = {
55+
"TT": [600, 2508],
56+
"TE": [600, 2508],
57+
"EE": [500, 2508]
58+
}
59+
like.tt_lmax = 2508
60+
like.load_data()
61+
62+
logp = like.logp(cell)
63+
64+
assert np.isclose(logp, -1236.189)
65+
66+
67+
if __name__ == "__main__":
68+
test_import_jaxlike()
69+
test_jax_load_data()
70+
test_jaxlike()
71+
test_jax_loglike()

0 commit comments

Comments
 (0)