1313# limitations under the License.
1414
1515import abc
16+ import enum
1617import functools as ft
1718import math
1819import warnings
@@ -1942,14 +1943,35 @@ def _(operator):
19421943 def _ (operator , check = check ):
19431944 return check (operator .primal )
19441945
1946+ @check .register (AuxLinearOperator )
1947+ def _ (operator , check = check ):
1948+ return check (operator .operator )
1949+
1950+
1951+ # Scaling/negating preserves these structural properties
1952+ for check in (
1953+ is_symmetric ,
1954+ is_diagonal ,
1955+ is_lower_triangular ,
1956+ is_upper_triangular ,
1957+ is_tridiagonal ,
1958+ ):
1959+
19451960 @check .register (MulLinearOperator )
19461961 @check .register (NegLinearOperator )
19471962 @check .register (DivLinearOperator )
1948- @check .register (AuxLinearOperator )
19491963 def _ (operator , check = check ):
19501964 return check (operator .operator )
19511965
19521966
1967+ # has_unit_diagonal is NOT preserved by scaling or negation
1968+ @has_unit_diagonal .register (MulLinearOperator )
1969+ @has_unit_diagonal .register (NegLinearOperator )
1970+ @has_unit_diagonal .register (DivLinearOperator )
1971+ def _ (operator ):
1972+ return False
1973+
1974+
19531975for check in (is_positive_semidefinite , is_negative_semidefinite ):
19541976
19551977 @check .register (TangentLinearOperator )
@@ -1960,20 +1982,91 @@ def _(operator):
19601982 "Please open a GitHub issue: https://github.com/google/lineax"
19611983 )
19621984
1963- @check .register (MulLinearOperator )
1964- @check .register (DivLinearOperator )
1965- def _ (operator ):
1966- return False # play it safe, no way to tell.
1967-
1968- @check .register (NegLinearOperator )
1969- def _ (operator , check = check ):
1970- return not check (operator .operator )
1971-
19721985 @check .register (AuxLinearOperator )
19731986 def _ (operator , check = check ):
19741987 return check (operator .operator )
19751988
19761989
1990+ class _ScalarSign (enum .Enum ):
1991+ positive = enum .auto ()
1992+ negative = enum .auto ()
1993+ zero = enum .auto ()
1994+ unknown = enum .auto ()
1995+
1996+
1997+ def _scalar_sign (scalar ) -> _ScalarSign :
1998+ """Returns the sign of a scalar, or unknown for JAX tracers."""
1999+ if isinstance (scalar , (int , float , np .ndarray , np .generic )):
2000+ scalar = float (scalar )
2001+ if scalar > 0 :
2002+ return _ScalarSign .positive
2003+ elif scalar < 0 :
2004+ return _ScalarSign .negative
2005+ else :
2006+ return _ScalarSign .zero
2007+ else :
2008+ return _ScalarSign .unknown
2009+
2010+
2011+ # PSD/NSD for MulLinearOperator: depends on sign of scalar
2012+ # Zero scalar gives zero matrix which is both PSD and NSD
2013+ @is_positive_semidefinite .register (MulLinearOperator )
2014+ def _ (operator ):
2015+ sign = _scalar_sign (operator .scalar )
2016+ if sign is _ScalarSign .positive :
2017+ return is_positive_semidefinite (operator .operator )
2018+ elif sign is _ScalarSign .negative :
2019+ return is_negative_semidefinite (operator .operator )
2020+ elif sign is _ScalarSign .zero :
2021+ return True # zero matrix is PSD
2022+ return False
2023+
2024+
2025+ @is_negative_semidefinite .register (MulLinearOperator )
2026+ def _ (operator ):
2027+ sign = _scalar_sign (operator .scalar )
2028+ if sign is _ScalarSign .positive :
2029+ return is_negative_semidefinite (operator .operator )
2030+ elif sign is _ScalarSign .negative :
2031+ return is_positive_semidefinite (operator .operator )
2032+ elif sign is _ScalarSign .zero :
2033+ return True # zero matrix is NSD
2034+ return False
2035+
2036+
2037+ # PSD/NSD for DivLinearOperator: depends on sign of scalar
2038+ # Zero scalar is division by zero - return False (conservative)
2039+ @is_positive_semidefinite .register (DivLinearOperator )
2040+ def _ (operator ):
2041+ sign = _scalar_sign (operator .scalar )
2042+ if sign is _ScalarSign .positive :
2043+ return is_positive_semidefinite (operator .operator )
2044+ elif sign is _ScalarSign .negative :
2045+ return is_negative_semidefinite (operator .operator )
2046+ return False
2047+
2048+
2049+ @is_negative_semidefinite .register (DivLinearOperator )
2050+ def _ (operator ):
2051+ sign = _scalar_sign (operator .scalar )
2052+ if sign is _ScalarSign .positive :
2053+ return is_negative_semidefinite (operator .operator )
2054+ elif sign is _ScalarSign .negative :
2055+ return is_positive_semidefinite (operator .operator )
2056+ return False
2057+
2058+
2059+ # PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD
2060+ @is_positive_semidefinite .register (NegLinearOperator )
2061+ def _ (operator ):
2062+ return is_negative_semidefinite (operator .operator )
2063+
2064+
2065+ @is_negative_semidefinite .register (NegLinearOperator )
2066+ def _ (operator ):
2067+ return is_positive_semidefinite (operator .operator )
2068+
2069+
19772070for check , tag in (
19782071 (is_symmetric , symmetric_tag ),
19792072 (is_diagonal , diagonal_tag ),
@@ -2010,21 +2103,42 @@ def _(operator):
20102103 return False
20112104
20122105
2106+ # These properties ARE preserved under composition
20132107for check in (
2014- is_symmetric ,
20152108 is_diagonal ,
20162109 is_lower_triangular ,
20172110 is_upper_triangular ,
2018- is_positive_semidefinite ,
2019- is_negative_semidefinite ,
2020- is_tridiagonal ,
20212111):
20222112
20232113 @check .register (ComposedLinearOperator )
20242114 def _ (operator , check = check ):
20252115 return check (operator .operator1 ) and check (operator .operator2 )
20262116
20272117
2118+ # is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.
2119+ @is_symmetric .register (ComposedLinearOperator )
2120+ def _ (operator ):
2121+ return is_diagonal (operator .operator1 ) and is_diagonal (operator .operator2 )
2122+
2123+
2124+ # is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
2125+ # tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
2126+ @is_tridiagonal .register (ComposedLinearOperator )
2127+ def _ (operator ):
2128+ if is_diagonal (operator .operator1 ):
2129+ return is_tridiagonal (operator .operator2 )
2130+ if is_diagonal (operator .operator2 ):
2131+ return is_tridiagonal (operator .operator1 )
2132+ return False
2133+
2134+
2135+ # PSD/NSD: not preserved under composition in general.
2136+ @is_positive_semidefinite .register (ComposedLinearOperator )
2137+ @is_negative_semidefinite .register (ComposedLinearOperator )
2138+ def _ (operator ):
2139+ return False
2140+
2141+
20282142@has_unit_diagonal .register (ComposedLinearOperator )
20292143def _ (operator ):
20302144 a = is_diagonal (operator )
0 commit comments