1616
1717from __future__ import annotations
1818
19+ import sys
1920from typing import Any , Dict , Optional , cast
2021
2122
@@ -32,6 +33,46 @@ def _to_enum(enum_cls, value_or_name):
3233 return enum_cls (value_or_name )
3334
3435
36+ def _validate_nextn (nextn : int , nextn_accept_rates : Optional [list [float ]]) -> list [float ]:
37+ if nextn > 0 and nextn_accept_rates is None :
38+ raise ValueError ("nextn_accept_rates must be provided when nextn > 0" )
39+ return nextn_accept_rates or []
40+
41+
42+ def _ensure_aiconfigurator_available (* , need_inference_session : bool ) -> dict [str , Any ]:
43+ """
44+ Import required aiconfigurator symbols or raise a consistent ModuleNotFoundError.
45+
46+ Returns a dict of imported symbols so call sites can stay concise.
47+ """
48+ try :
49+ from aiconfigurator .sdk import common
50+ from aiconfigurator .sdk .backends .factory import get_backend
51+ from aiconfigurator .sdk .config import ModelConfig , RuntimeConfig
52+ from aiconfigurator .sdk .models import get_model
53+ from aiconfigurator .sdk .perf_database import get_database
54+
55+ if need_inference_session :
56+ from aiconfigurator .sdk .inference_session import InferenceSession
57+ else :
58+ InferenceSession = None # type: ignore[assignment]
59+ except ModuleNotFoundError as e :
60+ raise ModuleNotFoundError (
61+ "Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
62+ f"(python={ sys .executable } )"
63+ ) from e
64+
65+ return {
66+ "common" : common ,
67+ "get_backend" : get_backend ,
68+ "ModelConfig" : ModelConfig ,
69+ "RuntimeConfig" : RuntimeConfig ,
70+ "get_model" : get_model ,
71+ "get_database" : get_database ,
72+ "InferenceSession" : InferenceSession ,
73+ }
74+
75+
3576def predict_ifb_single (
3677 * ,
3778 model_name : str ,
@@ -62,28 +103,20 @@ def predict_ifb_single(
62103 overwrite_num_layers : int = 0 ,
63104) -> Dict [str , Any ]:
64105 """Predict metrics for a single IFB configuration using the aiconfigurator SDK primitives."""
65- try :
66- from aiconfigurator .sdk import common
67- from aiconfigurator .sdk .backends .factory import get_backend
68- from aiconfigurator .sdk .config import ModelConfig , RuntimeConfig
69- from aiconfigurator .sdk .models import get_model
70- from aiconfigurator .sdk .perf_database import get_database
71- except ModuleNotFoundError as e :
72- import sys as _sys
73-
74- raise ModuleNotFoundError (
75- "Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
76- f"(python={ _sys .executable } )"
77- ) from e
106+ syms = _ensure_aiconfigurator_available (need_inference_session = False )
107+ common = syms ["common" ]
108+ get_backend = syms ["get_backend" ]
109+ ModelConfig = syms ["ModelConfig" ]
110+ RuntimeConfig = syms ["RuntimeConfig" ]
111+ get_model = syms ["get_model" ]
112+ get_database = syms ["get_database" ]
78113
79114 database = get_database (system = system , backend = backend , version = version )
80115 if database is None :
81116 raise ValueError (f"No perf database found for system={ system } backend={ backend } version={ version } " )
82117 backend_impl = cast (Any , get_backend (backend ))
83118
84- if nextn > 0 and nextn_accept_rates is None :
85- raise ValueError ("nextn_accept_rates must be provided when nextn > 0" )
86- accept_rates = cast (Any , nextn_accept_rates or [])
119+ accept_rates = _validate_nextn (nextn , nextn_accept_rates )
87120
88121 mc = ModelConfig (
89122 tp_size = tp ,
@@ -164,20 +197,14 @@ def predict_disagg_single(
164197 decode_correction_scale : float = 1.0 ,
165198) -> Dict [str , Any ]:
166199 """Predict metrics for a single disaggregated configuration (explicit prefill/decode workers)."""
167- try :
168- from aiconfigurator .sdk import common
169- from aiconfigurator .sdk .backends .factory import get_backend
170- from aiconfigurator .sdk .config import ModelConfig , RuntimeConfig
171- from aiconfigurator .sdk .inference_session import InferenceSession
172- from aiconfigurator .sdk .models import get_model
173- from aiconfigurator .sdk .perf_database import get_database
174- except ModuleNotFoundError as e :
175- import sys as _sys
176-
177- raise ModuleNotFoundError (
178- "Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
179- f"(python={ _sys .executable } )"
180- ) from e
200+ syms = _ensure_aiconfigurator_available (need_inference_session = True )
201+ common = syms ["common" ]
202+ get_backend = syms ["get_backend" ]
203+ ModelConfig = syms ["ModelConfig" ]
204+ RuntimeConfig = syms ["RuntimeConfig" ]
205+ get_model = syms ["get_model" ]
206+ get_database = syms ["get_database" ]
207+ InferenceSession = syms ["InferenceSession" ]
181208
182209 prefill_db = get_database (system = system , backend = backend , version = version )
183210 decode_db = get_database (system = system , backend = backend , version = version )
@@ -187,9 +214,7 @@ def predict_disagg_single(
187214 prefill_backend = cast (Any , get_backend (backend ))
188215 decode_backend = cast (Any , get_backend (backend ))
189216
190- if nextn > 0 and nextn_accept_rates is None :
191- raise ValueError ("nextn_accept_rates must be provided when nextn > 0" )
192- accept_rates = cast (Any , nextn_accept_rates or [])
217+ accept_rates = _validate_nextn (nextn , nextn_accept_rates )
193218
194219 p_mc = ModelConfig (
195220 tp_size = p_tp ,
0 commit comments