2121logger = init_logger (__name__ )
2222
2323
24- def _dmd_forecast (
25- snapshots : List [torch .Tensor ],
26- k : float ,
24+ def _dmd_fit_one (
25+ traj : torch .Tensor ,
2726 rank : int = 0 ,
2827 ridge : float = 1e-8 ,
29- ) -> torch . Tensor :
30- """Forecast a feature ``k `` snapshot-spacings past the newest snapshot via
31- Dynamic Mode Decomposition (Prony ).
28+ ):
29+ """Fit the DMD eigendecomposition for ONE ``[d, n] `` trajectory (a single batch item's snapshot
30+ history, columns OLDEST..NEWEST ).
3231
3332 Identify the linear propagator ``A`` from the snapshot pairs
34- (``Y_{t+1} ~= A Y_t``) through one economy SVD, eigendecompose it once, and
35- advance by (possibly fractional) eigenvalue powers::
36-
37- Y_{t+k} ~= Phi @ (lambda**k * b), b = pinv(Phi) @ Y_t
33+ (``Y_{t+1} ~= A Y_t``) through one economy SVD and eigendecompose it once. The
34+ fit is horizon-free, so the caller caches it and advances it cheaply (see
35+ :func:`_dmd_eval`).
3836
39- :param snapshots: >= 3 same-shape tensors, OLDEST..NEWEST, the fully computed
40- features at recent compute steps (uniformly spaced).
41- :param k: Forecast horizon in snapshot-spacing units (fractional allowed).
4237 :param rank: SVD truncation rank; 0 selects it from the spectrum (drop modes
4338 below 1e-4 of the leading singular value — this is what rejects noise).
4439 :param ridge: Tikhonov term added to the inverted singular values.
45- :returns: Forecast tensor of the snapshot shape; falls back to last-value
46- reuse when the history is too short or the fit is degenerate .
40+ :returns: ``(Phi, evals, b)`` or ``None`` on a degenerate fit (caller reuses
41+ the last value) .
4742 """
4843
49- shp , dt = snapshots [- 1 ].shape , snapshots [- 1 ].dtype
50- if len (snapshots ) < 3 :
51- return snapshots [- 1 ].clone ()
52- V = torch .stack ([s .reshape (- 1 ) for s in snapshots ], dim = 1 ).to (torch .float64 )
53- X , Xp = V [:, :- 1 ], V [:, 1 :]
44+ X , Xp = traj [:, :- 1 ], traj [:, 1 :]
5445 try :
5546 U , S , Vh = torch .linalg .svd (X , full_matrices = False )
56- except Exception : # noqa: BLE001 — degenerate fit: fall back to last-value reuse
57- return snapshots [- 1 ].clone ()
58- if rank <= 0 :
59- rank = int ((S > S [0 ] * 1e-4 ).sum ().clamp (min = 1 ).item ())
60- rank = max (1 , min (rank , S .numel ()))
61- Ur , Sr , Vr = U [:, :rank ], S [:rank ], Vh [:rank ].mH
47+ except Exception : # noqa: BLE001 — degenerate fit: caller falls back to reuse
48+ return None
49+ r = rank
50+ if r <= 0 :
51+ r = int ((S > S [0 ] * 1e-4 ).sum ().clamp (min = 1 ).item ())
52+ r = max (1 , min (r , S .numel ()))
53+ Ur , Sr , Vr = U [:, :r ], S [:r ], Vh [:r ].mH
6254 Sinv = (1.0 / (Sr + ridge )).to (torch .complex128 )
6355 Atil = (Ur .mH @ Xp @ Vr ).to (torch .complex128 ) * Sinv .unsqueeze (0 )
6456 try :
6557 evals , W = torch .linalg .eig (Atil )
6658 Phi = ((Xp @ Vr ).to (torch .complex128 ) * Sinv .unsqueeze (0 )) @ W
67- b = torch .linalg .lstsq (Phi , V [:, - 1 ].to (torch .complex128 ).unsqueeze (1 )).solution .squeeze (1 )
68- except Exception : # noqa: BLE001 — degenerate fit: fall back to last-value reuse
69- return snapshots [- 1 ].clone ()
70- pred = (Phi @ (evals .pow (float (k )) * b )).real
71- if not torch .isfinite (pred ).all ():
72- return snapshots [- 1 ].clone ()
73- return pred .to (dt ).reshape (shp )
59+ b = torch .linalg .lstsq (Phi , traj [:, - 1 ].to (torch .complex128 ).unsqueeze (1 )).solution .squeeze (1 )
60+ except Exception : # noqa: BLE001 — degenerate fit: caller falls back to reuse
61+ return None
62+ return (Phi , evals , b )
63+
64+
65+ def _dmd_fit (
66+ snapshots : List [torch .Tensor ],
67+ rank : int = 0 ,
68+ ridge : float = 1e-8 ,
69+ ):
70+ """Fit DMD once for a window of >= 4 same-shape snapshots, INDEPENDENTLY per batch item (axis 0).
71+
72+ Flattening the whole tensor into a single state (the pre-fix behaviour) folds
73+ the batch dimension into one DMD fit, so one prompt's forecast would depend on
74+ the other prompts in the same batch — unlike the elementwise Taylor path.
75+ Fitting per batch item keeps them independent. The fit is horizon-free, so
76+ :class:`DMDState` caches the returned object and reuses it for every skip step
77+ until a new snapshot arrives (one SVD/eig per window, not per skipped step).
78+
79+ :returns: ``(per_item_fits, shape, dtype)``, or ``None`` when the window is
80+ too short. ``per_item_fits[i]`` is ``None`` for a degenerate batch item.
81+ """
82+
83+ if len (snapshots ) < 4 :
84+ return None
85+ newest = snapshots [- 1 ]
86+ shp , dt = newest .shape , newest .dtype
87+ bsz = shp [0 ] if newest .dim () > 1 else 1
88+ # (B, d, n): per-item trajectories; B == 1 reproduces the un-batched fit.
89+ V = torch .stack ([s .reshape (bsz , - 1 ) for s in snapshots ], dim = - 1 ).to (torch .float64 )
90+ fits = [_dmd_fit_one (V [i ], rank = rank , ridge = ridge ) for i in range (bsz )]
91+ return (fits , shp , dt )
92+
93+
94+ def _dmd_eval (fit , k : float ):
95+ """Advance a cached :func:`_dmd_fit` to (fractional) horizon ``k`` by eigenvalue powers —
96+ ``Y_{t+k} ~= Phi @ (lambda**k * b)`` — one cheap evaluation per batch item, no re-decomposition.
97+
98+ :returns: The forecast tensor of the original snapshot shape, or ``None`` when
99+ any batch item is degenerate, or when the result is non-finite AFTER the
100+ output-dtype cast. The finite check is deliberately post-cast: a finite
101+ float64 forecast can still overflow to ``inf`` in fp16, which the caller
102+ must catch and fall back from rather than feed downstream.
103+ """
104+
105+ fits , shp , dt = fit
106+ rows = []
107+ for f in fits :
108+ if f is None :
109+ return None
110+ Phi , evals , b = f
111+ rows .append ((Phi @ (evals .pow (float (k )) * b )).real )
112+ pred = torch .stack (rows , dim = 0 ) if len (shp ) > 1 else rows [0 ]
113+ out = pred .to (dt ).reshape (shp )
114+ if not torch .isfinite (out ).all ():
115+ return None
116+ return out
74117
75118
76119class DMDState :
@@ -102,6 +145,10 @@ def __init__(
102145 self .current_step = - 1
103146 self .last_non_approximated_step = - 1
104147 self .snapshots : List [Tuple [int , torch .Tensor ]] = []
148+ # Cached horizon-free DMD fit + the window key it was fitted on, so skip
149+ # steps reuse one SVD/eig instead of recomputing it every step.
150+ self ._fit = None
151+ self ._fit_key = None
105152 self .state : Dict [str , List [torch .Tensor ]] = {
106153 "dY_prev" : [None ] * self .order ,
107154 "dY_current" : [None ] * self .order ,
@@ -113,6 +160,8 @@ def reset(self):
113160 self .current_step = - 1
114161 self .last_non_approximated_step = - 1
115162 self .snapshots = []
163+ self ._fit = None
164+ self ._fit_key = None
116165 self .state = {
117166 "dY_prev" : [None ] * self .order ,
118167 "dY_current" : [None ] * self .order ,
@@ -166,8 +215,10 @@ def update(self, Y: torch.Tensor):
166215 # silently overwrite the snapshot history in place.
167216 self .snapshots .append ((self .current_step , Y .detach ().clone ()))
168217 if len (self .snapshots ) > self .history :
169- del self .snapshots [: len (self .snapshots ) - self .history ]
218+ del self .snapshots [:len (self .snapshots ) - self .history ]
170219 self .last_non_approximated_step = self .current_step
220+ # A new snapshot changes the window, so the cached fit is stale.
221+ self ._fit_key = None
171222
172223 def _uniform_tail (self ) -> Tuple [List [torch .Tensor ], int ]:
173224 """Longest uniformly spaced suffix of the snapshot history.
@@ -203,13 +254,26 @@ def approximate(self) -> torch.Tensor:
203254 pairs) it falls back to the Taylor expansion — DMD acts only where it is
204255 valid and the polynomial path covers warm-up.
205256
257+ The eigendecomposition depends only on the snapshot window, which cannot
258+ change between two skip steps, so it is fitted once per window and cached;
259+ each skip step only re-advances the cheap ``lambda**k`` horizon. A
260+ degenerate fit or non-finite forecast reuses the newest snapshot.
261+
206262 :returns: The forecast tensor for the current logical step.
207263 """
208264
209265 vels , spacing = self ._uniform_tail ()
210266 if len (vels ) >= 4 :
211267 k = (self .current_step - self .snapshots [- 1 ][0 ]) / spacing
212- return _dmd_forecast (vels , k , rank = self .rank , ridge = self .ridge )
268+ key = (self .snapshots [- 1 ][0 ], len (vels ), spacing )
269+ if self ._fit_key != key :
270+ self ._fit = _dmd_fit (vels , rank = self .rank , ridge = self .ridge )
271+ self ._fit_key = key
272+ if self ._fit is not None :
273+ pred = _dmd_eval (self ._fit , k )
274+ if pred is not None :
275+ return pred
276+ return self .snapshots [- 1 ][1 ].clone ()
213277 return self ._approximate_taylor ()
214278
215279 def step (self , Y : torch .Tensor ):
@@ -227,8 +291,8 @@ def step(self, Y: torch.Tensor):
227291
228292
229293class DMDCalibrator (CalibratorBase ):
230- """Calibrator that forecasts tensors with a Dynamic Mode Decomposition
231- (Prony) exponential basis — drop-in alternative to `TaylorSeerCalibrator`."""
294+ """Calibrator that forecasts tensors with a Dynamic Mode Decomposition (Prony) exponential basis —
295+ drop-in alternative to `TaylorSeerCalibrator`."""
232296
233297 def __init__ (
234298 self ,
0 commit comments