From b480ba756c55b73a6d6056923cec86fa2e5e76ca Mon Sep 17 00:00:00 2001 From: flaport Date: Wed, 11 Mar 2026 10:34:42 +0100 Subject: [PATCH 1/2] Fix EME r2l cascade and lossy-unitarity projection --- src/meow/eme/common.py | 6 ++-- src/meow/eme/propagate.py | 3 +- src/tests/test_eme_fixes.py | 62 +++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 src/tests/test_eme_fixes.py diff --git a/src/meow/eme/common.py b/src/meow/eme/common.py index 285df1ba..206d178c 100644 --- a/src/meow/eme/common.py +++ b/src/meow/eme/common.py @@ -110,8 +110,10 @@ def compute_interface_s_matrix( # enforce S@S.H is diagonal: HACK! if enforce_lossy_unitarity: - U, s, V = np.linalg.svd(S) - S = np.diag(s) @ U @ V + # Project to a contractive matrix by clipping singular values to <= 1. + U, s, Vh = np.linalg.svd(S, full_matrices=False) + s_clipped = np.minimum(s, 1.0) + S = U @ np.diag(s_clipped) @ Vh # ensure reciprocity: HACK? if enforce_reciprocity: diff --git a/src/meow/eme/propagate.py b/src/meow/eme/propagate.py index 429466f4..31372957 100644 --- a/src/meow/eme/propagate.py +++ b/src/meow/eme/propagate.py @@ -81,7 +81,8 @@ def r2l_matrices( """Right to left S-matrices.""" Ss = [pairs[-1]] - for p in pairs[-1::-1]: + # Already seeded with the rightmost pair; only prepend remaining pairs. + for p in pairs[-2::-1]: Ss.append(_connect_two(p, Ss[-1], sax_backend)) return Ss[::-1] diff --git a/src/tests/test_eme_fixes.py b/src/tests/test_eme_fixes.py new file mode 100644 index 00000000..98ce8de2 --- /dev/null +++ b/src/tests/test_eme_fixes.py @@ -0,0 +1,62 @@ +import numpy as np + +import meow.eme.common as eme_common +import meow.eme.propagate as eme_propagate + + +def test_r2l_matrices_does_not_duplicate_last_pair(monkeypatch): + calls: list[tuple[str, str]] = [] + + def fake_connect_two(l, r, sax_backend): # noqa: ANN001,ARG001 + calls.append((l, r)) + return f"({l}>{r})" + + monkeypatch.setattr(eme_propagate, "_connect_two", fake_connect_two) + + pairs = ["p0", "p1", "p2"] + matrices = eme_propagate.r2l_matrices(pairs, sax_backend="default") + + assert calls == [("p1", "p2"), ("p0", "(p1>p2)")] + assert matrices == ["(p0>(p1>p2))", "(p1>p2)", "p2"] + + +def test_enforce_lossy_unitarity_projects_to_contractive_matrix(monkeypatch): + left = object() + right = object() + + def fake_inner_product_conj(a, b): # noqa: ANN001 + if a is left and b is left: + return 1.0 + if a is right and b is right: + return 1.0 + if a is left and b is right: + return 0.01 + if a is right and b is left: + return 10.0 + msg = "unexpected mode pair" + raise AssertionError(msg) + + monkeypatch.setattr(eme_common, "inner_product_conj", fake_inner_product_conj) + + S_no, _ = eme_common.compute_interface_s_matrix( + [left], + [right], + enforce_lossy_unitarity=False, + ignore_warnings=False, + ) + S_yes, _ = eme_common.compute_interface_s_matrix( + [left], + [right], + enforce_lossy_unitarity=True, + ignore_warnings=False, + ) + + s_no = np.linalg.svd(S_no, compute_uv=False) + assert float(s_no.max()) > 1.0 + + U, s, Vh = np.linalg.svd(S_no, full_matrices=False) + expected = U @ np.diag(np.minimum(s, 1.0)) @ Vh + assert np.allclose(S_yes, expected) + + s_yes = np.linalg.svd(S_yes, compute_uv=False) + assert float(s_yes.max()) <= 1.0 + 1e-12 From 40c3b25bedbee9935435257f6f9cffefe03c2bc0 Mon Sep 17 00:00:00 2001 From: flaport Date: Wed, 11 Mar 2026 10:39:03 +0100 Subject: [PATCH 2/2] Fix test typing for pre-commit checks --- src/tests/test_eme_fixes.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/tests/test_eme_fixes.py b/src/tests/test_eme_fixes.py index 98ce8de2..a514f947 100644 --- a/src/tests/test_eme_fixes.py +++ b/src/tests/test_eme_fixes.py @@ -1,30 +1,39 @@ +from typing import Any, cast + import numpy as np +import pytest +import sax import meow.eme.common as eme_common import meow.eme.propagate as eme_propagate +from meow.mode import Mode -def test_r2l_matrices_does_not_duplicate_last_pair(monkeypatch): +def test_r2l_matrices_does_not_duplicate_last_pair( + monkeypatch: pytest.MonkeyPatch, +) -> None: calls: list[tuple[str, str]] = [] - def fake_connect_two(l, r, sax_backend): # noqa: ANN001,ARG001 + def fake_connect_two(l: Any, r: Any, sax_backend: sax.Backend) -> str: # noqa: ARG001 calls.append((l, r)) return f"({l}>{r})" monkeypatch.setattr(eme_propagate, "_connect_two", fake_connect_two) - pairs = ["p0", "p1", "p2"] - matrices = eme_propagate.r2l_matrices(pairs, sax_backend="default") + pairs = cast(list[sax.STypeMM], ["p0", "p1", "p2"]) + matrices = eme_propagate.r2l_matrices(pairs, sax_backend="klu") assert calls == [("p1", "p2"), ("p0", "(p1>p2)")] assert matrices == ["(p0>(p1>p2))", "(p1>p2)", "p2"] -def test_enforce_lossy_unitarity_projects_to_contractive_matrix(monkeypatch): - left = object() - right = object() +def test_enforce_lossy_unitarity_projects_to_contractive_matrix( + monkeypatch: pytest.MonkeyPatch, +) -> None: + left = cast(Mode, object()) + right = cast(Mode, object()) - def fake_inner_product_conj(a, b): # noqa: ANN001 + def fake_inner_product_conj(a: Any, b: Any) -> float: if a is left and b is left: return 1.0 if a is right and b is right: