Skip to content

Commit e1067af

Browse files
committed
Address review: cache the DMD fit, decouple batch items, guard fp16 overflow
- Cache the horizon-free DMD eigendecomposition per snapshot window (DMDState._fit / _fit_key, invalidated when a new snapshot arrives). Skip steps now reuse one SVD/eig instead of recomputing it every step, which is what restores the intended cache speedup at large fresh intervals. - Fit DMD independently per batch item (axis 0). Flattening folded the batch into one state, so a prompt's forecast depended on the other prompts in the batch; per-item fitting keeps them independent like the Taylor path. - Move the finite check after the output-dtype cast: a finite float64 forecast can still overflow to inf in fp16, so the cast result is what gets guarded. - yapf / docformatter clean (fixes the failing pre-commit CI check).
1 parent c9a6b50 commit e1067af

2 files changed

Lines changed: 106 additions & 43 deletions

File tree

src/cache_dit/caching/cache_contexts/calibrators/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,11 @@ def to_kwargs(self) -> Dict:
158158
class DMDCalibratorConfig(CalibratorConfig):
159159
"""Config for the Dynamic Mode Decomposition (Prony) forecasting calibrator.
160160
161-
An EXPONENTIAL-basis alternative to TaylorSeer's polynomial forecast: DMD
162-
(Schmid 2010; the SVD-regularised generalisation of Prony's method) identifies
163-
the linear propagator of the cached feature stream from recent compute-step
164-
snapshots and extrapolates by eigenvalue powers — exact on the (locally)
165-
exponential trajectories diffusion features follow, where a polynomial
166-
diverges with the cache interval. NOT Distribution Matching Distillation.
161+
An EXPONENTIAL-basis alternative to TaylorSeer's polynomial forecast: DMD (Schmid 2010; the SVD-
162+
regularised generalisation of Prony's method) identifies the linear propagator of the cached
163+
feature stream from recent compute-step snapshots and extrapolates by eigenvalue powers — exact on
164+
the (locally) exponential trajectories diffusion features follow, where a polynomial diverges with
165+
the cache interval. NOT Distribution Matching Distillation.
167166
"""
168167

169168
# enable_calibrator (`bool`, *required*, defaults to True):

src/cache_dit/caching/cache_contexts/calibrators/dmd.py

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,56 +21,99 @@
2121
logger = init_logger(__name__)
2222

2323

