Skip to content

Commit 35433a4

Browse files
committed
fix derived tag check rules
1 parent b16b03e commit 35433a4

File tree

1 file changed

+119
-14
lines changed

1 file changed

+119
-14
lines changed

lineax/_operator.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,14 +1942,35 @@ def _(operator):
19421942
def _(operator, check=check):
19431943
return check(operator.primal)
19441944

1945+
@check.register(AuxLinearOperator)
1946+
def _(operator, check=check):
1947+
return check(operator.operator)
1948+
1949+
1950+
# Scaling/negating preserves these structural properties
1951+
for check in (
1952+
is_symmetric,
1953+
is_diagonal,
1954+
is_lower_triangular,
1955+
is_upper_triangular,
1956+
is_tridiagonal,
1957+
):
1958+
19451959
@check.register(MulLinearOperator)
19461960
@check.register(NegLinearOperator)
19471961
@check.register(DivLinearOperator)
1948-
@check.register(AuxLinearOperator)
19491962
def _(operator, check=check):
19501963
return check(operator.operator)
19511964

19521965

1966+
# has_unit_diagonal is NOT preserved by scaling or negation
1967+
@has_unit_diagonal.register(MulLinearOperator)
1968+
@has_unit_diagonal.register(NegLinearOperator)
1969+
@has_unit_diagonal.register(DivLinearOperator)
1970+
def _(operator):
1971+
return False
1972+
1973+
19531974
for check in (is_positive_semidefinite, is_negative_semidefinite):
19541975

19551976
@check.register(TangentLinearOperator)
@@ -1960,20 +1981,83 @@ def _(operator):
19601981
"Please open a GitHub issue: https://github.com/google/lineax"
19611982
)
19621983

1963-
@check.register(MulLinearOperator)
1964-
@check.register(DivLinearOperator)
1965-
def _(operator):
1966-
return False # play it safe, no way to tell.
1967-
1968-
@check.register(NegLinearOperator)
1969-
def _(operator, check=check):
1970-
return not check(operator.operator)
1971-
19721984
@check.register(AuxLinearOperator)
19731985
def _(operator, check=check):
19741986
return check(operator.operator)
19751987

19761988

1989+
def _scalar_sign(scalar) -> int | None:
1990+
"""Returns 1 if positive, -1 if negative, 0 if zero, None if unknown (traced)."""
1991+
try:
1992+
if scalar > 0:
1993+
return 1
1994+
elif scalar < 0:
1995+
return -1
1996+
else:
1997+
return 0
1998+
except Exception:
1999+
return None
2000+
2001+
2002+
# PSD/NSD for MulLinearOperator: depends on sign of scalar
2003+
# Zero scalar gives zero matrix which is both PSD and NSD
2004+
@is_positive_semidefinite.register(MulLinearOperator)
2005+
def _(operator):
2006+
sign = _scalar_sign(operator.scalar)
2007+
if sign == 1:
2008+
return is_positive_semidefinite(operator.operator)
2009+
elif sign == -1:
2010+
return is_negative_semidefinite(operator.operator)
2011+
elif sign == 0:
2012+
return True # zero matrix is PSD
2013+
return False
2014+
2015+
2016+
@is_negative_semidefinite.register(MulLinearOperator)
2017+
def _(operator):
2018+
sign = _scalar_sign(operator.scalar)
2019+
if sign == 1:
2020+
return is_negative_semidefinite(operator.operator)
2021+
elif sign == -1:
2022+
return is_positive_semidefinite(operator.operator)
2023+
elif sign == 0:
2024+
return True # zero matrix is NSD
2025+
return False
2026+
2027+
2028+
# PSD/NSD for DivLinearOperator: depends on sign of scalar
2029+
# Zero scalar is division by zero - return False (conservative)
2030+
@is_positive_semidefinite.register(DivLinearOperator)
2031+
def _(operator):
2032+
sign = _scalar_sign(operator.scalar)
2033+
if sign == 1:
2034+
return is_positive_semidefinite(operator.operator)
2035+
elif sign == -1:
2036+
return is_negative_semidefinite(operator.operator)
2037+
return False
2038+
2039+
2040+
@is_negative_semidefinite.register(DivLinearOperator)
2041+
def _(operator):
2042+
sign = _scalar_sign(operator.scalar)
2043+
if sign == 1:
2044+
return is_negative_semidefinite(operator.operator)
2045+
elif sign == -1:
2046+
return is_positive_semidefinite(operator.operator)
2047+
return False
2048+
2049+
2050+
# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD
2051+
@is_positive_semidefinite.register(NegLinearOperator)
2052+
def _(operator):
2053+
return is_negative_semidefinite(operator.operator)
2054+
2055+
2056+
@is_negative_semidefinite.register(NegLinearOperator)
2057+
def _(operator):
2058+
return is_positive_semidefinite(operator.operator)
2059+
2060+
19772061
for check, tag in (
19782062
(is_symmetric, symmetric_tag),
19792063
(is_diagonal, diagonal_tag),
@@ -2010,21 +2094,42 @@ def _(operator):
20102094
return False
20112095

20122096

2097+
# These properties ARE preserved under composition
20132098
for check in (
2014-
is_symmetric,
20152099
is_diagonal,
20162100
is_lower_triangular,
20172101
is_upper_triangular,
2018-
is_positive_semidefinite,
2019-
is_negative_semidefinite,
2020-
is_tridiagonal,
20212102
):
20222103

20232104
@check.register(ComposedLinearOperator)
20242105
def _(operator, check=check):
20252106
return check(operator.operator1) and check(operator.operator2)
20262107

20272108

2109+
# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.
2110+
@is_symmetric.register(ComposedLinearOperator)
2111+
def _(operator):
2112+
return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)
2113+
2114+
2115+
# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
2116+
# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
2117+
@is_tridiagonal.register(ComposedLinearOperator)
2118+
def _(operator):
2119+
if is_diagonal(operator.operator1):
2120+
return is_tridiagonal(operator.operator2)
2121+
if is_diagonal(operator.operator2):
2122+
return is_tridiagonal(operator.operator1)
2123+
return False
2124+
2125+
2126+
# PSD/NSD: not preserved under composition in general.
2127+
@is_positive_semidefinite.register(ComposedLinearOperator)
2128+
@is_negative_semidefinite.register(ComposedLinearOperator)
2129+
def _(operator):
2130+
return False
2131+
2132+
20282133
@has_unit_diagonal.register(ComposedLinearOperator)
20292134
def _(operator):
20302135
a = is_diagonal(operator)

0 commit comments

Comments
 (0)