1616
1717from __future__ import annotations
1818
19- import sys
2019from typing import Any , Dict , Optional , cast
2120
21+ from aiconfigurator .sdk import common
22+ from aiconfigurator .sdk import config as aic_config
23+ from aiconfigurator .sdk import inference_session as aic_inference_session
24+ from aiconfigurator .sdk import models as aic_models
25+ from aiconfigurator .sdk import perf_database as aic_perf_database
26+ from aiconfigurator .sdk .backends import factory as aic_backends_factory
27+
2228
2329def _to_enum (enum_cls : Any , value_or_name : Any ) -> Any :
2430 """
@@ -39,40 +45,6 @@ def _validate_nextn(nextn: int, nextn_accept_rates: Optional[list[float]]) -> li
3945 return nextn_accept_rates or []
4046
4147
42- def _ensure_aiconfigurator_available (* , need_inference_session : bool ) -> dict [str , Any ]:
43- """
44- Import required aiconfigurator symbols or raise a consistent ModuleNotFoundError.
45-
46- Returns a dict of imported symbols so call sites can stay concise.
47- """
48- try :
49- from aiconfigurator .sdk import common
50- from aiconfigurator .sdk .backends .factory import get_backend
51- from aiconfigurator .sdk .config import ModelConfig , RuntimeConfig
52- from aiconfigurator .sdk .models import get_model
53- from aiconfigurator .sdk .perf_database import get_database
54-
55- if need_inference_session :
56- from aiconfigurator .sdk .inference_session import InferenceSession
57- else :
58- InferenceSession = None # type: ignore[assignment]
59- except ModuleNotFoundError as e :
60- raise ModuleNotFoundError (
61- "Missing dependency 'aiconfigurator'. Install it in the Python environment used for this test. "
62- f"(python={ sys .executable } )"
63- ) from e
64-
65- return {
66- "common" : common ,
67- "get_backend" : get_backend ,
68- "ModelConfig" : ModelConfig ,
69- "RuntimeConfig" : RuntimeConfig ,
70- "get_model" : get_model ,
71- "get_database" : get_database ,
72- "InferenceSession" : InferenceSession ,
73- }
74-
75-
7648def predict_ifb_single (
7749 * ,
7850 model_name : str ,
@@ -103,22 +75,14 @@ def predict_ifb_single(
10375 overwrite_num_layers : int = 0 ,
10476) -> Dict [str , Any ]:
10577 """Predict metrics for a single IFB configuration using the aiconfigurator SDK primitives."""
106- syms = _ensure_aiconfigurator_available (need_inference_session = False )
107- common = syms ["common" ]
108- get_backend = syms ["get_backend" ]
109- ModelConfig = syms ["ModelConfig" ]
110- RuntimeConfig = syms ["RuntimeConfig" ]
111- get_model = syms ["get_model" ]
112- get_database = syms ["get_database" ]
113-
114- database = get_database (system = system , backend = backend , version = version )
78+ database = aic_perf_database .get_database (system = system , backend = backend , version = version )
11579 if database is None :
11680 raise ValueError (f"No perf database found for system={ system } backend={ backend } version={ version } " )
117- backend_impl = cast (Any , get_backend (backend ))
81+ backend_impl = cast (Any , aic_backends_factory . get_backend (backend ))
11882
11983 accept_rates = _validate_nextn (nextn , nextn_accept_rates )
12084
121- mc = ModelConfig (
85+ mc = aic_config . ModelConfig (
12286 tp_size = tp ,
12387 pp_size = pp ,
12488 attention_dp_size = dp ,
@@ -133,9 +97,9 @@ def predict_ifb_single(
13397 nextn_accept_rates = accept_rates ,
13498 overwrite_num_layers = overwrite_num_layers ,
13599 )
136- model = get_model (model_name , mc , backend )
100+ model = aic_models . get_model (model_name , mc , backend )
137101
138- rc = RuntimeConfig (batch_size = batch_size , isl = isl , osl = osl )
102+ rc = aic_config . RuntimeConfig (batch_size = batch_size , isl = isl , osl = osl )
139103 summary = backend_impl .run_ifb (model = model , database = database , runtime_config = rc , ctx_tokens = ctx_tokens )
140104 df = summary .get_summary_df ()
141105 if df is None or df .empty :
@@ -197,24 +161,15 @@ def predict_disagg_single(
197161 decode_correction_scale : float = 1.0 ,
198162) -> Dict [str , Any ]:
199163 """Predict metrics for a single disaggregated configuration (explicit prefill/decode workers)."""
200- syms = _ensure_aiconfigurator_available (need_inference_session = True )
201- common = syms ["common" ]
202- get_backend = syms ["get_backend" ]
203- ModelConfig = syms ["ModelConfig" ]
204- RuntimeConfig = syms ["RuntimeConfig" ]
205- get_model = syms ["get_model" ]
206- get_database = syms ["get_database" ]
207- InferenceSession = syms ["InferenceSession" ]
208-
209- perf_db = get_database (system = system , backend = backend , version = version )
164+ perf_db = aic_perf_database .get_database (system = system , backend = backend , version = version )
210165 if perf_db is None :
211166 raise ValueError (f"No perf database found for system={ system } backend={ backend } version={ version } " )
212167
213- perf_backend = cast (Any , get_backend (backend ))
168+ perf_backend = cast (Any , aic_backends_factory . get_backend (backend ))
214169
215170 accept_rates = _validate_nextn (nextn , nextn_accept_rates )
216171
217- p_mc = ModelConfig (
172+ p_mc = aic_config . ModelConfig (
218173 tp_size = p_tp ,
219174 pp_size = p_pp ,
220175 attention_dp_size = p_dp ,
@@ -229,7 +184,7 @@ def predict_disagg_single(
229184 nextn_accept_rates = accept_rates ,
230185 overwrite_num_layers = overwrite_num_layers ,
231186 )
232- d_mc = ModelConfig (
187+ d_mc = aic_config . ModelConfig (
233188 tp_size = d_tp ,
234189 pp_size = d_pp ,
235190 attention_dp_size = d_dp ,
@@ -245,14 +200,14 @@ def predict_disagg_single(
245200 overwrite_num_layers = overwrite_num_layers ,
246201 )
247202
248- rc_prefill = RuntimeConfig (batch_size = p_bs , isl = isl , osl = osl )
249- rc_decode = RuntimeConfig (batch_size = d_bs , isl = isl , osl = osl )
203+ rc_prefill = aic_config . RuntimeConfig (batch_size = p_bs , isl = isl , osl = osl )
204+ rc_decode = aic_config . RuntimeConfig (batch_size = d_bs , isl = isl , osl = osl )
250205
251- prefill_model = get_model (model_name , p_mc , backend )
252- decode_model = get_model (model_name , d_mc , backend )
206+ prefill_model = aic_models . get_model (model_name , p_mc , backend )
207+ decode_model = aic_models . get_model (model_name , d_mc , backend )
253208
254- prefill_sess = InferenceSession (prefill_model , perf_db , perf_backend )
255- decode_sess = InferenceSession (decode_model , perf_db , perf_backend )
209+ prefill_sess = aic_inference_session . InferenceSession (prefill_model , perf_db , perf_backend )
210+ decode_sess = aic_inference_session . InferenceSession (decode_model , perf_db , perf_backend )
256211
257212 prefill_summary = prefill_sess .run_static (mode = "static_ctx" , runtime_config = rc_prefill )
258213 decode_summary = decode_sess .run_static (mode = "static_gen" , runtime_config = rc_decode )
0 commit comments