24-
def _dmd_forecast(
25-
snapshots: List[torch.Tensor],
26-
k: float,
24+
def _dmd_fit_one(
25+
traj: torch.Tensor,
2726
rank: int = 0,
2827
ridge: float = 1e-8,
29-
) -> torch.Tensor:
30-
"""Forecast a feature ``k`` snapshot-spacings past the newest snapshot via
31-
Dynamic Mode Decomposition (Prony).
28+
):
29+
"""Fit the DMD eigendecomposition for ONE ``[d, n]`` trajectory (a single batch item's snapshot
30+
history, columns OLDEST..NEWEST).
3231
3332
Identify the linear propagator ``A`` from the snapshot pairs
34-
(``Y_{t+1} ~= A Y_t``) through one economy SVD, eigendecompose it once, and
35-
advance by (possibly fractional) eigenvalue powers::
36-
37-
Y_{t+k} ~= Phi @ (lambda**k * b), b = pinv(Phi) @ Y_t
33+
(``Y_{t+1} ~= A Y_t``) through one economy SVD and eigendecompose it once. The
34+
fit is horizon-free, so the caller caches it and advances it cheaply (see
35+
:func:`_dmd_eval`).
3836
39-
:param snapshots: >= 3 same-shape tensors, OLDEST..NEWEST, the fully computed
40-
features at recent compute steps (uniformly spaced).
41-
:param k: Forecast horizon in snapshot-spacing units (fractional allowed).
4237
:param rank: SVD truncation rank; 0 selects it from the spectrum (drop modes
4338
below 1e-4 of the leading singular value — this is what rejects noise).
4439
:param ridge: Tikhonov term added to the inverted singular values.
45-
:returns: Forecast tensor of the snapshot shape; falls back to last-value
46-
reuse when the history is too short or the fit is degenerate.
40+
:returns: ``(Phi, evals, b)`` or ``None`` on a degenerate fit (caller reuses
41+
the last value).
4742
"""
4843

49-
shp, dt = snapshots[-1].shape, snapshots[-1].dtype
50-
if len(snapshots) < 3:
51-
return snapshots[-1].clone()
52-
V = torch.stack([s.reshape(-1) for s in snapshots], dim=1).to(torch.float64)
53-
X, Xp = V[:, :-1], V[:, 1:]
44+
X, Xp = traj[:, :-1], traj[:, 1:]
5445
try:
5546
U, S, Vh = torch.linalg.svd(X, full_matrices=False)
56-
except Exception: # noqa: BLE001 — degenerate fit: fall back to last-value reuse
57-
return snapshots[-1].clone()
58-
if rank <= 0:
59-
rank = int((S > S[0] * 1e-4).sum().clamp(min=1).item())
60-
rank = max(1, min(rank, S.numel()))
61-
Ur, Sr, Vr = U[:, :rank], S[:rank], Vh[:rank].mH
47+
except Exception: # noqa: BLE001 — degenerate fit: caller falls back to reuse
48+
return None
49+
r = rank
50+
if r <= 0:
51+
r = int((S > S[0] * 1e-4).sum().clamp(min=1).item())
52+
r = max(1, min(r, S.numel()))
53+
Ur, Sr, Vr = U[:, :r], S[:r], Vh[:r].mH
6254
Sinv = (1.0 / (Sr + ridge)).to(torch.complex128)
6355
Atil = (Ur.mH @ Xp @ Vr).to(torch.complex128) * Sinv.unsqueeze(0)
6456
try:
6557
evals, W = torch.linalg.eig(Atil)
6658
Phi = ((Xp @ Vr).to(torch.complex128) * Sinv.unsqueeze(0)) @ W
67-
b = torch.linalg.lstsq(Phi, V[:, -1].to(torch.complex128).unsqueeze(1)).solution.squeeze(1)
68-
except Exception: # noqa: BLE001 — degenerate fit: fall back to last-value reuse
69-
return snapshots[-1].clone()
70-
pred = (Phi @ (evals.pow(float(k)) * b)).real
71-
if not torch.isfinite(pred).all():
72-
return snapshots[-1].clone()
73-
return pred.to(dt).reshape(shp)
59+
b = torch.linalg.lstsq(Phi, traj[:, -1].to(torch.complex128).unsqueeze(1)).solution.squeeze(1)
60+
except Exception: # noqa: BLE001 — degenerate fit: caller falls back to reuse
61+
return None
62+
return (Phi, evals, b)
63+
64+
65+
def _dmd_fit(
66+
snapshots: List[torch.Tensor],
67+
rank: int = 0,
68+
ridge: float = 1e-8,
69+
):
70+
"""Fit DMD once for a window of >= 4 same-shape snapshots, INDEPENDENTLY per batch item (axis 0).
71+
72+
Flattening the whole tensor into a single state (the pre-fix behaviour) folds
73+
the batch dimension into one DMD fit, so one prompt's forecast would depend on
74+
the other prompts in the same batch — unlike the elementwise Taylor path.
75+
Fitting per batch item keeps them independent. The fit is horizon-free, so
76+
:class:`DMDState` caches the returned object and reuses it for every skip step
77+
until a new snapshot arrives (one SVD/eig per window, not per skipped step).
78+
79+
:returns: ``(per_item_fits, shape, dtype)``, or ``None`` when the window is
80+
too short. ``per_item_fits[i]`` is ``None`` for a degenerate batch item.
81+
"""
82+
83+
if len(snapshots) < 4:
84+
return None
85+
newest = snapshots[-1]
86+
shp, dt = newest.shape, newest.dtype
87+
bsz = shp[0] if newest.dim() > 1 else 1
88+
# (B, d, n): per-item trajectories; B == 1 reproduces the un-batched fit.
89+
V = torch.stack([s.reshape(bsz, -1) for s in snapshots], dim=-1).to(torch.float64)
90+
fits = [_dmd_fit_one(V[i], rank=rank, ridge=ridge) for i in range(bsz)]
91+
return (fits, shp, dt)
92+
93+
94+
def _dmd_eval(fit, k: float):
95+
"""Advance a cached :func:`_dmd_fit` to (fractional) horizon ``k`` by eigenvalue powers —
96+
``Y_{t+k} ~= Phi @ (lambda**k * b)`` — one cheap evaluation per batch item, no re-decomposition.
97+
98+
:returns: The forecast tensor of the original snapshot shape, or ``None`` when
99+
any batch item is degenerate, or when the result is non-finite AFTER the
100+
output-dtype cast. The finite check is deliberately post-cast: a finite
101+
float64 forecast can still overflow to ``inf`` in fp16, which the caller
102+
must catch and fall back from rather than feed downstream.
103+
"""
104+
105+
fits, shp, dt = fit
106+
rows = []
107+
for f in fits:
108+
if f is None:
109+
return None
110+
Phi, evals, b = f
111+
rows.append((Phi @ (evals.pow(float(k)) * b)).real)
112+
pred = torch.stack(rows, dim=0) if len(shp) > 1 else rows[0]
113+
out = pred.to(dt).reshape(shp)
114+
if not torch.isfinite(out).all():
115+
return None
116+
return out
74117

75118

76119
class DMDState:
@@ -102,6 +145,10 @@ def __init__(
102145
self.current_step = -1
103146
self.last_non_approximated_step = -1
104147
self.snapshots: List[Tuple[int, torch.Tensor]] = []
148+
# Cached horizon-free DMD fit + the window key it was fitted on, so skip
149+
# steps reuse one SVD/eig instead of recomputing it every step.
150+
self._fit = None
151+
self._fit_key = None
105152
self.state: Dict[str, List[torch.Tensor]] = {
106153
"dY_prev": [None] * self.order,
107154
"dY_current": [None] * self.order,
@@ -113,6 +160,8 @@ def reset(self):
113160
self.current_step = -1
114161
self.last_non_approximated_step = -1
115162
self.snapshots = []
163+
self._fit = None
164+
self._fit_key = None
116165
self.state = {
117166
"dY_prev": [None] * self.order,
118167
"dY_current": [None] * self.order,
@@ -166,8 +215,10 @@ def update(self, Y: torch.Tensor):
166215
# silently overwrite the snapshot history in place.
167216
self.snapshots.append((self.current_step, Y.detach().clone()))
168217
if len(self.snapshots) > self.history:
169-
del self.snapshots[: len(self.snapshots) - self.history]
218+
del self.snapshots[:len(self.snapshots) - self.history]
170219
self.last_non_approximated_step = self.current_step
220+
# A new snapshot changes the window, so the cached fit is stale.
221+
self._fit_key = None
171222

172223
def _uniform_tail(self) -> Tuple[List[torch.Tensor], int]:
173224
"""Longest uniformly spaced suffix of the snapshot history.
@@ -203,13 +254,26 @@ def approximate(self) -> torch.Tensor:
203254
pairs) it falls back to the Taylor expansion — DMD acts only where it is
204255
valid and the polynomial path covers warm-up.
205256
257+
The eigendecomposition depends only on the snapshot window, which cannot
258+
change between two skip steps, so it is fitted once per window and cached;
259+
each skip step only re-advances the cheap ``lambda**k`` horizon. A
260+
degenerate fit or non-finite forecast reuses the newest snapshot.
261+
206262
:returns: The forecast tensor for the current logical step.
207263
"""
208264

209265
vels, spacing = self._uniform_tail()
210266
if len(vels) >= 4:
211267
k = (self.current_step - self.snapshots[-1][0]) / spacing
212-
return _dmd_forecast(vels, k, rank=self.rank, ridge=self.ridge)
268+
key = (self.snapshots[-1][0], len(vels), spacing)
269+
if self._fit_key != key:
270+
self._fit = _dmd_fit(vels, rank=self.rank, ridge=self.ridge)
271+
self._fit_key = key
272+
if self._fit is not None:
273+
pred = _dmd_eval(self._fit, k)
274+
if pred is not None:
275+
return pred
276+
return self.snapshots[-1][1].clone()
213277
return self._approximate_taylor()
214278

215279
def step(self, Y: torch.Tensor):
@@ -227,8 +291,8 @@ def step(self, Y: torch.Tensor):
227291

228292

229293
class DMDCalibrator(CalibratorBase):
230-
"""Calibrator that forecasts tensors with a Dynamic Mode Decomposition
231-
(Prony) exponential basis — drop-in alternative to `TaylorSeerCalibrator`."""
294+
"""Calibrator that forecasts tensors with a Dynamic Mode Decomposition (Prony) exponential basis —
295+
drop-in alternative to `TaylorSeerCalibrator`."""
232296

233297
def __init__(
234298
self,

0 commit comments

Comments
 (0)