@@ -679,7 +679,7 @@ def _em(self, X, W=None, aff=None):
679679 self ._update_bias (XB , Z , W , vx = vx )
680680 XB = torch .exp (self .beta , out = XB ).mul_ (X )
681681 plot_mode = 'bias'
682-
682+
683683 if n_iter_bias > 1 and lb - olb < self .tol * nW :
684684 break
685685
@@ -1297,15 +1297,15 @@ def _affine_gradient(self):
12971297 return linalg ._expm (self .eta , self .affine_basis , grad_X = True )[1 ]
12981298
12991299 def _full_affine_gradient (self , aff ):
1300- """Derivative of the full affine (aff_prior \ (aff_align @ aff_dat))
1300+ r """Derivative of the full affine (aff_prior \ (aff_align @ aff_dat))
13011301 with respect to the Lie parameters if aff_align."""
13021302 g_aff = self ._affine_gradient
13031303 g_aff = torch .matmul (g_aff , aff )
13041304 g_aff = torch .matmul (self .affine_prior .inverse ().to (aff ), g_aff )
13051305 return g_aff
13061306
13071307 def _full_affine (self , aff ):
1308- """Full affine matrix: aff_prior \ (aff_align @ aff_dat)"""
1308+ r """Full affine matrix: aff_prior \ (aff_align @ aff_dat)"""
13091309 if self .eta is not None :
13101310 aff = torch .matmul (self .affine .to (aff ), aff )
13111311 aff = torch .matmul (self .affine_prior .inverse ().to (aff ), aff )
@@ -1583,12 +1583,12 @@ def _update_lb_wishart(self):
15831583 sigma0 , df0 = self .wishart
15841584
15851585 # Kullbeck-Leibler divergence between inverse-Wishart distributions
1586- # 2*KL(q||p) = N0 * (logdet(S1) - logdet(S0))
1586+ # 2*KL(q||p) = N0 * (logdet(S1) - logdet(S0))
15871587 # + N1 * tr(S1\S0)
15881588 # + 2 * (gammal(N0/2) - gammal(N1/2))
1589- # + (N1 - N0) * digamma(N1/2)
1589+ # + (N1 - N0) * digamma(N1/2)
15901590 # - N1 * C
1591- #
1591+ #
15921592 # If we use Sigma1 = S1/N1 and Sigma0 = S0/N0, the first term becomes
15931593 # N0 * (logdet(Sigma1) - logdet(Sigma0)) + N0 * C * (log(N1) - log(N0))
15941594 # and the second term becomes
@@ -1821,7 +1821,7 @@ def digamma(x):
18211821 d1 = - 0.5772156649015328606065121 # = digamma(1)
18221822 d2 = (pymath .pi * pymath .pi )/ 6
18231823 return d1 - 1 / x + d2 * x
1824- # --- not large: reduce to digamma(x + n) where (x + n) is large
1824+ # --- not large: reduce to digamma(x + n) where (x + n) is large
18251825 large = 9.5
18261826 y = 0
18271827 while x < large :
@@ -1852,5 +1852,3 @@ def mvlgamma(x, order=1):
18521852 for p in range (1 , order + 1 ):
18531853 y += pymath .lgamma (x + (1 - p ) / 2 )
18541854 return y
1855-
1856-
0 commit comments