Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/cache_dit/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,14 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa
default=6,
help="DMD snapshot-history window length",
)
parser.add_argument(
"--dmd-svd-precision",
"--dmd-svd",
type=str,
default="medium",
choices=["low", "medium", "high"],
help="DMD SVD precision mode: medium (default, balanced), low (randomised), high (accurate)",
)
parser.add_argument(
"--steps-mask",
action="store_true",
Expand Down Expand Up @@ -2295,9 +2303,10 @@ def _prepare_distributed_size():
force_refresh_step_hint=kwargs.get("force_refresh_step_hint", None),
force_refresh_step_policy=kwargs.get("force_refresh_step_policy", "once"),
) if cache_config is None and args.cache else cache_config),
calibrator_config=(DMDCalibratorConfig(
dmd_history=args.dmd_history) if args.dmd else TaylorSeerCalibratorConfig(
taylorseer_order=args.taylorseer_order) if args.taylorseer else None),
calibrator_config=(DMDCalibratorConfig(dmd_history=args.dmd_history,
dmd_svd_precision=args.dmd_svd_precision)
if args.dmd else TaylorSeerCalibratorConfig(
taylorseer_order=args.taylorseer_order) if args.taylorseer else None),
params_modifiers=kwargs.get("params_modifiers", None),
parallelism_config=(ParallelismConfig(
ulysses_size=ulysses_size,
Expand Down
14 changes: 11 additions & 3 deletions src/cache_dit/caching/cache_contexts/calibrators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,24 @@ class DMDCalibratorConfig(CalibratorConfig):
# dmd_ridge (`float`, *optional*, defaults to 1e-8):
# Tikhonov term added to the inverted singular values.
dmd_ridge: float = 1e-8
# dmd_svd_precision (`str`, *optional*, defaults to 'medium'):
# SVD precision mode for the DMD snapshot matrix. ``"medium"`` uses default
# ``torch.linalg.svd`` (gesdd, balanced), ``"low"`` uses randomised
# ``torch.svd_lowrank`` (niter=1, deterministic seed), ``"high"`` uses
# ``driver="gesvd"``.
dmd_svd_precision: str = "medium"

def strify(self, **kwargs) -> str:
"""Return a compact tag that includes the snapshot-history length.
"""Return a compact tag that includes the snapshot-history length and SVD precision.

:param kwargs: Additional keyword arguments forwarded to the underlying implementation.
:returns: A compact DMD tag for logs, summaries, or filenames.
"""

prec = self.dmd_svd_precision[0]
if kwargs.get("details", False):
return f"DMD_H({self.dmd_history})"
return f"DMDH{self.dmd_history}"
return f"DMD_H({self.dmd_history}, {self.dmd_svd_precision})"
return f"DMDH{self.dmd_history}{prec}" if prec != "m" else f"DMDH{self.dmd_history}"

def to_kwargs(self) -> Dict:
"""Translate config fields into `DMDCalibrator` init kwargs.
Expand All @@ -215,6 +222,7 @@ def to_kwargs(self) -> Dict:
kwargs["history"] = self.dmd_history
kwargs["rank"] = self.dmd_rank
kwargs["ridge"] = self.dmd_ridge
kwargs["svd_precision"] = self.dmd_svd_precision
return kwargs


Expand Down
72 changes: 69 additions & 3 deletions src/cache_dit/caching/cache_contexts/calibrators/dmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,57 @@

logger = init_logger(__name__)

_SVD_PRECISIONS = ("low", "medium", "high")
_SVD_LOWRANK_NITER = 1
_SVD_LOWRANK_SEED = 0


def _svd_rng_devices(X: torch.Tensor) -> list[int]:
if X.device.type != "cuda":
return []
idx = X.device.index
if idx is None:
idx = torch.cuda.current_device()
return [idx]


def _dmd_svd(
X: torch.Tensor,
svd_precision: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Economy SVD for a tall-skinny ``[d, n]`` snapshot matrix.

Three precision levels:
- ``"low"``: randomised ``torch.svd_lowrank`` (fastest, niter=1, deterministic seed).
- ``"medium"``: standard ``torch.linalg.svd`` with default ``gesdd``.
- ``"high"``: ``torch.linalg.svd`` with ``driver="gesvd"`` for maximum accuracy.

:param X: Tall-skinny input matrix of shape ``[d, n]``.
:param svd_precision: One of ``("low", "medium", "high")``.
:returns: ``(U, S, Vh)`` — economy decomposition where ``U`` is ``[d, n]``,
``S`` is ``[n]``, and ``Vh`` is ``[n, n]``.
"""

assert svd_precision in _SVD_PRECISIONS, (
f"svd_precision must be one of {_SVD_PRECISIONS}, got {svd_precision!r}")
n = X.shape[-1]
if svd_precision == "low":
with torch.random.fork_rng(devices=_svd_rng_devices(X)):
torch.manual_seed(_SVD_LOWRANK_SEED)
U, S, V = torch.svd_lowrank(X, q=n, niter=_SVD_LOWRANK_NITER)
Vh = V.mH
elif svd_precision == "high":
U, S, Vh = torch.linalg.svd(X, full_matrices=False, driver="gesvd")
else: # medium
U, S, Vh = torch.linalg.svd(X, full_matrices=False)
return U, S, Vh


def _dmd_fit_one(
traj: torch.Tensor,
rank: int = 0,
ridge: float = 1e-8,
svd_precision: str = "medium",
):
"""Fit the DMD eigendecomposition for ONE ``[d, n]`` trajectory (a single batch item's snapshot
history, columns OLDEST..NEWEST).
Expand All @@ -37,13 +83,16 @@ def _dmd_fit_one(
:param rank: SVD truncation rank; 0 selects it from the spectrum (drop modes
below 1e-4 of the leading singular value — this is what rejects noise).
:param ridge: Tikhonov term added to the inverted singular values.
:param svd_precision: SVD precision mode; ``"low"`` uses randomised
``torch.svd_lowrank`` (fastest), ``"medium"`` uses default ``gesdd``,
``"high"`` uses ``driver="gesvd"``.
:returns: ``(Phi, evals, b)`` or ``None`` on a degenerate fit (caller reuses
the last value).
"""

X, Xp = traj[:, :-1], traj[:, 1:]
try:
U, S, Vh = torch.linalg.svd(X, full_matrices=False)
U, S, Vh = _dmd_svd(X, svd_precision)
except Exception: # noqa: BLE001 — degenerate fit: caller falls back to reuse
return None
r = rank
Expand All @@ -66,6 +115,7 @@ def _dmd_fit(
snapshots: List[torch.Tensor],
rank: int = 0,
ridge: float = 1e-8,
svd_precision: str = "medium",
):
"""Fit DMD once for a window of >= 4 same-shape snapshots, INDEPENDENTLY per batch item (axis 0).

Expand All @@ -87,7 +137,9 @@ def _dmd_fit(
bsz = shp[0] if newest.dim() > 1 else 1
# (B, d, n): per-item trajectories; B == 1 reproduces the un-batched fit.
V = torch.stack([s.reshape(bsz, -1) for s in snapshots], dim=-1).to(torch.float64)
fits = [_dmd_fit_one(V[i], rank=rank, ridge=ridge) for i in range(bsz)]
fits = [
_dmd_fit_one(V[i], rank=rank, ridge=ridge, svd_precision=svd_precision) for i in range(bsz)
]
return (fits, shp, dt)


Expand Down Expand Up @@ -125,6 +177,7 @@ def __init__(
rank: int = 0,
ridge: float = 1e-8,
n_derivatives: int = 1,
svd_precision: str = "medium",
):
"""Initialize snapshot buffers and the polynomial-fallback ladder.

Expand All @@ -135,12 +188,16 @@ def __init__(
:param rank: SVD truncation rank for the DMD fit (0 = automatic).
:param ridge: Tikhonov regulariser for the DMD fit.
:param n_derivatives: Taylor orders kept for the warm-up fallback.
:param svd_precision: SVD precision mode; ``"low"`` uses randomised
``torch.svd_lowrank`` (fastest), ``"medium"`` uses default ``gesdd``,
``"high"`` uses ``driver="gesvd"``.
"""

self.history = history
self.rank = rank
self.ridge = ridge
self.n_derivatives = n_derivatives
self.svd_precision = svd_precision
self.order = n_derivatives + 1
self.current_step = -1
self.last_non_approximated_step = -1
Expand Down Expand Up @@ -267,7 +324,10 @@ def approximate(self) -> torch.Tensor:
k = (self.current_step - self.snapshots[-1][0]) / spacing
key = (self.snapshots[-1][0], len(vels), spacing)
if self._fit_key != key:
self._fit = _dmd_fit(vels, rank=self.rank, ridge=self.ridge)
self._fit = _dmd_fit(vels,
rank=self.rank,
ridge=self.ridge,
svd_precision=self.svd_precision)
self._fit_key = key
if self._fit is not None:
pred = _dmd_eval(self._fit, k)
Expand Down Expand Up @@ -300,6 +360,7 @@ def __init__(
rank: int = 0,
ridge: float = 1e-8,
n_derivatives: int = 1,
svd_precision: str = "medium",
**kwargs,
):
"""Create a calibrator whose states are keyed by logical tensor names.
Expand All @@ -308,13 +369,17 @@ def __init__(
:param rank: SVD truncation rank for the DMD fit (0 = automatic).
:param ridge: Tikhonov regulariser for the DMD fit.
:param n_derivatives: Taylor orders for the warm-up fallback ladder.
:param svd_precision: SVD precision mode; ``"low"`` uses randomised
``torch.svd_lowrank`` (fastest), ``"medium"`` uses default ``gesdd``,
``"high"`` uses ``driver="gesvd"``.
:param kwargs: Additional keyword arguments forwarded to the underlying implementation.
"""

self.history = history
self.rank = rank
self.ridge = ridge
self.n_derivatives = n_derivatives
self.svd_precision = svd_precision
self.states: Dict[str, DMDState] = {}
self.reset_cache()

Expand All @@ -340,6 +405,7 @@ def maybe_init_state(
rank=self.rank,
ridge=self.ridge,
n_derivatives=self.n_derivatives,
svd_precision=self.svd_precision,
)

def mark_step_begin(self, *args, **kwargs):
Expand Down
Loading