Skip to content

Commit a61917e

Browse files
authored
fix derived tag check rules (#192)
* fix derived tag check rules * simpler docstring for _scalar_sign * incorporate review comments on scalar sign
1 parent b16b03e commit a61917e

File tree

1 file changed

+128
-14
lines changed

1 file changed

+128
-14
lines changed

lineax/_operator.py

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import abc
16+
import enum
1617
import functools as ft
1718
import math
1819
import warnings
@@ -1942,14 +1943,35 @@ def _(operator):
19421943
def _(operator, check=check):
19431944
return check(operator.primal)
19441945

1946+
@check.register(AuxLinearOperator)
1947+
def _(operator, check=check):
1948+
return check(operator.operator)
1949+
1950+
1951+
# Scaling/negating preserves these structural properties
1952+
for check in (
1953+
is_symmetric,
1954+
is_diagonal,
1955+
is_lower_triangular,
1956+
is_upper_triangular,
1957+
is_tridiagonal,
1958+
):
1959+
19451960
@check.register(MulLinearOperator)
19461961
@check.register(NegLinearOperator)
19471962
@check.register(DivLinearOperator)
1948-
@check.register(AuxLinearOperator)
19491963
def _(operator, check=check):
19501964
return check(operator.operator)
19511965

19521966

1967+
# has_unit_diagonal is NOT preserved by scaling or negation
1968+
@has_unit_diagonal.register(MulLinearOperator)
1969+
@has_unit_diagonal.register(NegLinearOperator)
1970+
@has_unit_diagonal.register(DivLinearOperator)
1971+
def _(operator):
1972+
return False
1973+
1974+
19531975
for check in (is_positive_semidefinite, is_negative_semidefinite):
19541976

19551977
@check.register(TangentLinearOperator)
@@ -1960,20 +1982,91 @@ def _(operator):
19601982
"Please open a GitHub issue: https://github.com/google/lineax"
19611983
)
19621984

1963-
@check.register(MulLinearOperator)
1964-
@check.register(DivLinearOperator)
1965-
def _(operator):
1966-
return False # play it safe, no way to tell.
1967-
1968-
@check.register(NegLinearOperator)
1969-
def _(operator, check=check):
1970-
return not check(operator.operator)
1971-
19721985
@check.register(AuxLinearOperator)
19731986
def _(operator, check=check):
19741987
return check(operator.operator)
19751988

19761989

1990+
class _ScalarSign(enum.Enum):
1991+
positive = enum.auto()
1992+
negative = enum.auto()
1993+
zero = enum.auto()
1994+
unknown = enum.auto()
1995+
1996+
1997+
def _scalar_sign(scalar) -> _ScalarSign:
1998+
"""Returns the sign of a scalar, or unknown for JAX tracers."""
1999+
if isinstance(scalar, (int, float, np.ndarray, np.generic)):
2000+
scalar = float(scalar)
2001+
if scalar > 0:
2002+
return _ScalarSign.positive
2003+
elif scalar < 0:
2004+
return _ScalarSign.negative
2005+
else:
2006+
return _ScalarSign.zero
2007+
else:
2008+
return _ScalarSign.unknown
2009+
2010+
2011+
# PSD/NSD for MulLinearOperator: depends on sign of scalar
2012+
# Zero scalar gives zero matrix which is both PSD and NSD
2013+
@is_positive_semidefinite.register(MulLinearOperator)
2014+
def _(operator):
2015+
sign = _scalar_sign(operator.scalar)
2016+
if sign is _ScalarSign.positive:
2017+
return is_positive_semidefinite(operator.operator)
2018+
elif sign is _ScalarSign.negative:
2019+
return is_negative_semidefinite(operator.operator)
2020+
elif sign is _ScalarSign.zero:
2021+
return True # zero matrix is PSD
2022+
return False
2023+
2024+
2025+
@is_negative_semidefinite.register(MulLinearOperator)
2026+
def _(operator):
2027+
sign = _scalar_sign(operator.scalar)
2028+
if sign is _ScalarSign.positive:
2029+
return is_negative_semidefinite(operator.operator)
2030+
elif sign is _ScalarSign.negative:
2031+
return is_positive_semidefinite(operator.operator)
2032+
elif sign is _ScalarSign.zero:
2033+
return True # zero matrix is NSD
2034+
return False
2035+
2036+
2037+
# PSD/NSD for DivLinearOperator: depends on sign of scalar
2038+
# Zero scalar is division by zero - return False (conservative)
2039+
@is_positive_semidefinite.register(DivLinearOperator)
2040+
def _(operator):
2041+
sign = _scalar_sign(operator.scalar)
2042+
if sign is _ScalarSign.positive:
2043+
return is_positive_semidefinite(operator.operator)
2044+
elif sign is _ScalarSign.negative:
2045+
return is_negative_semidefinite(operator.operator)
2046+
return False
2047+
2048+
2049+
@is_negative_semidefinite.register(DivLinearOperator)
2050+
def _(operator):
2051+
sign = _scalar_sign(operator.scalar)
2052+
if sign is _ScalarSign.positive:
2053+
return is_negative_semidefinite(operator.operator)
2054+
elif sign is _ScalarSign.negative:
2055+
return is_positive_semidefinite(operator.operator)
2056+
return False
2057+
2058+
2059+
# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD
2060+
@is_positive_semidefinite.register(NegLinearOperator)
2061+
def _(operator):
2062+
return is_negative_semidefinite(operator.operator)
2063+
2064+
2065+
@is_negative_semidefinite.register(NegLinearOperator)
2066+
def _(operator):
2067+
return is_positive_semidefinite(operator.operator)
2068+
2069+
19772070
for check, tag in (
19782071
(is_symmetric, symmetric_tag),
19792072
(is_diagonal, diagonal_tag),
@@ -2010,21 +2103,42 @@ def _(operator):
20102103
return False
20112104

20122105

2106+
# These properties ARE preserved under composition
20132107
for check in (
2014-
is_symmetric,
20152108
is_diagonal,
20162109
is_lower_triangular,
20172110
is_upper_triangular,
2018-
is_positive_semidefinite,
2019-
is_negative_semidefinite,
2020-
is_tridiagonal,
20212111
):
20222112

20232113
@check.register(ComposedLinearOperator)
20242114
def _(operator, check=check):
20252115
return check(operator.operator1) and check(operator.operator2)
20262116

20272117

2118+
# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.
2119+
@is_symmetric.register(ComposedLinearOperator)
2120+
def _(operator):
2121+
return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)
2122+
2123+
2124+
# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
2125+
# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
2126+
@is_tridiagonal.register(ComposedLinearOperator)
2127+
def _(operator):
2128+
if is_diagonal(operator.operator1):
2129+
return is_tridiagonal(operator.operator2)
2130+
if is_diagonal(operator.operator2):
2131+
return is_tridiagonal(operator.operator1)
2132+
return False
2133+
2134+
2135+
# PSD/NSD: not preserved under composition in general.
2136+
@is_positive_semidefinite.register(ComposedLinearOperator)
2137+
@is_negative_semidefinite.register(ComposedLinearOperator)
2138+
def _(operator):
2139+
return False
2140+
2141+
20282142
@has_unit_diagonal.register(ComposedLinearOperator)
20292143
def _(operator):
20302144
a = is_diagonal(operator)

0 commit comments

Comments
 (0)