@@ -136,6 +136,8 @@ class CCAMean(BaseMetric):
136136 def _compute_raw (
137137 self , X : torch .Tensor , Y : torch .Tensor , config : MetricConfig
138138 ) -> float :
139+ # L/R factorisation; the algebraically-equivalent Cxx^{-1/2} @ Cxy @ Cyy^{-1/2}
140+ # is float32-unstable when d > n.
139141 device = config .device
140142 reg = 1e-6
141143 Xc = (X - X .mean (0 , keepdim = True )).to (device = device , dtype = torch .float32 )
@@ -145,14 +147,15 @@ def _compute_raw(
145147 eye_y = torch .eye (dy , device = device , dtype = torch .float32 )
146148 Cxx = (Xc .T @ Xc ) / (n - 1 ) + reg * eye_x
147149 Cyy = (Yc .T @ Yc ) / (n - 1 ) + reg * eye_y
148- Cxy = (Xc .T @ Yc ) / (n - 1 )
149150 Sx , Ux = torch .linalg .eigh (Cxx )
150151 Sy , Uy = torch .linalg .eigh (Cyy )
151152 Sx = torch .clamp (Sx , min = reg )
152153 Sy = torch .clamp (Sy , min = reg )
153154 Cxx_inv_sqrt = Ux @ torch .diag (1.0 / torch .sqrt (Sx )) @ Ux .T
154155 Cyy_inv_sqrt = Uy @ torch .diag (1.0 / torch .sqrt (Sy )) @ Uy .T
155- T = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt
156+ L = Cxx_inv_sqrt @ Xc .T
157+ R = Yc @ Cyy_inv_sqrt
158+ T = (L @ R ) / (n - 1 )
156159 try :
157160 eigvals = torch .linalg .eigvalsh (T @ T .T )
158161 svals = torch .sqrt (torch .clamp (eigvals , min = 0.0 ))
@@ -354,6 +357,8 @@ class PWCCA(BaseMetric):
354357 def _compute_raw (
355358 self , X : torch .Tensor , Y : torch .Tensor , config : MetricConfig
356359 ) -> float :
360+ # L/R factorisation; the algebraically-equivalent Cxx^{-1/2} @ Cxy @ Cyy^{-1/2}
361+ # is float32-unstable when d > n.
357362 device = config .device
358363 reg = 1e-6
359364 Xc = (X - X .mean (0 , keepdim = True )).to (device = device , dtype = torch .float32 )
@@ -363,14 +368,15 @@ def _compute_raw(
363368 eye_y = torch .eye (dy , device = device , dtype = torch .float32 )
364369 Cxx = (Xc .T @ Xc ) / (n - 1 ) + reg * eye_x
365370 Cyy = (Yc .T @ Yc ) / (n - 1 ) + reg * eye_y
366- Cxy = (Xc .T @ Yc ) / (n - 1 )
367371 Sx , Ux = torch .linalg .eigh (Cxx )
368372 Sy , Uy = torch .linalg .eigh (Cyy )
369373 Sx = torch .clamp (Sx , min = reg )
370374 Sy = torch .clamp (Sy , min = reg )
371375 Cxx_inv_sqrt = Ux @ torch .diag (1.0 / torch .sqrt (Sx )) @ Ux .T
372376 Cyy_inv_sqrt = Uy @ torch .diag (1.0 / torch .sqrt (Sy )) @ Uy .T
373- T = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt
377+ L = Cxx_inv_sqrt @ Xc .T
378+ R = Yc @ Cyy_inv_sqrt
379+ T = (L @ R ) / (n - 1 )
374380 eigvals , U = torch .linalg .eigh (T @ T .T )
375381 # eigh ascending → reverse for SVD descending convention
376382 svals = torch .sqrt (torch .clamp (eigvals .flip (- 1 ), min = 0.0 ))
0 commit comments