Skip to content

Commit 9f6e8fd

Browse files
authored
Fix nested use_cache disabling for calibration (#1704)
### What does this PR do? Type of change: Bug fix Extends the calibration/memory-probe `use_cache` guard to Step 3.7-style nested text configs. Step 3.7 remote code reads the language config under `model.config.text_config` directly and raises `AttributeError` when `use_cache` is absent during PTQ calibration with Transformers >5. This keeps the existing Step 3.5 behavior and applies the same temporary set/restore logic to the nested text config. ### Usage No API change. PTQ calibration continues to use the existing forward-loop path. ### Testing - `pre-commit run ruff-format --files modelopt/torch/utils/dataset_utils.py tests/unit/torch/utils/test_dataset_utils.py` - `pre-commit run ruff-check --files modelopt/torch/utils/dataset_utils.py tests/unit/torch/utils/test_dataset_utils.py` - `python -m py_compile modelopt/torch/utils/dataset_utils.py tests/unit/torch/utils/test_dataset_utils.py` - `python -m pytest tests/unit/torch/utils/test_dataset_utils.py -k "disable_use_cache or iter_use_cache_configs or forward_loop_runs_under_disabled" -vv` ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A - Did you get Claude approval on this PR?: N/A ### Additional Information This is separate from PR #1693. Step 3.7 needs both fixes if both failure paths are exercised: this PR fixes PTQ calibration-time `use_cache` handling, while PR #1693 fixes exported config `layer_types` metadata for deployment config loading. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved handling of cache flags stored in nested model configuration objects: cache is reliably disabled during dataset operations and restored or removed afterward. * **Tests** * Added unit tests covering nested-config disabling, restoration/removal of cache flags post-operation, and deduplication when nested configs reference the same object. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent c4f39bd commit 9f6e8fd

2 files changed

Lines changed: 77 additions & 14 deletions

File tree

modelopt/torch/utils/dataset_utils.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -920,9 +920,29 @@ def get_supported_datasets() -> list[str]:
920920
return list(SUPPORTED_DATASET_CONFIG.keys()) + list(DATASET_COMBOS.keys())
921921

922922

923+
_NESTED_USE_CACHE_CONFIG_ATTRS = ("text_config",)
924+
925+
926+
def _iter_use_cache_configs(model: torch.nn.Module) -> Iterator[Any]:
927+
"""Yield the top-level config and Step3.7-style nested text config."""
928+
seen: set[int] = set()
929+
config = getattr(model, "config", None)
930+
if config is None:
931+
return
932+
933+
for candidate in (
934+
config,
935+
*(getattr(config, attr, None) for attr in _NESTED_USE_CACHE_CONFIG_ATTRS),
936+
):
937+
if candidate is None or id(candidate) in seen:
938+
continue
939+
seen.add(id(candidate))
940+
yield candidate
941+
942+
923943
@contextmanager
924944
def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]:
925-
"""Set ``model.config.use_cache = False`` for the duration of the block.
945+
"""Set model config ``use_cache`` flags to ``False`` for the duration of the block.
926946
927947
KV caching is unwanted during calibration / memory-probe forward passes:
928948
it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH)
@@ -931,23 +951,26 @@ def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]:
931951
present) also sidesteps configs that never assign the attribute at all
932952
— e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward
933953
code that reads ``self.config.use_cache`` would otherwise raise
934-
``AttributeError``. The prior value is restored on exit if one existed.
954+
``AttributeError``. Step3.7 keeps the relevant language config nested
955+
under ``text_config``; that config object is handled the same way. The
956+
prior value is restored on exit if one existed.
935957
"""
936-
config = getattr(model, "config", None)
937-
if config is None:
938-
yield
939-
return
940-
had_attr = hasattr(config, "use_cache")
941-
prev = config.use_cache if had_attr else None
942-
config.use_cache = False
958+
states = []
959+
for config in _iter_use_cache_configs(model):
960+
had_attr = hasattr(config, "use_cache")
961+
prev = config.use_cache if had_attr else None
962+
config.use_cache = False
963+
states.append((config, had_attr, prev))
964+
943965
try:
944966
yield
945967
finally:
946-
if had_attr:
947-
config.use_cache = prev
948-
else:
949-
with suppress(AttributeError):
950-
delattr(config, "use_cache")
968+
for config, had_attr, prev in reversed(states):
969+
if had_attr:
970+
config.use_cache = prev
971+
else:
972+
with suppress(AttributeError):
973+
delattr(config, "use_cache")
951974

952975

953976
def get_max_batch_size(

tests/unit/torch/utils/test_dataset_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DATASET_COMBOS,
2626
_disable_use_cache,
2727
_forward_loop,
28+
_iter_use_cache_configs,
2829
_pack_documents_into_rows,
2930
_process_batch,
3031
get_dataset_dataloader,
@@ -222,6 +223,45 @@ def test_disable_use_cache_without_existing_attr():
222223
assert not hasattr(model.config, "use_cache")
223224

224225

226+
@pytest.mark.parametrize("prev_value", [True, False])
227+
def test_disable_use_cache_with_nested_text_config_existing_attr(prev_value):
228+
"""Nested text config `use_cache` is disabled and restored."""
229+
model = torch.nn.Linear(4, 4)
230+
model.config = _Config()
231+
model.config.text_config = _Config()
232+
model.config.text_config.use_cache = prev_value
233+
234+
with _disable_use_cache(model):
235+
assert model.config.use_cache is False
236+
assert model.config.text_config.use_cache is False
237+
238+
assert not hasattr(model.config, "use_cache")
239+
assert model.config.text_config.use_cache is prev_value
240+
241+
242+
def test_disable_use_cache_with_nested_text_config_without_existing_attr():
243+
"""Nested text config `use_cache` is removed if it was added by the context."""
244+
model = torch.nn.Linear(4, 4)
245+
model.config = _Config()
246+
model.config.text_config = _Config()
247+
248+
with _disable_use_cache(model):
249+
assert model.config.use_cache is False
250+
assert model.config.text_config.use_cache is False
251+
252+
assert not hasattr(model.config, "use_cache")
253+
assert not hasattr(model.config.text_config, "use_cache")
254+
255+
256+
def test_iter_use_cache_configs_deduplicates_text_config_alias():
257+
"""The same config object is patched once if `config.text_config is config`."""
258+
model = torch.nn.Linear(4, 4)
259+
model.config = _Config()
260+
model.config.text_config = model.config
261+
262+
assert list(_iter_use_cache_configs(model)) == [model.config]
263+
264+
225265
def test_forward_loop_runs_under_disabled_use_cache():
226266
"""`_forward_loop` runs forward on every batch and restores `use_cache` on exit."""
227267
seen_use_cache: list[bool] = []

0 commit comments

Comments
 (0)