Skip to content

Commit 079bc4c

Browse files
refactor duplicated code
1 parent 31084bd commit 079bc4c

File tree

1 file changed

+58
-33
lines changed

1 file changed

+58
-33
lines changed

src/cloudai/workloads/aiconfig/predictor.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import sys
1920
from typing import Any, Dict, Optional, cast
2021

2122

@@ -32,6 +33,46 @@ def _to_enum(enum_cls, value_or_name):
3233
return enum_cls(value_or_name)
3334

3435

36+
def _validate_nextn(nextn: int, nextn_accept_rates: Optional[list[float]]) -> list[float]:
37+
if nextn > 0 and nextn_accept_rates is None:
38+
raise ValueError("nextn_accept_rates must be provided when nextn > 0")
39+
return nextn_accept_rates or []
40+
41+
42+
def _ensure_aiconfigurator_available(*, need_inference_session: bool) -> dict[str, Any]:
43+
"""
44+
Import required aiconfigurator symbols or raise a consistent ModuleNotFoundError.
45+
46+
Returns a dict of imported symbols so call sites can stay concise.
47+
"""
48+
try:
49+
from aiconfigurator.sdk import common
50+
from aiconfigurator.sdk.backends.factory import get_backend
51+
from aiconfigurator.sdk.config import ModelConfig, RuntimeConfig
52+
from aiconfigurator.sdk.models import get_model
53+
from aiconfigurator.sdk.perf_database import get_database
54+
55+
if need_inference_session:
56+
from aiconfigurator.sdk.inference_session import InferenceSession
57+
else:
58+
InferenceSession = None # type: ignore[assignment]
59+
except ModuleNotFoundError as e:
60+
raise ModuleNotFoundError(
61+
"Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
62+
f"(python={sys.executable})"
63+
) from e
64+
65+
return {
66+
"common": common,
67+
"get_backend": get_backend,
68+
"ModelConfig": ModelConfig,
69+
"RuntimeConfig": RuntimeConfig,
70+
"get_model": get_model,
71+
"get_database": get_database,
72+
"InferenceSession": InferenceSession,
73+
}
74+
75+
3576
def predict_ifb_single(
3677
*,
3778
model_name: str,
@@ -62,28 +103,20 @@ def predict_ifb_single(
62103
overwrite_num_layers: int = 0,
63104
) -> Dict[str, Any]:
64105
"""Predict metrics for a single IFB configuration using the aiconfigurator SDK primitives."""
65-
try:
66-
from aiconfigurator.sdk import common
67-
from aiconfigurator.sdk.backends.factory import get_backend
68-
from aiconfigurator.sdk.config import ModelConfig, RuntimeConfig
69-
from aiconfigurator.sdk.models import get_model
70-
from aiconfigurator.sdk.perf_database import get_database
71-
except ModuleNotFoundError as e:
72-
import sys as _sys
73-
74-
raise ModuleNotFoundError(
75-
"Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
76-
f"(python={_sys.executable})"
77-
) from e
106+
syms = _ensure_aiconfigurator_available(need_inference_session=False)
107+
common = syms["common"]
108+
get_backend = syms["get_backend"]
109+
ModelConfig = syms["ModelConfig"]
110+
RuntimeConfig = syms["RuntimeConfig"]
111+
get_model = syms["get_model"]
112+
get_database = syms["get_database"]
78113

79114
database = get_database(system=system, backend=backend, version=version)
80115
if database is None:
81116
raise ValueError(f"No perf database found for system={system} backend={backend} version={version}")
82117
backend_impl = cast(Any, get_backend(backend))
83118

84-
if nextn > 0 and nextn_accept_rates is None:
85-
raise ValueError("nextn_accept_rates must be provided when nextn > 0")
86-
accept_rates = cast(Any, nextn_accept_rates or [])
119+
accept_rates = _validate_nextn(nextn, nextn_accept_rates)
87120

88121
mc = ModelConfig(
89122
tp_size=tp,
@@ -164,20 +197,14 @@ def predict_disagg_single(
164197
decode_correction_scale: float = 1.0,
165198
) -> Dict[str, Any]:
166199
"""Predict metrics for a single disaggregated configuration (explicit prefill/decode workers)."""
167-
try:
168-
from aiconfigurator.sdk import common
169-
from aiconfigurator.sdk.backends.factory import get_backend
170-
from aiconfigurator.sdk.config import ModelConfig, RuntimeConfig
171-
from aiconfigurator.sdk.inference_session import InferenceSession
172-
from aiconfigurator.sdk.models import get_model
173-
from aiconfigurator.sdk.perf_database import get_database
174-
except ModuleNotFoundError as e:
175-
import sys as _sys
176-
177-
raise ModuleNotFoundError(
178-
"Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
179-
f"(python={_sys.executable})"
180-
) from e
200+
syms = _ensure_aiconfigurator_available(need_inference_session=True)
201+
common = syms["common"]
202+
get_backend = syms["get_backend"]
203+
ModelConfig = syms["ModelConfig"]
204+
RuntimeConfig = syms["RuntimeConfig"]
205+
get_model = syms["get_model"]
206+
get_database = syms["get_database"]
207+
InferenceSession = syms["InferenceSession"]
181208

182209
prefill_db = get_database(system=system, backend=backend, version=version)
183210
decode_db = get_database(system=system, backend=backend, version=version)
@@ -187,9 +214,7 @@ def predict_disagg_single(
187214
prefill_backend = cast(Any, get_backend(backend))
188215
decode_backend = cast(Any, get_backend(backend))
189216

190-
if nextn > 0 and nextn_accept_rates is None:
191-
raise ValueError("nextn_accept_rates must be provided when nextn > 0")
192-
accept_rates = cast(Any, nextn_accept_rates or [])
217+
accept_rates = _validate_nextn(nextn, nextn_accept_rates)
193218

194219
p_mc = ModelConfig(
195220
tp_size=p_tp,

0 commit comments

Comments
 (0)