|
3 | 3 | # Licensed under the MIT License. |
4 | 4 | # -------------------------------------------------------------------------- |
5 | 5 | import logging |
| 6 | +from copy import deepcopy |
6 | 7 | from pathlib import Path |
7 | 8 |
|
8 | 9 | import onnx |
@@ -56,6 +57,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon |
56 | 57 | default_value=64, |
57 | 58 | description="Input length of the context model.", |
58 | 59 | ), |
| 60 | + "context_lengths": PassConfigParam( |
| 61 | + type_=list[int], |
| 62 | + default_value=None, |
| 63 | + description=( |
| 64 | + "List of context lengths to generate static models QNN_GPU." |
| 65 | + "If None or empty, falls back to single 'context_length'." |
| 66 | + ), |
| 67 | + ), |
59 | 68 | "group_session_options": PassConfigParam( |
60 | 69 | type_=dict, |
61 | 70 | description=( |
@@ -182,59 +191,143 @@ def process_context_iterator(component_models, llm_pipeline, output_dir): |
182 | 191 | ) |
183 | 192 |
|
184 | 193 | def _run_qnn_gpu(self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: Path): |
| 194 | + """QNN_GPU path: generate one or more static ONNX models for different context lengths. |
| 195 | +
|
| 196 | + - If config.context_lengths is None/empty: use config.context_length (single model). |
| 197 | + - If config.context_lengths has 1 value: use that context length (single model). |
| 198 | + - If config.context_lengths has >1 values: generate multiple models and return CompositeModelHandler. |
| 199 | + """ |
185 | 200 | output_model_dir = Path(output_model_path).with_suffix("") |
186 | 201 | model_path = Path(model.model_path) |
187 | 202 |
|
188 | 203 | # --- Step 1: Load model (handle both single and external data) --- |
189 | 204 | try: |
190 | | - model_proto = onnx.load(model_path, load_external_data=True) |
| 205 | + base_model_proto = onnx.load(model_path, load_external_data=True) |
191 | 206 | except Exception as e: |
192 | 207 | raise RuntimeError(f"Failed to load ONNX model: {e}") from e |
193 | 208 |
|
194 | | - # --- Step 2: Fix symbolic dimensions --- |
195 | | - batch_size, sequence_length = OnnxDAG(model_proto).get_io_shape("input_ids") |
| 209 | + # --- Step 2: Get symbolic batch and sequence dims once --- |
| 210 | + batch_size, sequence_length = OnnxDAG(base_model_proto).get_io_shape("input_ids") |
196 | 211 | if not (isinstance(batch_size, str) and isinstance(sequence_length, str)): |
197 | 212 | raise ValueError("Input dimensions must be symbolic before static shape fixing.") |
198 | 213 |
|
199 | | - param_mapping = {batch_size: config.batch_size, sequence_length: config.context_length} |
200 | | - self.fix_shape(model_proto, param_mapping) |
| 214 | + # --- Determine which context lengths to use --- |
| 215 | + cfg_ctx_lengths = getattr(config, "context_lengths", None) or [] |
| 216 | + ctx_lengths_list = [int(x) for x in cfg_ctx_lengths if x is not None] |
| 217 | + |
| 218 | + if not ctx_lengths_list: |
| 219 | + # Fall back to single context_length in config |
| 220 | + ctx_lengths_list = [int(config.context_length)] |
| 221 | + |
| 222 | + # If only one context length, we still treat it uniformly but return a single handler. |
| 223 | + multiple = len(ctx_lengths_list) > 1 |
201 | 224 |
|
202 | | - # --- Step 3: Save model as external-data format --- |
203 | | - output_model_file = Path(output_model_dir) / "model.onnx" |
204 | | - external_data_file = Path(output_model_dir) / "model.onnx.data" |
| 225 | + generated_handlers: dict[int, ONNXModelHandler] = {} |
| 226 | + generated_names: dict[int, str] = {} |
| 227 | + |
| 228 | + for ctx_len in ctx_lengths_list: |
| 229 | + # --- Clone base model proto for this variant --- |
| 230 | + model_proto = onnx.ModelProto() |
| 231 | + model_proto.CopyFrom(base_model_proto) |
| 232 | + |
| 233 | + # --- Step 3: Fix symbolic dimensions for this context length --- |
| 234 | + param_mapping = {batch_size: config.batch_size, sequence_length: ctx_len} |
| 235 | + self.fix_shape(model_proto, param_mapping) |
| 236 | + |
| 237 | + add_version_metadata_to_model_proto(model_proto) |
| 238 | + |
| 239 | + # --- Step 4: Save as external-data ONNX --- |
| 240 | + onnx_file_name = f"model_ctx{ctx_len}.onnx" |
| 241 | + output_model_file = Path(output_model_dir) / onnx_file_name |
| 242 | + external_data_file = Path(output_model_dir) / f"{onnx_file_name}.data" |
| 243 | + |
| 244 | + output_model_dir.mkdir(parents=True, exist_ok=True) |
| 245 | + onnx.save( |
| 246 | + model_proto, |
| 247 | + str(output_model_file), |
| 248 | + save_as_external_data=True, |
| 249 | + all_tensors_to_one_file=True, |
| 250 | + location=external_data_file.name, |
| 251 | + convert_attribute=False, |
| 252 | + ) |
205 | 253 |
|
206 | | - onnx.save( |
207 | | - model_proto, |
208 | | - str(output_model_file), |
209 | | - save_as_external_data=True, |
210 | | - all_tensors_to_one_file=True, |
211 | | - location=external_data_file.name, |
212 | | - convert_attribute=False, |
| 254 | + # Build handler for this static model |
| 255 | + new_model_attributes = deepcopy(model.model_attributes) or {} |
| 256 | + handler = ONNXModelHandler( |
| 257 | + model_path=output_model_dir, |
| 258 | + onnx_file_name=output_model_file.name, |
| 259 | + model_attributes=new_model_attributes, |
| 260 | + ) |
| 261 | + |
| 262 | + # Store handler + a logical component name (e.g., ctx_128) |
| 263 | + generated_handlers[ctx_len] = handler |
| 264 | + generated_names[ctx_len] = f"ctx_{ctx_len}" |
| 265 | + |
| 266 | + # --- Step 5: Update genai_config.json --- |
| 267 | + # For single model: pipeline with one component. |
| 268 | + # For multiple models: pipeline with multiple components (composite). |
| 269 | + if not multiple: |
| 270 | + # Single context length |
| 271 | + ctx_len = ctx_lengths_list[0] |
| 272 | + handler = generated_handlers[ctx_len] |
| 273 | + |
| 274 | + decoder_config_extra = { |
| 275 | + "inputs": { |
| 276 | + "past_sequence_length": "past_seq_len", |
| 277 | + "total_sequence_length": "total_seq_len", |
| 278 | + }, |
| 279 | + "sliding_window": { |
| 280 | + "window_size": ctx_len, |
| 281 | + "pad_value": 0, |
| 282 | + "alignment": "left", |
| 283 | + "slide_key_value_cache": False, |
| 284 | + }, |
| 285 | + } |
| 286 | + |
| 287 | + handler = update_llm_pipeline_genai_config_gpu( |
| 288 | + model=handler, |
| 289 | + output_model_dir=output_model_dir, |
| 290 | + decoder_config_extra=decoder_config_extra, |
| 291 | + composite_components=None, |
| 292 | + ) |
| 293 | + return handler |
| 294 | + |
| 295 | + # Multiple context lengths -> wrap in CompositeModelHandler and create composite pipeline |
| 296 | + components = [] |
| 297 | + component_names = [] |
| 298 | + for ctx_len, handler in sorted(generated_handlers.items(), key=lambda kv: kv[0]): |
| 299 | + components.append(handler) |
| 300 | + component_names.append(generated_names[ctx_len]) |
| 301 | + |
| 302 | + new_model_attributes = deepcopy(model.model_attributes) or {} |
| 303 | + |
| 304 | + composite = CompositeModelHandler( |
| 305 | + model_components=components, model_component_names=component_names, model_attributes=new_model_attributes |
213 | 306 | ) |
214 | 307 |
|
215 | | - decoder_config_extra = { |
| 308 | + # Build per-component sliding_window config keyed by name |
| 309 | + composite_decoder_extra = { |
216 | 310 | "inputs": { |
217 | 311 | "past_sequence_length": "past_seq_len", |
218 | 312 | "total_sequence_length": "total_seq_len", |
219 | 313 | }, |
220 | 314 | "sliding_window": { |
221 | | - "window_size": config.context_length, |
| 315 | + "window_size": max(ctx_lengths_list), |
222 | 316 | "pad_value": 0, |
223 | 317 | "alignment": "left", |
224 | 318 | "slide_key_value_cache": False, |
225 | 319 | }, |
226 | 320 | } |
227 | 321 |
|
228 | | - input_model_path = model.model_path |
229 | | - model_static = ONNXModelHandler(model_path=output_model_dir, onnx_file_name=output_model_file.name) |
230 | | - |
231 | | - return update_llm_pipeline_genai_config_gpu( |
232 | | - model_static, |
233 | | - output_model_dir, |
234 | | - input_model_path, |
235 | | - decoder_config_extra, |
| 322 | + composite = update_llm_pipeline_genai_config_gpu( |
| 323 | + model=composite, |
| 324 | + output_model_dir=output_model_dir, |
| 325 | + decoder_config_extra=composite_decoder_extra, |
| 326 | + composite_components=list(zip(component_names, components)), |
236 | 327 | ) |
237 | 328 |
|
| 329 | + return composite |
| 330 | + |
238 | 331 | @staticmethod |
239 | 332 | def fix_shape(model_proto: onnx.ModelProto, param_mapping: dict[str, int]): |
240 | 333 | """Fix the shape of the model based on the param mapping. |
|
0 commit comments