Skip to content

Commit 3df302e

Browse files
authored
[quantization] Fix internal test (#722)
This commit fixes internal tests. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 806fb41 commit 3df302e

3 files changed

Lines changed: 68 additions & 8 deletions

File tree

test/quantization/algorithm/test_gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def test_model(self):
391391
)
392392

393393
# Load data
394-
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="train")
394+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
395395
sample_input = tokenizer(dataset[0]["text"], return_tensors="pt").input_ids
396396

397397
# base

test/quantization/algorithm/test_smooth_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_value(self):
4949
)
5050

5151
# Load data
52-
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="train")
52+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
5353
sample_input = tokenizer(dataset[0]["text"], return_tensors="pt").input_ids
5454

5555
# base

test/quantization/recipes/optional_dependency_stubs.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,69 @@
2121
make the tests heavier than necessary.
2222
"""
2323

24+
import importlib
2425
import sys
2526
import types
2627

28+
_STUB_MARKER = "__tico_optional_dependency_stub__"
29+
2730

2831
def install_optional_dependency_stubs() -> None:
2932
"""Install lightweight stubs for optional recipe dependencies."""
3033
_install_datasets_stub()
3134
_install_lm_eval_stub()
3235

3336

37+
def _has_attrs(module: types.ModuleType, names: tuple[str, ...]) -> bool:
38+
"""Return True if a module has all required attributes."""
39+
return all(hasattr(module, name) for name in names)
40+
41+
42+
def _is_our_stub(module: types.ModuleType | None) -> bool:
43+
"""Return True if a module was installed by this stub helper."""
44+
return bool(module is not None and getattr(module, _STUB_MARKER, False))
45+
46+
47+
def _try_import_optional_module(module_name: str) -> types.ModuleType | None:
48+
"""
49+
Import an optional module if it is available.
50+
51+
Missing optional modules are tolerated. Import failures caused by missing
52+
transitive dependencies are re-raised so broken real installations are not
53+
silently hidden by test stubs.
54+
"""
55+
try:
56+
return importlib.import_module(module_name)
57+
except ModuleNotFoundError as exc:
58+
top_level_name = module_name.partition(".")[0]
59+
if exc.name in {module_name, top_level_name}:
60+
return None
61+
raise
62+
63+
3464
def _install_datasets_stub() -> None:
3565
"""Install a minimal datasets module when the real package is unavailable."""
36-
if "datasets" in sys.modules and all(
37-
hasattr(sys.modules["datasets"], name)
38-
for name in ("Dataset", "IterableDataset", "load_dataset")
66+
required_attrs = ("Dataset", "IterableDataset", "load_dataset")
67+
existing_module = sys.modules.get("datasets")
68+
69+
if (
70+
existing_module is not None
71+
and not _is_our_stub(existing_module)
72+
and _has_attrs(existing_module, required_attrs)
3973
):
4074
return
4175

76+
real_module = _try_import_optional_module("datasets")
77+
if real_module is not None and _has_attrs(real_module, required_attrs):
78+
return
79+
4280
module = sys.modules.get("datasets")
43-
if module is None:
81+
if module is None or not _is_our_stub(module):
4482
module = types.ModuleType("datasets")
4583
sys.modules["datasets"] = module
4684

85+
setattr(module, _STUB_MARKER, True)
86+
4787
class Dataset:
4888
"""Minimal datasets.Dataset stub for import-time compatibility."""
4989

@@ -74,15 +114,31 @@ def load_dataset(*args, **kwargs):
74114

75115
def _install_lm_eval_stub() -> None:
76116
"""Install minimal lm_eval modules when the real package is unavailable."""
77-
if "lm_eval" in sys.modules and hasattr(sys.modules["lm_eval"], "evaluator"):
117+
existing_module = sys.modules.get("lm_eval")
118+
119+
if (
120+
existing_module is not None
121+
and not _is_our_stub(existing_module)
122+
and hasattr(existing_module, "evaluator")
123+
):
78124
return
79125

126+
real_module = _try_import_optional_module("lm_eval")
127+
if real_module is not None and not _is_our_stub(real_module):
128+
real_evaluator_module = _try_import_optional_module("lm_eval.evaluator")
129+
if real_evaluator_module is not None:
130+
setattr(real_module, "evaluator", real_evaluator_module)
131+
return
132+
80133
lm_eval_module = sys.modules.get("lm_eval")
81-
if lm_eval_module is None:
134+
if lm_eval_module is None or not _is_our_stub(lm_eval_module):
82135
lm_eval_module = types.ModuleType("lm_eval")
83136
sys.modules["lm_eval"] = lm_eval_module
84137

138+
setattr(lm_eval_module, _STUB_MARKER, True)
139+
85140
evaluator_module = types.ModuleType("lm_eval.evaluator")
141+
setattr(evaluator_module, _STUB_MARKER, True)
86142

87143
def simple_evaluate(*args, **kwargs):
88144
"""Fail clearly if a test accidentally runs real lm-eval."""
@@ -96,6 +152,7 @@ def simple_evaluate(*args, **kwargs):
96152
setattr(lm_eval_module, "evaluator", evaluator_module)
97153

98154
utils_module = types.ModuleType("lm_eval.utils")
155+
setattr(utils_module, _STUB_MARKER, True)
99156

100157
def make_table(results):
101158
"""Return a stable string representation for patched evaluation results."""
@@ -106,7 +163,10 @@ def make_table(results):
106163
setattr(lm_eval_module, "utils", utils_module)
107164

108165
models_module = types.ModuleType("lm_eval.models")
166+
setattr(models_module, _STUB_MARKER, True)
167+
109168
huggingface_module = types.ModuleType("lm_eval.models.huggingface")
169+
setattr(huggingface_module, _STUB_MARKER, True)
110170

111171
class HFLM:
112172
"""Minimal HFLM stub for import-time compatibility."""

0 commit comments

Comments
 (0)