@@ -63,16 +63,16 @@ def mahalanobis_norm_relative(self, u, /, rv):
6363 return np .reshape (np .abs (mahalanobis ) / np .sqrt (rv .mean .size ), ())
6464
6565 def logpdf (self , u , / , rv ):
66- # The cholesky factor is triangular, so we compute a cheap slogdet.
67- diagonal = linalg .diagonal_along_axis (rv . cholesky , axis1 = - 1 , axis2 = - 2 )
66+ cholesky = linalg . qr_r ( rv . cholesky . T ). T
67+ diagonal = linalg .diagonal_along_axis (cholesky , axis1 = - 1 , axis2 = - 2 )
6868 slogdet = np .sum (np .log (np .abs (diagonal )))
6969
7070 dx = u - rv .mean
71- residual_white = linalg .solve_triangular (rv . cholesky . T , dx , trans = "T" )
72- x1 = linalg .vector_dot (residual_white , residual_white )
73- x2 = 2.0 * slogdet
74- x3 = u . size * np .log (np .pi () * 2 )
75- return - 0.5 * ( x1 + x2 + x3 )
71+ residual_white = linalg .solve_triangular (cholesky , dx , lower = True , trans = 0 )
72+ sqrnorm = linalg .vector_dot (residual_white , residual_white )
73+
74+ const = np .log (np .pi () * 2 )
75+ return - 1 / 2 * sqrnorm - u . size / 2 * const - slogdet
7676
7777 def mean (self , rv ):
7878 return rv .mean
@@ -128,12 +128,14 @@ def logpdf(self, u, /, rv):
128128 u = u [None , :]
129129
130130 def logpdf_scalar (x , r ):
131+ cholesky = linalg .qr_r (r .cholesky .T ).T
132+
131133 dx = x - r .mean
132- w = linalg .solve_triangular (r . cholesky .T , dx , trans = "T" )
134+ w = linalg .solve_triangular (cholesky .T , dx , trans = "T" )
133135
134136 maha_term = linalg .vector_dot (w , w )
135137
136- diagonal = linalg .diagonal_along_axis (r . cholesky , axis1 = - 1 , axis2 = - 2 )
138+ diagonal = linalg .diagonal_along_axis (cholesky , axis1 = - 1 , axis2 = - 2 )
137139 slogdet = np .sum (np .log (np .abs (diagonal )))
138140 logdet_term = 2.0 * slogdet
139141 return - 0.5 * (logdet_term + maha_term + x .size * np .log (np .pi () * 2 ))
@@ -195,12 +197,14 @@ def mahalanobis_norm_relative(self, u, /, rv):
195197
196198 def logpdf (self , u , / , rv ):
197199 def logpdf_scalar (x , r ):
200+ cholesky = linalg .qr_r (r .cholesky .T ).T
201+
198202 dx = x - r .mean
199- w = linalg .solve_triangular (r . cholesky .T , dx , trans = "T" )
203+ w = linalg .solve_triangular (cholesky .T , dx , trans = "T" )
200204
201205 maha_term = linalg .vector_dot (w , w )
202206
203- diagonal = linalg .diagonal_along_axis (r . cholesky , axis1 = - 1 , axis2 = - 2 )
207+ diagonal = linalg .diagonal_along_axis (cholesky , axis1 = - 1 , axis2 = - 2 )
204208 slogdet = np .sum (np .log (np .abs (diagonal )))
205209 logdet_term = 2.0 * slogdet
206210 return - 0.5 * (logdet_term + maha_term + x .size * np .log (np .pi () * 2 ))
0 commit comments