Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
from .progress import NoProgress, Progress
from .result import Result, binned, truncated

try:
from copy import replace
except ImportError:
# Python < 3.13
from dataclasses import replace

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping

Expand Down Expand Up @@ -425,7 +431,6 @@ def invert_mixing_matrix(
_M = value.array
s1, s2 = value.spin
*_, _n, _m = _M.shape
new_ell = np.arange(_m)

with progress.task(f"invert {key}"):
if (s1 != 0) and (s2 != 0):
Expand All @@ -440,8 +445,7 @@ def invert_mixing_matrix(
else:
_inv_M = np.linalg.pinv(_M, rcond=rtol)

inv_M[key] = Result(_inv_M, axis=value.axis, ell=new_ell)

inv_M[key] = replace(M[key], array=_inv_M)
return inv_M


Expand All @@ -460,7 +464,6 @@ def apply_mixing_matrix(d, M, lmax=None):
*_, lmax = d[key].shape
dtype = d[key].array.dtype
ell_mask = M[key].ell
axis = d[key].axis
s1, s2 = d[key].spin
_d = np.atleast_2d(d[key].array)
_M = M[key].array
Expand All @@ -476,7 +479,7 @@ def apply_mixing_matrix(d, M, lmax=None):
_corr_d.append(_M @ cl)
_corr_d = np.squeeze(_corr_d)
_corr_d = np.array(list(_corr_d), dtype=dtype)
corr_d[key] = Result(_corr_d, axis=axis, ell=ell_mask)
corr_d[key] = replace(d[key], array=_corr_d, ell=ell_mask)
# truncate
corr_d = truncated(corr_d, lmax)
return corr_d
9 changes: 9 additions & 0 deletions tests/test_twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def test_inverting_mixing_matrices():
for key in mms.keys():
*_, n, m = mms[key].shape
*_, _n, _m = inv_mms[key].shape
s1, s2 = mms[key].spin
_s1, _s2 = inv_mms[key].spin
assert s1 == _s1
assert s2 == _s2
assert n == _m
assert m == _n

Expand All @@ -396,6 +400,11 @@ def test_inverting_mixing_matrices():

# test application of mixing matrices
mixed_cls = apply_mixing_matrix(cls, inv_mms)
for key in cls:
s1, s2 = cls[key].spin
_s1, _s2 = mixed_cls[key].spin
assert s1 == _s1
assert s2 == _s2
assert mixed_cls.keys() == cls.keys()
for key in mixed_cls:
(n,) = mixed_cls[key].axis
Expand Down