From 948e665d12fccb9cdab3fdc01d1c5827934c13a0 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 15 Jun 2026 03:51:38 +0000 Subject: [PATCH 1/2] feat: fast svd decompose for dmd calibrator --- src/cache_dit/_utils/utils.py | 15 +++- .../cache_contexts/calibrators/__init__.py | 13 +++- .../caching/cache_contexts/calibrators/dmd.py | 72 ++++++++++++++++++- 3 files changed, 91 insertions(+), 9 deletions(-) diff --git a/src/cache_dit/_utils/utils.py b/src/cache_dit/_utils/utils.py index de47be6a..2baf1c38 100644 --- a/src/cache_dit/_utils/utils.py +++ b/src/cache_dit/_utils/utils.py @@ -322,6 +322,14 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa default=6, help="DMD snapshot-history window length", ) + parser.add_argument( + "--dmd-svd-precision", + "--dmd-svd", + type=str, + default="low", + choices=["low", "medium", "high"], + help="DMD SVD precision mode: low (fastest), medium (balanced), high (accurate)", + ) parser.add_argument( "--steps-mask", action="store_true", @@ -2295,9 +2303,10 @@ def _prepare_distributed_size(): force_refresh_step_hint=kwargs.get("force_refresh_step_hint", None), force_refresh_step_policy=kwargs.get("force_refresh_step_policy", "once"), ) if cache_config is None and args.cache else cache_config), - calibrator_config=(DMDCalibratorConfig( - dmd_history=args.dmd_history) if args.dmd else TaylorSeerCalibratorConfig( - taylorseer_order=args.taylorseer_order) if args.taylorseer else None), + calibrator_config=(DMDCalibratorConfig(dmd_history=args.dmd_history, + dmd_svd_precision=args.dmd_svd_precision) + if args.dmd else TaylorSeerCalibratorConfig( + taylorseer_order=args.taylorseer_order) if args.taylorseer else None), params_modifiers=kwargs.get("params_modifiers", None), parallelism_config=(ParallelismConfig( ulysses_size=ulysses_size, diff --git a/src/cache_dit/caching/cache_contexts/calibrators/__init__.py b/src/cache_dit/caching/cache_contexts/calibrators/__init__.py index 7c868ade..09c6c393 100644 --- a/src/cache_dit/caching/cache_contexts/calibrators/__init__.py +++ b/src/cache_dit/caching/cache_contexts/calibrators/__init__.py @@ -193,17 +193,23 @@ class DMDCalibratorConfig(CalibratorConfig): # dmd_ridge (`float`, *optional*, defaults to 1e-8): # Tikhonov term added to the inverted singular values. dmd_ridge: float = 1e-8 + # dmd_svd_precision (`str`, *optional*, defaults to 'low'): + # SVD precision mode for the DMD snapshot matrix. ``"low"`` uses randomised + # ``torch.svd_lowrank`` (fastest, niter=4, deterministic seed), ``"medium"`` + # uses default ``torch.linalg.svd`` (gesdd), ``"high"`` uses ``driver="gesvd"``. + dmd_svd_precision: str = "low" def strify(self, **kwargs) -> str: - """Return a compact tag that includes the snapshot-history length. + """Return a compact tag that includes the snapshot-history length and SVD precision. :param kwargs: Additional keyword arguments forwarded to the underlying implementation. :returns: A compact DMD tag for logs, summaries, or filenames. """ + prec = self.dmd_svd_precision[0] if kwargs.get("details", False): - return f"DMD_H({self.dmd_history})" - return f"DMDH{self.dmd_history}" + return f"DMD_H({self.dmd_history}, {self.dmd_svd_precision})" + return f"DMDH{self.dmd_history}{prec}" if prec != "l" else f"DMDH{self.dmd_history}" def to_kwargs(self) -> Dict: """Translate config fields into `DMDCalibrator` init kwargs. @@ -215,6 +221,7 @@ def to_kwargs(self) -> Dict: kwargs["history"] = self.dmd_history kwargs["rank"] = self.dmd_rank kwargs["ridge"] = self.dmd_ridge + kwargs["svd_precision"] = self.dmd_svd_precision return kwargs diff --git a/src/cache_dit/caching/cache_contexts/calibrators/dmd.py b/src/cache_dit/caching/cache_contexts/calibrators/dmd.py index 5baca5da..c4b0f083 100644 --- a/src/cache_dit/caching/cache_contexts/calibrators/dmd.py +++ b/src/cache_dit/caching/cache_contexts/calibrators/dmd.py @@ -20,11 +20,57 @@ logger = init_logger(__name__) +_SVD_PRECISIONS = ("low", "medium", "high") +_SVD_LOWRANK_NITER = 4 +_SVD_LOWRANK_SEED = 0 + + +def _svd_rng_devices(X: torch.Tensor) -> list[int]: + if X.device.type != "cuda": + return [] + idx = X.device.index + if idx is None: + idx = torch.cuda.current_device() + return [idx] + + +def _dmd_svd( + X: torch.Tensor, + svd_precision: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Economy SVD for a tall-skinny ``[d, n]`` snapshot matrix. + + Three precision levels: + - ``"low"``: randomised ``torch.svd_lowrank`` (fastest, niter=4, deterministic seed). + - ``"medium"``: standard ``torch.linalg.svd`` with default ``gesdd``. + - ``"high"``: ``torch.linalg.svd`` with ``driver="gesvd"`` for maximum accuracy. + + :param X: Tall-skinny input matrix of shape ``[d, n]``. + :param svd_precision: One of ``("low", "medium", "high")``. + :returns: ``(U, S, Vh)`` — economy decomposition where ``U`` is ``[d, n]``, + ``S`` is ``[n]``, and ``Vh`` is ``[n, n]``. + """ + + assert svd_precision in _SVD_PRECISIONS, ( + f"svd_precision must be one of {_SVD_PRECISIONS}, got {svd_precision!r}") + n = X.shape[-1] + if svd_precision == "low": + with torch.random.fork_rng(devices=_svd_rng_devices(X)): + torch.manual_seed(_SVD_LOWRANK_SEED) + U, S, V = torch.svd_lowrank(X, q=n, niter=_SVD_LOWRANK_NITER) + Vh = V.mH + elif svd_precision == "high": + U, S, Vh = torch.linalg.svd(X, full_matrices=False, driver="gesvd") + else: # medium + U, S, Vh = torch.linalg.svd(X, full_matrices=False) + return U, S, Vh + def _dmd_fit_one( traj: torch.Tensor, rank: int = 0, ridge: float = 1e-8, + svd_precision: str = "low", ): """Fit the DMD eigendecomposition for ONE ``[d, n]`` trajectory (a single batch item's snapshot history, columns OLDEST..NEWEST). @@ -37,13 +83,16 @@ def _dmd_fit_one( :param rank: SVD truncation rank; 0 selects it from the spectrum (drop modes below 1e-4 of the leading singular value — this is what rejects noise). :param ridge: Tikhonov term added to the inverted singular values. + :param svd_precision: SVD precision mode; ``"low"`` uses randomised + ``torch.svd_lowrank`` (fastest), ``"medium"`` uses default ``gesdd``, + ``"high"`` uses ``driver="gesvd"``. :returns: ``(Phi, evals, b)`` or ``None`` on a degenerate fit (caller reuses the last value). """ X, Xp = traj[:, :-1], traj[:, 1:] try: - U, S, Vh = torch.linalg.svd(X, full_matrices=False) + U, S, Vh = _dmd_svd(X, svd_precision) except Exception: # noqa: BLE001 — degenerate fit: caller falls back to reuse return None r = rank @@ -66,6 +115,7 @@ def _dmd_fit( snapshots: List[torch.Tensor], rank: int = 0, ridge: float = 1e-8, + svd_precision: str = "low", ): """Fit DMD once for a window of >= 4 same-shape snapshots, INDEPENDENTLY per batch item (axis 0). @@ -87,7 +137,9 @@ def _dmd_fit( bsz = shp[0] if newest.dim() > 1 else 1 # (B, d, n): per-item trajectories; B == 1 reproduces the un-batched fit. V = torch.stack([s.reshape(bsz, -1) for s in snapshots], dim=-1).to(torch.float64) - fits = [_dmd_fit_one(V[i], rank=rank, ridge=ridge) for i in range(bsz)] + fits = [ + _dmd_fit_one(V[i], rank=rank, ridge=ridge, svd_precision=svd_precision) for i in range(bsz) + ] return (fits, shp, dt) @@ -125,6 +177,7 @@ def __init__( rank: int = 0, ridge: float = 1e-8, n_derivatives: int = 1, + svd_precision: str = "low", ): """Initialize snapshot buffers and the polynomial-fallback ladder. @@ -135,12 +188,16 @@ def __init__( :param rank: SVD truncation rank for the DMD fit (0 = automatic). :param ridge: Tikhonov regulariser for the DMD fit. :param n_derivatives: Taylor orders kept for the warm-up fallback. + :param svd_precision: SVD precision mode; ``"low"`` uses randomised + ``torch.svd_lowrank`` (fastest), ``"medium"`` uses default ``gesdd``, + ``"high"`` uses ``driver="gesvd"``. """ self.history = history self.rank = rank self.ridge = ridge self.n_derivatives = n_derivatives + self.svd_precision = svd_precision self.order = n_derivatives + 1 self.current_step = -1 self.last_non_approximated_step = -1 @@ -267,7 +324,10 @@ def approximate(self) -> torch.Tensor: k = (self.current_step - self.snapshots[-1][0]) / spacing key = (self.snapshots[-1][0], len(vels), spacing) if self._fit_key != key: - self._fit = _dmd_fit(vels, rank=self.rank, ridge=self.ridge) + self._fit = _dmd_fit(vels, + rank=self.rank, + ridge=self.ridge, + svd_precision=self.svd_precision) self._fit_key = key if self._fit is not None: pred = _dmd_eval(self._fit, k) @@ -300,6 +360,7 @@ def __init__( rank: int = 0, ridge: float = 1e-8, n_derivatives: int = 1, + svd_precision: str = "low", **kwargs, ): """Create a calibrator whose states are keyed by logical tensor names. @@ -308,6 +369,9 @@ def __init__( :param rank: SVD truncation rank for the DMD fit (0 = automatic). :param ridge: Tikhonov regulariser for the DMD fit. :param n_derivatives: Taylor orders for the warm-up fallback ladder. + :param svd_precision: SVD precision mode; ``"low"`` uses randomised + ``torch.svd_lowrank`` (fastest), ``"medium"`` uses default ``gesdd``, + ``"high"`` uses ``driver="gesvd"``. :param kwargs: Additional keyword arguments forwarded to the underlying implementation. """ @@ -315,6 +379,7 @@ def __init__( self.rank = rank self.ridge = ridge self.n_derivatives = n_derivatives + self.svd_precision = svd_precision self.states: Dict[str, DMDState] = {} self.reset_cache() @@ -340,6 +405,7 @@ def maybe_init_state( rank=self.rank, ridge=self.ridge, n_derivatives=self.n_derivatives, + svd_precision=self.svd_precision, ) def mark_step_begin(self, *args, **kwargs): From b7a0b547790f0fac633166d1b0b9f07245500c53 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 15 Jun 2026 03:59:25 +0000 Subject: [PATCH 2/2] feat: fast svd decompose for dmd calibrator --- src/cache_dit/_utils/utils.py | 4 ++-- .../caching/cache_contexts/calibrators/__init__.py | 13 +++++++------ .../caching/cache_contexts/calibrators/dmd.py | 12 ++++++------ 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/cache_dit/_utils/utils.py b/src/cache_dit/_utils/utils.py index 2baf1c38..71749b6d 100644 --- a/src/cache_dit/_utils/utils.py +++ b/src/cache_dit/_utils/utils.py @@ -326,9 +326,9 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa "--dmd-svd-precision", "--dmd-svd", type=str, - default="low", + default="medium", choices=["low", "medium", "high"], - help="DMD SVD precision mode: low (fastest), medium (balanced), high (accurate)", + help="DMD SVD precision mode: medium (default, balanced), low (randomised), high (accurate)", ) parser.add_argument( "--steps-mask", diff --git a/src/cache_dit/caching/cache_contexts/calibrators/__init__.py b/src/cache_dit/caching/cache_contexts/calibrators/__init__.py index 09c6c393..1d70012b 100644 --- a/src/cache_dit/caching/cache_contexts/calibrators/__init__.py +++ b/src/cache_dit/caching/cache_contexts/calibrators/__init__.py @@ -193,11 +193,12 @@ class DMDCalibratorConfig(CalibratorConfig): # dmd_ridge (`float`, *optional*, defaults to 1e-8): # Tikhonov term added to the inverted singular values. dmd_ridge: float = 1e-8 - # dmd_svd_precision (`str`, *optional*, defaults to 'low'): - # SVD precision mode for the DMD snapshot matrix. ``"low"`` uses randomised - # ``torch.svd_lowrank`` (fastest, niter=4, deterministic seed), ``"medium"`` - # uses default ``torch.linalg.svd`` (gesdd), ``"high"`` uses ``driver="gesvd"``. - dmd_svd_precision: str = "low" + # dmd_svd_precision (`str`, *optional*, defaults to 'medium'): + # SVD precision mode for the DMD snapshot matrix. ``"medium"`` uses default + # ``torch.linalg.svd`` (gesdd, balanced), ``"low"`` uses randomised + # ``torch.svd_lowrank`` (niter=1, deterministic seed), ``"high"`` uses + # ``driver="gesvd"``. + dmd_svd_precision: str = "medium" def strify(self, **kwargs) -> str: """Return a compact tag that includes the snapshot-history length and SVD precision. @@ -209,7 +210,7 @@ def strify(self, **kwargs) -> str: prec = self.dmd_svd_precision[0] if kwargs.get("details", False): return f"DMD_H({self.dmd_history}, {self.dmd_svd_precision})" - return f"DMDH{self.dmd_history}{prec}" if prec != "l" else f"DMDH{self.dmd_history}" + return f"DMDH{self.dmd_history}{prec}" if prec != "m" else f"DMDH{self.dmd_history}" def to_kwargs(self) -> Dict: """Translate config fields into `DMDCalibrator` init kwargs. diff --git a/src/cache_dit/caching/cache_contexts/calibrators/dmd.py b/src/cache_dit/caching/cache_contexts/calibrators/dmd.py index c4b0f083..0e231a7c 100644 --- a/src/cache_dit/caching/cache_contexts/calibrators/dmd.py +++ b/src/cache_dit/caching/cache_contexts/calibrators/dmd.py @@ -21,7 +21,7 @@ logger = init_logger(__name__) _SVD_PRECISIONS = ("low", "medium", "high") -_SVD_LOWRANK_NITER = 4 +_SVD_LOWRANK_NITER = 1 _SVD_LOWRANK_SEED = 0 @@ -41,7 +41,7 @@ def _dmd_svd( """Economy SVD for a tall-skinny ``[d, n]`` snapshot matrix. Three precision levels: - - ``"low"``: randomised ``torch.svd_lowrank`` (fastest, niter=4, deterministic seed). + - ``"low"``: randomised ``torch.svd_lowrank`` (fastest, niter=1, deterministic seed). - ``"medium"``: standard ``torch.linalg.svd`` with default ``gesdd``. - ``"high"``: ``torch.linalg.svd`` with ``driver="gesvd"`` for maximum accuracy. @@ -70,7 +70,7 @@ def _dmd_fit_one( traj: torch.Tensor, rank: int = 0, ridge: float = 1e-8, - svd_precision: str = "low", + svd_precision: str = "medium", ): """Fit the DMD eigendecomposition for ONE ``[d, n]`` trajectory (a single batch item's snapshot history, columns OLDEST..NEWEST). @@ -115,7 +115,7 @@ def _dmd_fit( snapshots: List[torch.Tensor], rank: int = 0, ridge: float = 1e-8, - svd_precision: str = "low", + svd_precision: str = "medium", ): """Fit DMD once for a window of >= 4 same-shape snapshots, INDEPENDENTLY per batch item (axis 0). @@ -177,7 +177,7 @@ def __init__( rank: int = 0, ridge: float = 1e-8, n_derivatives: int = 1, - svd_precision: str = "low", + svd_precision: str = "medium", ): """Initialize snapshot buffers and the polynomial-fallback ladder. @@ -360,7 +360,7 @@ def __init__( rank: int = 0, ridge: float = 1e-8, n_derivatives: int = 1, - svd_precision: str = "low", + svd_precision: str = "medium", **kwargs, ): """Create a calibrator whose states are keyed by logical tensor names.