Skip to content

Commit 6cd559c

Browse files
authored
feat: support dynamic mode decomposition calibrator (#1053)
* Add a Dynamic Mode Decomposition (Prony) exponential-basis calibrator (`calibrator_type="dmd"`) * Address review: cache the DMD fit, decouple batch items, guard fp16 overflow - Cache the horizon-free DMD eigendecomposition per snapshot window (DMDState._fit / _fit_key, invalidated when a new snapshot arrives). Skip steps now reuse one SVD/eig instead of recomputing it every step, which is what restores the intended cache speedup at large fresh intervals. - Fit DMD independently per batch item (axis 0). Flattening folded the batch into one state, so a prompt's forecast depended on the other prompts in the batch; per-item fitting keeps them independent like the Taylor path. - Move the finite check after the output-dtype cast: a finite float64 forecast can still overflow to inf in fp16, so the cast result is what gets guarded. - yapf / docformatter clean (fixes the failing pre-commit CI check). * Add 'dmd' to the example generate CLI (--dmd / --dmd-history) Per review: enable the DMD calibrator from `python -m cache_dit.generate` exactly like --taylorseer. --dmd selects DMDCalibratorConfig (history via --dmd-history, default 6); --taylorseer is unchanged. Verified end-to-end: python -m cache_dit.generate flux --cache --dmd --cpu-offload generates with the DMD calibrator active (optimization tag ...DMDH6_S12, image saved).
1 parent 33eacf7 commit 6cd559c

6 files changed

Lines changed: 489 additions & 2 deletions

File tree

src/cache_dit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .caching import CalibratorConfig
4545
from .caching import TaylorSeerCalibratorConfig
4646
from .caching import FoCaCalibratorConfig
47+
from .caching import DMDCalibratorConfig
4748
from .caching import supported_pipelines
4849
from .caching import get_adapter
4950
from .caching import BlockAdapterRegister

src/cache_dit/_utils/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..caching import (
1919
BlockAdapter,
2020
DBCacheConfig,
21+
DMDCalibratorConfig,
2122
TaylorSeerCalibratorConfig,
2223
load_configs,
2324
load_parallelism_config,
@@ -309,6 +310,18 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa
309310
default=1,
310311
help="TaylorSeer order",
311312
)
313+
parser.add_argument(
314+
"--dmd",
315+
action="store_true",
316+
default=False,
317+
help="Enable DMD (Dynamic Mode Decomposition / Prony) exponential-basis calibrator for CacheDiT",
318+
)
319+
parser.add_argument(
320+
"--dmd-history",
321+
type=int,
322+
default=6,
323+
help="DMD snapshot-history window length",
324+
)
312325
parser.add_argument(
313326
"--steps-mask",
314327
action="store_true",
@@ -2282,8 +2295,9 @@ def _prepare_distributed_size():
22822295
force_refresh_step_hint=kwargs.get("force_refresh_step_hint", None),
22832296
force_refresh_step_policy=kwargs.get("force_refresh_step_policy", "once"),
22842297
) if cache_config is None and args.cache else cache_config),
2285-
calibrator_config=(TaylorSeerCalibratorConfig(taylorseer_order=args.taylorseer_order, )
2286-
if args.taylorseer else None),
2298+
calibrator_config=(DMDCalibratorConfig(
2299+
dmd_history=args.dmd_history) if args.dmd else TaylorSeerCalibratorConfig(
2300+
taylorseer_order=args.taylorseer_order) if args.taylorseer else None),
22872301
params_modifiers=kwargs.get("params_modifiers", None),
22882302
parallelism_config=(ParallelismConfig(
22892303
ulysses_size=ulysses_size,

src/cache_dit/caching/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .cache_contexts import CalibratorConfig
2222
from .cache_contexts import TaylorSeerCalibratorConfig
2323
from .cache_contexts import FoCaCalibratorConfig
24+
from .cache_contexts import DMDCalibratorConfig
2425

2526
from .cache_blocks import CachedBlocks
2627
from .cache_blocks import PrunedBlocks

src/cache_dit/caching/cache_contexts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
CalibratorConfig,
55
TaylorSeerCalibratorConfig,
66
FoCaCalibratorConfig,
7+
DMDCalibratorConfig,
78
)
89
from .cache_config import (
910
BasicCacheConfig,

src/cache_dit/caching/cache_contexts/calibrators/__init__.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .base import CalibratorBase
22
from .taylorseer import TaylorSeerCalibrator
33
from .foca import FoCaCalibrator
4+
from .dmd import DMDCalibrator
45

56
import dataclasses
67
from typing import Any, Dict
@@ -153,6 +154,70 @@ def to_kwargs(self) -> Dict:
153154
return kwargs
154155

155156

157+
@dataclasses.dataclass
158+
class DMDCalibratorConfig(CalibratorConfig):
159+
"""Config for the Dynamic Mode Decomposition (Prony) forecasting calibrator.
160+
161+
An EXPONENTIAL-basis alternative to TaylorSeer's polynomial forecast: DMD (Schmid 2010; the SVD-
162+
regularised generalisation of Prony's method) identifies the linear propagator of the cached
163+
feature stream from recent compute-step snapshots and extrapolates by eigenvalue powers — exact on
164+
the (locally) exponential trajectories diffusion features follow, where a polynomial diverges with
165+
the cache interval. NOT Distribution Matching Distillation.
166+
"""
167+
168+
# enable_calibrator (`bool`, *required*, defaults to True):
169+
# Whether to enable calibrator, if True. means that user want to use DBCache
170+
# with specific calibrator for hidden_states (or hidden_states redisual),
171+
# such as taylorseer, foca, dmd, and so on.
172+
enable_calibrator: bool = True
173+
# enable_encoder_calibrator (`bool`, *required*, defaults to True):
174+
# Whether to enable calibrator, if True. means that user want to use DBCache
175+
# with specific calibrator for encoder_hidden_states (or encoder_hidden_states
176+
# redisual), such as taylorseer, foca, dmd, and so on.
177+
enable_encoder_calibrator: bool = True
178+
# calibrator_type (`str`, *required*, defaults to 'dmd'):
179+
# The specific type for calibrator, taylorseer, foca or dmd, etc.
180+
calibrator_type: str = "dmd"
181+
# dmd_history (`int`, *required*, defaults to 6):
182+
# Number of recent compute-step snapshots retained per stream. >= 4 uniformly
183+
# spaced snapshots are needed before the exponential fit engages (one complex
184+
# pole costs two real degrees of freedom); below the floor the calibrator
185+
# falls back to the Taylor expansion automatically. 5-6 is the sweet spot —
186+
# the feature dynamics drift across timesteps, so longer windows hurt.
187+
dmd_history: int = 6
188+
# dmd_rank (`int`, *optional*, defaults to 0):
189+
# SVD truncation rank of the snapshot matrix; 0 selects it from the spectrum
190+
# (drop modes below 1e-4 of the leading singular value). The truncation is
191+
# what rejects the noise subspace.
192+
dmd_rank: int = 0
193+
# dmd_ridge (`float`, *optional*, defaults to 1e-8):
194+
# Tikhonov term added to the inverted singular values.
195+
dmd_ridge: float = 1e-8
196+
197+
def strify(self, **kwargs) -> str:
198+
"""Return a compact tag that includes the snapshot-history length.
199+
200+
:param kwargs: Additional keyword arguments forwarded to the underlying implementation.
201+
:returns: A compact DMD tag for logs, summaries, or filenames.
202+
"""
203+
204+
if kwargs.get("details", False):
205+
return f"DMD_H({self.dmd_history})"
206+
return f"DMDH{self.dmd_history}"
207+
208+
def to_kwargs(self) -> Dict:
209+
"""Translate config fields into `DMDCalibrator` init kwargs.
210+
211+
:returns: Keyword arguments expected by `DMDCalibrator`.
212+
"""
213+
214+
kwargs = self.calibrator_kwargs.copy()
215+
kwargs["history"] = self.dmd_history
216+
kwargs["rank"] = self.dmd_rank
217+
kwargs["ridge"] = self.dmd_ridge
218+
return kwargs
219+
220+
156221
@dataclasses.dataclass
157222
class FoCaCalibratorConfig(CalibratorConfig):
158223
"""Config placeholder for the future FoCa calibrator backend."""
@@ -183,6 +248,7 @@ class Calibrator:
183248

184249
_supported_calibrators = [
185250
"taylorseer",
251+
"dmd",
186252
# TODO: FoCa
187253
]
188254

@@ -201,5 +267,7 @@ def __new__(
201267

202268
if calibrator_config.calibrator_type.lower() == "taylorseer":
203269
return TaylorSeerCalibrator(**calibrator_config.to_kwargs())
270+
elif calibrator_config.calibrator_type.lower() == "dmd":
271+
return DMDCalibrator(**calibrator_config.to_kwargs())
204272
else:
205273
raise ValueError(f"Calibrator {calibrator_config.calibrator_type} is not supported now!")

0 commit comments

Comments
 (0)