-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathgeneral.py
More file actions
executable file
·201 lines (166 loc) · 7.28 KB
/
general.py
File metadata and controls
executable file
·201 lines (166 loc) · 7.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations
import json
from importlib.metadata import version
from pathlib import Path
from random import choices, shuffle
from typing import Dict, List, Tuple, Union
import yaml
from tensorrt_llm._torch.pyexecutor.model_engine import \
validate_and_set_kv_cache_quant
from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings,
get_model_config)
from tensorrt_llm.bench.build.dataclasses import NemotronHybridConfig
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
InferenceRequest)
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
_KV_CACHE_MAP = {
QuantAlgo.FP8.value: "fp8",
QuantAlgo.NVFP4.value: "fp8",
}
ALL_SUPPORTED_BACKENDS = ["pytorch", "_autodeploy", "tensorrt"]
def get_settings_from_engine(
engine_path: Path
) -> Tuple[Dict[str, Union[str, int]], Dict[str, Union[str, int]]]:
"""Retrieve basic engine information.
Args:
engine_path (Path): Path to a TRT-LLM engine directory.
Returns:
Tuple[Dict[str, Union[str, int]], Dict[str, Union[str, int]]]: Engine
properties parsed from the engine at engine_path.
"""
config_path = engine_path / "config.json"
runtime_config = {}
with open(config_path, "r") as config_json:
config = json.load(config_json)
engine_world_map = config["pretrained_config"]["mapping"]
engine_build_cfg = config["build_config"]
engine_parallel_map = engine_build_cfg["auto_parallel_config"]
world_config = {
"pp_size": engine_world_map["pp_size"],
"tp_size": engine_world_map["tp_size"],
"world_size": engine_world_map["world_size"],
"gpus_per_node": engine_parallel_map["gpus_per_node"],
}
executor_settings = {
"max_batch_size": engine_build_cfg["max_batch_size"],
"max_num_tokens": engine_build_cfg["max_num_tokens"],
}
runtime_config.update({
"sw_version": config["version"],
"engine_dir": str(engine_path.absolute()),
"settings_config": executor_settings,
"world_config": world_config,
})
runtime_config["performance_options"] = {}
runtime_config["decoding_config"] = {
"decoding_mode": engine_build_cfg["speculative_decoding_mode"]
}
return runtime_config, engine_build_cfg
def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
model_path: Union[Path, None]) -> Dict[str, Union[str, int]]:
"""Retrieve basic runtime config for pytorch backend path
Args:
params (dict): Configuration parameters.
model (str): Model name.
model_path (Union[Path, None]): Path to the model.
Returns:
Dict[str, Union[str, int]]: Properties for runtime config.
"""
extra_llm_api_options = params.get("extra_llm_api_options")
enable_chunked_prefill = params.get("enable_chunked_prefill", False)
kv_cache_dtype = "auto"
mamba_ssm_cache_dtype = params.get("mamba_ssm_cache_dtype", "auto")
kv_cache_config = {}
if extra_llm_api_options:
with open(extra_llm_api_options, 'r') as f:
llm_args_dict = yaml.safe_load(f)
kv_cache_config = llm_args_dict.get("kv_cache_config", {
"dtype": "auto",
})
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
mamba_ssm_cache_dtype = kv_cache_config.get("mamba_ssm_cache_dtype",
mamba_ssm_cache_dtype)
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
enable_chunked_prefill)
world_config = {
"pp_size": params.get("pp"),
"tp_size": params.get("tp"),
"world_size": params.get("pp") * params.get("tp"),
"ep_size": params.get("ep"),
"cluster_size": params.get("cluster_size"),
}
if params.get("max_batch_size") and params.get("max_num_tokens"):
logger.info("Use user-provided max batch size and max num tokens.")
max_batch_size, max_num_tokens = params.get(
"max_batch_size"), params.get("max_num_tokens")
else:
model_config = get_model_config(model, model_path)
if isinstance(model_config, NemotronHybridConfig):
model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype)
from tensorrt_llm._torch.model_config import ModelConfig
model = model_path or model
tllm_model_config = ModelConfig.from_pretrained(model,
trust_remote_code=True)
if (kv_cache_dtype is None
and tllm_model_config.quant_config.kv_cache_quant_algo is None):
kv_cache_dtype = _KV_CACHE_MAP.get(
tllm_model_config.quant_config.quant_algo, "auto")
validate_and_set_kv_cache_quant(tllm_model_config, kv_cache_dtype)
max_batch_size, max_num_tokens = get_benchmark_engine_settings(
model_config,
tllm_model_config.quant_config,
params.get("tp"),
params.get("pp"),
dataset_metadata.avg_isl,
dataset_metadata.avg_osl,
params.get("kv_cache_free_gpu_mem_fraction"),
)
logger.info(
f"Max batch size and max num tokens not provided. "
f"Using heuristics or pre-defined settings: max_batch_size={max_batch_size}, max_num_tokens={max_num_tokens}."
)
# If chunked prefill is disabled, we need to ensure that the max_num_tokens is at least the max_isl
if not enable_chunked_prefill:
logger.warning(
f"Chunked prefill is disabled, but max_num_tokens ({max_num_tokens}) is less than the max ISL ({dataset_metadata.max_isl}). "
f"Forcing max_num_tokens to {dataset_metadata.max_isl + max_batch_size}."
)
max_num_tokens = max(max_num_tokens,
dataset_metadata.max_isl + max_batch_size)
else:
# TODO: Figure out how to handle chunked block size.
# Expecting this to be the max of chunk block and max_num_tokens.
pass
cuda_graph_config = {
"enable_padding": True,
"max_batch_size": max_batch_size
}
kv_cache_config["dtype"] = kv_cache_dtype
kv_cache_config["mamba_ssm_cache_dtype"] = mamba_ssm_cache_dtype
pyt_options = {
"cuda_graph_config": cuda_graph_config,
"kv_cache_config": kv_cache_config,
}
backend = params.get("backend", "pytorch")
return {
"sw_version": version("tensorrt_llm"),
"model_path": model_path,
"settings_config": {
"max_batch_size": int(max_batch_size),
"max_num_tokens": int(max_num_tokens),
"chunking": enable_chunked_prefill,
},
"world_config": world_config,
"backend": backend,
"decoding_config": {},
"performance_options": {
"cuda_graphs": True,
"pytorch_config": pyt_options,
}
}
def generate_warmup_dataset(requests, steps) -> List[InferenceRequest]:
"""Warm up the benchmarker."""
warm_up_dataset = choices(requests, k=steps)
shuffle(warm_up_dataset)
return warm_up_dataset