Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions examples/unmixing.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion heracles/dices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
shrinkage_factor,
gaussian_covariance,
)
from .utils import (
from ..utils import (
impose_correlation,
get_cl,
flatten,
Expand Down
6 changes: 3 additions & 3 deletions heracles/dices/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import itertools
from copy import deepcopy
from itertools import combinations
from .utils import add_to_Cls, sub_to_Cls
from ..utils import add_to_Cls, sub_to_Cls
from ..core import update_metadata
from ..result import Result, get_result_array
from ..mapping import transform
Expand Down Expand Up @@ -58,7 +58,7 @@ def jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=1):
_cls_mm = get_cls(vis_maps, jk_maps, fields, *regions)
# Mask correction
alphas = mask_correction(_cls_mm, mls0)
_cls = _natural_unmixing(_cls, alphas)
_cls = _natural_unmixing(_cls, alphas, fields)
# Bias correction
_cls = correct_bias(_cls, jk_maps, fields, *regions)
cls[regions] = _cls
Expand Down Expand Up @@ -226,7 +226,7 @@ def mask_correction(Mljk, Mls0):
# Compute alpha
alpha = wmljk / wmls0
alpha *= logistic(np.log10(abs(wmljk)))
alphas[key] = alpha
alphas[key] = replace(Mls0[key], array=alpha)
return alphas


Expand Down
4 changes: 2 additions & 2 deletions heracles/dices/shrinkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# License along with DICES. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
import itertools
from .utils import (
from ..utils import (
expand_spin0_dims,
squeeze_spin0_dims,
)
Expand All @@ -29,7 +29,7 @@
from .jackknife import (
bias,
)
from .utils import (
from ..utils import (
add_to_Cls,
impose_correlation,
get_cl,
Expand Down
51 changes: 34 additions & 17 deletions heracles/unmixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from .result import truncated
from .transforms import cl2corr, corr2cl
from .utils import get_cl

try:
from copy import replace
Expand All @@ -27,45 +28,61 @@
from dataclasses import replace


def natural_unmixing(d, m, x0=-2, k=50, patch_hole=True, lmax=None):
def natural_unmixing(d, m, fields, x0=-2, k=50, patch_hole=True, lmax=None):
"""
Natural unmixing of the data Cl.
Args:
d: Data Cl
m: mask Cl
fields: list of fields
patch_hole: If True, apply the patch hole correction
Returns:
corr_d: Corrected Cl
"""
wm = {}
m_keys = list(m.keys())
for m_key in m_keys:
_m = m[m_key].array
_wm = cl2corr(_m).T[0]
if patch_hole:
_wm *= logistic(np.log10(abs(_wm)), x0=x0, k=k)
wm[m_key] = _wm
return _natural_unmixing(d, wm, lmax=lmax)
wm[m_key] = replace(m[m_key], array=_wm)
return _natural_unmixing(d, wm, fields, lmax=lmax)


def _natural_unmixing(d, wm, lmax=None):
def _natural_unmixing(d, wm, fields, lmax=None):
"""
Natural unmixing of the data Cl.
Args:
d: Data Cl
m: mask cls
wm: mask correlation function
fields: list of fields
patch_hole: If True, apply the patch hole correction
Returns:
corr_d: Corrected Cl
"""
corr_d = {}
d_keys = list(d.keys())
wm_keys = list(wm.keys())
for d_key, wm_key in zip(d_keys, wm_keys):
a, b, i, j = d_key
masks = {}
for key, field in fields.items():
if field.mask is not None:
masks[key] = field.mask

for key in d.keys():
a, b, i, j = key
m_key = (masks[a], masks[b], i, j)
_wm = get_cl(m_key, wm)
_d = d[key]
s1, s2 = _d.spin
if lmax is None:
*_, lmax = d[d_key].shape
s1, s2 = d[d_key].spin
_d = np.atleast_2d(d[d_key])
_wm = wm[wm_key]
lmax_mask = len(wm[wm_key])
*_, lmax = _d.shape
lmax_mask = len(_wm.array)
# Grab metadata
dtype = _d.array.dtype
# pad cls
_d = np.atleast_2d(_d.array)
pad_width = [(0, 0)] * _d.ndim # no padding for other dims
pad_width[-1] = (0, lmax_mask - lmax) # pad only last dim
_d = np.pad(_d, pad_width, mode="constant", constant_values=0)
# Grab metadata
dtype = d[d_key].array.dtype
if (s1 != 0) and (s2 != 0):
__d = np.array(
[
Expand Down Expand Up @@ -109,7 +126,7 @@ def _natural_unmixing(d, wm, lmax=None):
_corr_d = np.squeeze(_corr_d)
# Add metadata back
_corr_d = np.array(list(_corr_d), dtype=dtype)
corr_d[d_key] = replace(d[d_key], array=_corr_d)
corr_d[key] = replace(d[key], array=_corr_d)
# truncate to lmax
corr_d = truncated(corr_d, lmax)
return corr_d
Expand Down
File renamed without changes.
12 changes: 6 additions & 6 deletions tests/test_dices.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ def test_get_delete2_fsky(jk_maps, njk):
assert alpha == pytest.approx(_alpha, rel=1e-1)


def test_mask_correction(cls0, mls0):
def test_mask_correction(cls0, mls0, fields):
alphas = dices.mask_correction(mls0, mls0)
_cls = heracles.unmixing._natural_unmixing(cls0, alphas)
_cls = heracles.unmixing._natural_unmixing(cls0, alphas, fields)
for key in list(cls0.keys()):
cl = cls0[key].array
_cl = _cls[key].array
assert np.isclose(cl[2:], _cl[2:]).all()


def test_polspice(cls0):
from heracles.dices.utils import get_cl
from heracles.utils import get_cl

cls = np.array(
[
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_shrinkage(cov_jk):


def test_flatten_cls(nside, cls0):
from heracles.dices.utils import _flatten, flatten
from heracles.utils import _flatten, flatten

# Check that the individual blocks are flattened correctly
for key in cls0.keys():
Expand All @@ -272,7 +272,7 @@ def test_flatten_cls(nside, cls0):


def test_flatten_cov(nside, cov_jk):
from heracles.dices.utils import _flatten, flatten
from heracles.utils import _flatten, flatten

# Check that the individual blocks are flattened correctly
for key in cov_jk.keys():
Expand Down Expand Up @@ -307,7 +307,7 @@ def test_gauss_cov(cls0, cov_jk):
# We want to undo the bias that we will add later
# for an easy check
bias = dices.jackknife.bias(_cls0)
_cls0 = dices.utils.sub_to_Cls(_cls0, bias)
_cls0 = heracles.utils.sub_to_Cls(_cls0, bias)

# Compute Gaussian covariance
gauss_cov = dices.gaussian_covariance(_cls0)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import heracles
import heracles.dices as dices
from heracles import utils


def test_add_to_cls():
cls = {}
cls[("P", "P", 1, 1)] = heracles.Result(np.ones(10))
x = {}
x[("P", "P", 1, 1)] = -1.0
_cls = dices.utils.add_to_Cls(cls, x)
__cls = dices.utils.sub_to_Cls(_cls, x)
_cls = utils.add_to_Cls(cls, x)
__cls = utils.sub_to_Cls(_cls, x)
for key in list(cls.keys()):
assert np.all(_cls[key] == np.zeros(10))
assert np.all(cls[key].__array__() == __cls[key].__array__())
Expand All @@ -25,9 +25,9 @@ def test_get_cl():
np.array([[a, ab], [ba, a]]), spin=(2, 2)
)

cl = dices.utils.get_cl(("POS", "SHE", 1, 1), cls)
cl = utils.get_cl(("POS", "SHE", 1, 1), cls)
assert np.all(cl == np.array([a, a]))
cl = dices.utils.get_cl(("SHE", "SHE", 1, 2), cls)
cl = utils.get_cl(("SHE", "SHE", 1, 2), cls)
assert np.all(cl == np.array([[a, ba], [ab, a]]))


Expand All @@ -37,11 +37,11 @@ def test_expand_squeeze_spin0_dims(cls0, cov_jk):
s1, s2 = cl.spin
dof1 = 1 if s1 == 0 else 2
dof2 = 1 if s2 == 0 else 2
_cl = dices.utils.expand_spin0_dims(cl)
_cl = utils.expand_spin0_dims(cl)
(_ax,) = _cl.axis
assert _cl.shape == (dof1, dof2, cl.shape[-1])
assert _ax == 2
__cl = dices.utils.squeeze_spin0_dims(_cl)
__cl = utils.squeeze_spin0_dims(_cl)
assert np.all(cl.__array__() == __cl.__array__())
assert cl.axis == __cl.axis

Expand All @@ -52,7 +52,7 @@ def test_expand_squeeze_spin0_dims(cls0, cov_jk):
dof_b1 = 1 if sb1 == 0 else 2
dof_a2 = 1 if sa2 == 0 else 2
dof_b2 = 1 if sb2 == 0 else 2
_cov = dices.utils.expand_spin0_dims(cov)
_cov = utils.expand_spin0_dims(cov)
_ax1, _ax2 = _cov.axis
assert _cov.shape == (
dof_a1,
Expand All @@ -63,6 +63,6 @@ def test_expand_squeeze_spin0_dims(cls0, cov_jk):
cov.shape[-1],
)
assert (_ax1, _ax2) == (4, 5)
__cov = dices.utils.squeeze_spin0_dims(_cov)
__cov = utils.squeeze_spin0_dims(_cov)
assert np.all(cov.__array__() == __cov.__array__())
assert cov.axis == __cov.axis