Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ license-files = ["LICENSE"]
requires-python = ">=3.9"

dependencies = [
"asteval>=1.0.5",
"jinja2>=3.1.0",
"matplotlib>=3.9.4",
"numpy~=1.26.4",
Expand Down
64 changes: 38 additions & 26 deletions src/aiconfigurator/generator/rendering/rule_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from pathlib import Path
from typing import Any, Optional

from jinja2 import Environment
from asteval import Interpreter
from munch import DefaultMunch

_ENV = Environment()
logger = logging.getLogger(__name__)
_BASE_DIR = Path(__file__).resolve().parent
_RULES_DIR = (_BASE_DIR.parent / "rule_plugin").resolve()
Expand All @@ -28,49 +28,61 @@ def _get_scope(pv: dict[str, Any], scope: str) -> Optional[dict[str, Any]]:


def _eval(expr: str, scope: str, pv: dict[str, Any]) -> Any:
ctx: dict[str, Any] = {}
ctx.update(pv)
service_cfg = pv.get("ServiceConfig", {})
"""
Safely evaluate a DSL expression from a .rule file within a scoped parameter context.
Supports comprehensive Python syntax via asteval and dot notation access to nested dictionaries via DefaultMunch.

Args:
expr (str): A Python/DSL expression to evaluate.
scope (str): The configuration scope (e.g., 'prefill', 'decode', or 'agg').
pv (dict[str, Any]): The full dictionary of generator parameters with dot notation support.
"""
ctx = DefaultMunch.fromDict(pv, None)
service_cfg = pv.get("ServiceConfig", DefaultMunch(None))
ctx.update(service_cfg)
if isinstance(service_cfg, dict):
ctx.setdefault("ServiceConfig", service_cfg)
k8s_cfg = pv.get("K8sConfig", {})
k8s_cfg = pv.get("K8sConfig", DefaultMunch(None))
ctx.update(k8s_cfg)
if isinstance(k8s_cfg, dict):
ctx.setdefault("K8sConfig", k8s_cfg)
node_cfg = pv.get("NodeConfig", {})
node_cfg = pv.get("NodeConfig", DefaultMunch(None))
if isinstance(node_cfg, dict):
ctx.update(node_cfg)
ctx.setdefault("NodeConfig", node_cfg)
dyn_cfg = pv.get("DynConfig")
if isinstance(dyn_cfg, dict):
ctx.setdefault("DynConfig", dyn_cfg)

# Provide structured aliases for DSL compatibility
if "SlaConfig" not in ctx:
sla_cfg = pv.get("SlaConfig")
if isinstance(sla_cfg, dict):
ctx["SlaConfig"] = sla_cfg
sc = pv.get("params", {}).get(scope, {})
sc = pv.get("params", {}).get(scope, DefaultMunch(None))
ctx.update(sc)

# Alias ModelConfig.is_moe -> is_moe for convenience
mc = pv.get("ModelConfig") or pv.get("model") or pv.get("model_config") or {}
if not mc and "ServiceConfig" in pv and isinstance(pv["ServiceConfig"], dict):
svc = pv["ServiceConfig"]
modeled: dict[str, Any] = {}
if "is_moe" in svc:
modeled["is_moe"] = svc.get("is_moe")
if modeled:
mc = modeled
if isinstance(mc, dict):
ctx.setdefault("ModelConfig", mc)
if "is_moe" in mc and "is_moe" not in ctx:
ctx["is_moe"] = mc.get("is_moe")
isl = sc.get("max_seq_len", pv.get("max_seq_len"))
bs = sc.get("max_batch_size", pv.get("max_batch_size"))
ctx["isl"] = isl if isl is not None else 0
ctx["bs"] = bs if bs is not None else 1
fn = _ENV.compile_expression(expr.strip())
return fn(**ctx)
ctx.ModelConfig = ctx.get("ModelConfig") or ctx.get("model") or ctx.get("model_config") or DefaultMunch(None)
if ctx.ModelConfig.is_moe is None and isinstance(ctx.get("service", {}).get("is_moe"), bool):
ctx.ModelConfig.is_moe = ctx.get("service", {}).get("is_moe")
if ctx.is_moe is None and ctx.ModelConfig.is_moe is not None:
ctx.is_moe = ctx.ModelConfig.is_moe

ctx.isl = sc.get("max_seq_len") or pv.get("max_seq_len") or 0
ctx.bs = sc.get("max_batch_size") or pv.get("max_batch_size") or 1

# DSL compatibility: lowercase booleans
ctx.true = True
ctx.false = False

# Evaluate expression safely with asteval
aeval = Interpreter(user_symbols=ctx)
result = aeval(expr.strip())
if aeval.error:
error_msg = "\n".join(str(e) for e in aeval.error)
raise ValueError(f"Rule engine evaluation failed: {error_msg}")
return result


def _parse_assign(line: str) -> Optional[tuple[Optional[str], str, str]]:
Expand Down
2 changes: 1 addition & 1 deletion src/aiconfigurator/generator/rule_plugin/sglang.rule
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ agg_decode max_batch_size = (max_batch_size if max_batch_size else 128)
agg_prefill_decode max_prefill_tokens = SlaConfig.isl + 1500
agg enable_mixed_chunk = true

agg_prefill_decode cuda_graph_batch_sizes = ((range(1, max_batch_size + 1) | list) if max_batch_size else [])
agg_prefill_decode cuda_graph_batch_sizes = ([x for x in [1,2,4,8,16,32,64,128,256,512,1024] if x < max_batch_size] + [max_batch_size] if max_batch_size else [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is pending in discussion in this PR: #245.
We need a meeting to finalize this fix.


# GPUs per worker follow the same TP/PP/DP product that SGLang expects
agg_prefill_decode gpus_per_worker = (tensor_parallel_size or 1) * (pipeline_parallel_size or 1) * (data_parallel_size or 1)
Expand Down
2 changes: 1 addition & 1 deletion src/aiconfigurator/generator/rule_plugin/trtllm.rule
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ prefill max_num_tokens = SlaConfig.isl + 1500
decode max_num_tokens = max_batch_size
agg max_num_tokens = max_batch_size + SlaConfig.isl + 1500

agg_prefill_decode cuda_graph_batch_sizes = ((range(1, max_batch_size + 1) | list) if max_batch_size else [])
agg_prefill_decode cuda_graph_batch_sizes = ([x for x in [1,2,4,8,16,32,64,128,256,512,1024] if x < max_batch_size] + [max_batch_size] if max_batch_size else [])

# Enforce TensorRT-LLM MoE parallelism: moe_tp × moe_ep = tp
when ModelConfig.is_moe and (moe_tensor_parallel_size and moe_expert_parallel_size):
Expand Down
2 changes: 2 additions & 0 deletions src/aiconfigurator/generator/rule_plugin/vllm.rule
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ decode max_num_tokens = max_batch_size
agg max_num_tokens = (max_batch_size or 0) + (SlaConfig.isl or 0) + 1500
agg max_seq_len = (SlaConfig.isl or 0) + (SlaConfig.osl or 0) + 1500

agg_prefill_decode cuda_graph_batch_sizes = ([x for x in [1,2,4,8,16,32,64,128,256,512,1024] if x < max_batch_size] + [max_batch_size] if max_batch_size else [])

when (ModelConfig.prefix or 0) > 0:
disable_prefix_cache = false
DynConfig.enable_router = true
Expand Down