Skip to content

Commit b603fd7

Browse files
reworked cca computation for fp32
1 parent efd5fa7 commit b603fd7

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

aristotelian/metrics/cca.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class CCAMean(BaseMetric):
136136
def _compute_raw(
137137
self, X: torch.Tensor, Y: torch.Tensor, config: MetricConfig
138138
) -> float:
139+
# L/R factorisation; the algebraically-equivalent Cxx^{-1/2} @ Cxy @ Cyy^{-1/2}
140+
# is float32-unstable when d > n.
139141
device = config.device
140142
reg = 1e-6
141143
Xc = (X - X.mean(0, keepdim=True)).to(device=device, dtype=torch.float32)
@@ -145,14 +147,15 @@ def _compute_raw(
145147
eye_y = torch.eye(dy, device=device, dtype=torch.float32)
146148
Cxx = (Xc.T @ Xc) / (n - 1) + reg * eye_x
147149
Cyy = (Yc.T @ Yc) / (n - 1) + reg * eye_y
148-
Cxy = (Xc.T @ Yc) / (n - 1)
149150
Sx, Ux = torch.linalg.eigh(Cxx)
150151
Sy, Uy = torch.linalg.eigh(Cyy)
151152
Sx = torch.clamp(Sx, min=reg)
152153
Sy = torch.clamp(Sy, min=reg)
153154
Cxx_inv_sqrt = Ux @ torch.diag(1.0 / torch.sqrt(Sx)) @ Ux.T
154155
Cyy_inv_sqrt = Uy @ torch.diag(1.0 / torch.sqrt(Sy)) @ Uy.T
155-
T = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt
156+
L = Cxx_inv_sqrt @ Xc.T
157+
R = Yc @ Cyy_inv_sqrt
158+
T = (L @ R) / (n - 1)
156159
try:
157160
eigvals = torch.linalg.eigvalsh(T @ T.T)
158161
svals = torch.sqrt(torch.clamp(eigvals, min=0.0))
@@ -354,6 +357,8 @@ class PWCCA(BaseMetric):
354357
def _compute_raw(
355358
self, X: torch.Tensor, Y: torch.Tensor, config: MetricConfig
356359
) -> float:
360+
# L/R factorisation; the algebraically-equivalent Cxx^{-1/2} @ Cxy @ Cyy^{-1/2}
361+
# is float32-unstable when d > n.
357362
device = config.device
358363
reg = 1e-6
359364
Xc = (X - X.mean(0, keepdim=True)).to(device=device, dtype=torch.float32)
@@ -363,14 +368,15 @@ def _compute_raw(
363368
eye_y = torch.eye(dy, device=device, dtype=torch.float32)
364369
Cxx = (Xc.T @ Xc) / (n - 1) + reg * eye_x
365370
Cyy = (Yc.T @ Yc) / (n - 1) + reg * eye_y
366-
Cxy = (Xc.T @ Yc) / (n - 1)
367371
Sx, Ux = torch.linalg.eigh(Cxx)
368372
Sy, Uy = torch.linalg.eigh(Cyy)
369373
Sx = torch.clamp(Sx, min=reg)
370374
Sy = torch.clamp(Sy, min=reg)
371375
Cxx_inv_sqrt = Ux @ torch.diag(1.0 / torch.sqrt(Sx)) @ Ux.T
372376
Cyy_inv_sqrt = Uy @ torch.diag(1.0 / torch.sqrt(Sy)) @ Uy.T
373-
T = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt
377+
L = Cxx_inv_sqrt @ Xc.T
378+
R = Yc @ Cyy_inv_sqrt
379+
T = (L @ R) / (n - 1)
374380
eigvals, U = torch.linalg.eigh(T @ T.T)
375381
# eigh ascending → reverse for SVD descending convention
376382
svals = torch.sqrt(torch.clamp(eigvals.flip(-1), min=0.0))

0 commit comments

Comments
 (0)