Skip to content

Commit 0c0236a

Browse files
committed
Support weight sharing in QNN GPU
1 parent 6e65e35 commit 0c0236a

File tree

3 files changed

+173
-44
lines changed

3 files changed

+173
-44
lines changed

olive/passes/onnx/common.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import re
8+
from collections.abc import Iterable
89
from copy import deepcopy
910
from pathlib import Path
1011
from typing import Any, Callable, Optional, Union
@@ -776,40 +777,46 @@ def update_llm_pipeline_genai_config(
776777

777778

778779
def update_llm_pipeline_genai_config_gpu(
779-
model: ONNXModelHandler,
780+
model: Union[ONNXModelHandler, CompositeModelHandler],
780781
output_model_dir: Union[str, Path],
781-
input_model_path: Union[str, Path],
782782
decoder_config_extra: Optional[dict[str, Any]] = None,
783-
) -> ONNXModelHandler:
783+
composite_components: Optional[Iterable[tuple[str, ONNXModelHandler]]] = None,
784+
) -> Union[ONNXModelHandler, CompositeModelHandler]:
784785
"""Update the LLM pipeline in the model's genai_config.json file.
785786
786-
:param model: The model to update.
787+
:param model: The model (single or composite) to update.
788+
:param output_model_dir: Directory where the updated genai_config.json should be written.
787789
:param decoder_config_extra: Extra configuration for the decoder.
790+
:param composite_components: Optional iterable of (component_name, ONNXModelHandler)
791+
used to build a multi-component pipeline.
792+
:return: The same `model` object (with its directory now having updated genai_config.json).
788793
"""
789794
output_model_dir = Path(output_model_dir)
790795

791-
# update genai_config if it exists
796+
additional_files = model.model_attributes["additional_files"]
792797
genai_config_path = None
793-
genai_config_path = Path(input_model_path).parent / "genai_config.json"
798+
for file_path in additional_files:
799+
if Path(file_path).name == "genai_config.json":
800+
genai_config_path = file_path
801+
break
794802

795-
if genai_config_path.exists():
796-
genai_config_path = str(genai_config_path.resolve())
797-
else:
803+
if not genai_config_path:
798804
return model
799805

800806
with open(genai_config_path) as f:
801807
genai_config = json.load(f)
802-
803808
# update model_type
804809
genai_config["model"]["type"] = "decoder-pipeline"
805810

806-
# Update the provider_options list
807811
provider_option = {"qnn": {"backend_type": "gpu"}}
808-
genai_config["model"]["decoder"]["session_options"]["provider_options"] = [provider_option]
812+
decoder = genai_config["model"].setdefault("decoder", {})
813+
session_opts = decoder.setdefault("session_options", {})
814+
session_opts["provider_options"] = [provider_option]
809815

810816
# update decoder config
811817
decoder_config = genai_config["model"]["decoder"]
812818
decoder_config.get("sliding_window", {}).pop("slide_inputs", None)
819+
813820
for key, value in (decoder_config_extra or {}).items():
814821
exisiting_value = decoder_config.get(key)
815822
if isinstance(exisiting_value, dict):
@@ -819,20 +826,47 @@ def update_llm_pipeline_genai_config_gpu(
819826
else:
820827
decoder_config[key] = value
821828

822-
pipeline_config = {}
823-
component_io_config = model.io_config
824-
pipeline_config["model_onnx"] = {
825-
"filename": Path(model.model_path).name,
826-
"inputs": component_io_config["input_names"],
827-
"outputs": component_io_config["output_names"],
828-
}
829+
# --- Build pipeline_config ---
830+
pipeline_config: dict[str, Any] = {}
831+
832+
if composite_components is None:
833+
if not isinstance(model, ONNXModelHandler):
834+
handlers = list(model.get_model_components())
835+
if not handlers:
836+
return model
837+
_, single_handler = handlers[0]
838+
else:
839+
single_handler = model
840+
841+
component_io_config = single_handler.io_config
842+
pipeline_config["model_onnx"] = {
843+
"filename": Path(single_handler.model_path).name,
844+
"inputs": component_io_config["input_names"],
845+
"outputs": component_io_config["output_names"],
846+
}
847+
848+
else:
849+
# Composite case: one entry per component
850+
for comp_name, comp_handler in composite_components:
851+
component_io_config = comp_handler.io_config
852+
pipeline_config[comp_name] = {
853+
"filename": Path(comp_handler.model_path).name,
854+
"inputs": component_io_config["input_names"],
855+
"outputs": component_io_config["output_names"],
856+
}
857+
if comp_name == "ctx_1":
858+
pipeline_config[comp_name]["run_on_prompt"] = False
859+
else:
860+
pipeline_config[comp_name]["run_on_token_gen"] = False
829861

830862
decoder_config["pipeline"] = [pipeline_config]
831863

832864
# save the updated genai_config
833865
new_genai_config_path = output_model_dir / "genai_config.json"
834866
with new_genai_config_path.open("w") as f:
835867
json.dump(genai_config, f, indent=4)
868+
additional_files.remove(genai_config_path)
869+
additional_files.append(str(new_genai_config_path))
836870

837871
return model
838872

olive/passes/onnx/context_binary.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def _generate_context_binary(
243243
if execution_provider == ExecutionProvider.QNNExecutionProvider:
244244
if str(device).lower() == "gpu":
245245
provider_options["backend_path"] = "libQnnGpu.so" if platform.system() == "Linux" else "QnnGpu.dll"
246+
if share_ep_contexts:
247+
provider_options["enable_gpu_weight_sharing"] = "1"
246248
update_llm_pipeline_genai_config_gpu_ctxbin(model_path)
247249
else:
248250
if version.parse(OrtVersion).release < version.parse("1.22.0").release:

olive/passes/onnx/static_llm.py

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55
import logging
6+
from copy import deepcopy
67
from pathlib import Path
78

89
import onnx
@@ -56,6 +57,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
5657
default_value=64,
5758
description="Input length of the context model.",
5859
),
60+
"context_lengths": PassConfigParam(
61+
type_=list[int],
62+
default_value=None,
63+
description=(
64+
"List of context lengths to generate static models QNN_GPU."
65+
"If None or empty, falls back to single 'context_length'."
66+
),
67+
),
5968
"group_session_options": PassConfigParam(
6069
type_=dict,
6170
description=(
@@ -182,59 +191,143 @@ def process_context_iterator(component_models, llm_pipeline, output_dir):
182191
)
183192

184193
def _run_qnn_gpu(self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: Path):
194+
"""QNN_GPU path: generate one or more static ONNX models for different context lengths.
195+
196+
- If config.context_lengths is None/empty: use config.context_length (single model).
197+
- If config.context_lengths has 1 value: use that context length (single model).
198+
- If config.context_lengths has >1 values: generate multiple models and return CompositeModelHandler.
199+
"""
185200
output_model_dir = Path(output_model_path).with_suffix("")
186201
model_path = Path(model.model_path)
187202

188203
# --- Step 1: Load model (handle both single and external data) ---
189204
try:
190-
model_proto = onnx.load(model_path, load_external_data=True)
205+
base_model_proto = onnx.load(model_path, load_external_data=True)
191206
except Exception as e:
192207
raise RuntimeError(f"Failed to load ONNX model: {e}") from e
193208

194-
# --- Step 2: Fix symbolic dimensions ---
195-
batch_size, sequence_length = OnnxDAG(model_proto).get_io_shape("input_ids")
209+
# --- Step 2: Get symbolic batch and sequence dims once ---
210+
batch_size, sequence_length = OnnxDAG(base_model_proto).get_io_shape("input_ids")
196211
if not (isinstance(batch_size, str) and isinstance(sequence_length, str)):
197212
raise ValueError("Input dimensions must be symbolic before static shape fixing.")
198213

199-
param_mapping = {batch_size: config.batch_size, sequence_length: config.context_length}
200-
self.fix_shape(model_proto, param_mapping)
214+
# --- Determine which context lengths to use ---
215+
cfg_ctx_lengths = getattr(config, "context_lengths", None) or []
216+
ctx_lengths_list = [int(x) for x in cfg_ctx_lengths if x is not None]
217+
218+
if not ctx_lengths_list:
219+
# Fall back to single context_length in config
220+
ctx_lengths_list = [int(config.context_length)]
221+
222+
# If only one context length, we still treat it uniformly but return a single handler.
223+
multiple = len(ctx_lengths_list) > 1
201224

202-
# --- Step 3: Save model as external-data format ---
203-
output_model_file = Path(output_model_dir) / "model.onnx"
204-
external_data_file = Path(output_model_dir) / "model.onnx.data"
225+
generated_handlers: dict[int, ONNXModelHandler] = {}
226+
generated_names: dict[int, str] = {}
227+
228+
for ctx_len in ctx_lengths_list:
229+
# --- Clone base model proto for this variant ---
230+
model_proto = onnx.ModelProto()
231+
model_proto.CopyFrom(base_model_proto)
232+
233+
# --- Step 3: Fix symbolic dimensions for this context length ---
234+
param_mapping = {batch_size: config.batch_size, sequence_length: ctx_len}
235+
self.fix_shape(model_proto, param_mapping)
236+
237+
add_version_metadata_to_model_proto(model_proto)
238+
239+
# --- Step 4: Save as external-data ONNX ---
240+
onnx_file_name = f"model_ctx{ctx_len}.onnx"
241+
output_model_file = Path(output_model_dir) / onnx_file_name
242+
external_data_file = Path(output_model_dir) / f"{onnx_file_name}.data"
243+
244+
output_model_dir.mkdir(parents=True, exist_ok=True)
245+
onnx.save(
246+
model_proto,
247+
str(output_model_file),
248+
save_as_external_data=True,
249+
all_tensors_to_one_file=True,
250+
location=external_data_file.name,
251+
convert_attribute=False,
252+
)
205253

206-
onnx.save(
207-
model_proto,
208-
str(output_model_file),
209-
save_as_external_data=True,
210-
all_tensors_to_one_file=True,
211-
location=external_data_file.name,
212-
convert_attribute=False,
254+
# Build handler for this static model
255+
new_model_attributes = deepcopy(model.model_attributes) or {}
256+
handler = ONNXModelHandler(
257+
model_path=output_model_dir,
258+
onnx_file_name=output_model_file.name,
259+
model_attributes=new_model_attributes,
260+
)
261+
262+
# Store handler + a logical component name (e.g., ctx_128)
263+
generated_handlers[ctx_len] = handler
264+
generated_names[ctx_len] = f"ctx_{ctx_len}"
265+
266+
# --- Step 5: Update genai_config.json ---
267+
# For single model: pipeline with one component.
268+
# For multiple models: pipeline with multiple components (composite).
269+
if not multiple:
270+
# Single context length
271+
ctx_len = ctx_lengths_list[0]
272+
handler = generated_handlers[ctx_len]
273+
274+
decoder_config_extra = {
275+
"inputs": {
276+
"past_sequence_length": "past_seq_len",
277+
"total_sequence_length": "total_seq_len",
278+
},
279+
"sliding_window": {
280+
"window_size": ctx_len,
281+
"pad_value": 0,
282+
"alignment": "left",
283+
"slide_key_value_cache": False,
284+
},
285+
}
286+
287+
handler = update_llm_pipeline_genai_config_gpu(
288+
model=handler,
289+
output_model_dir=output_model_dir,
290+
decoder_config_extra=decoder_config_extra,
291+
composite_components=None,
292+
)
293+
return handler
294+
295+
# Multiple context lengths -> wrap in CompositeModelHandler and create composite pipeline
296+
components = []
297+
component_names = []
298+
for ctx_len, handler in sorted(generated_handlers.items(), key=lambda kv: kv[0]):
299+
components.append(handler)
300+
component_names.append(generated_names[ctx_len])
301+
302+
new_model_attributes = deepcopy(model.model_attributes) or {}
303+
304+
composite = CompositeModelHandler(
305+
model_components=components, model_component_names=component_names, model_attributes=new_model_attributes
213306
)
214307

215-
decoder_config_extra = {
308+
# Build per-component sliding_window config keyed by name
309+
composite_decoder_extra = {
216310
"inputs": {
217311
"past_sequence_length": "past_seq_len",
218312
"total_sequence_length": "total_seq_len",
219313
},
220314
"sliding_window": {
221-
"window_size": config.context_length,
315+
"window_size": max(ctx_lengths_list),
222316
"pad_value": 0,
223317
"alignment": "left",
224318
"slide_key_value_cache": False,
225319
},
226320
}
227321

228-
input_model_path = model.model_path
229-
model_static = ONNXModelHandler(model_path=output_model_dir, onnx_file_name=output_model_file.name)
230-
231-
return update_llm_pipeline_genai_config_gpu(
232-
model_static,
233-
output_model_dir,
234-
input_model_path,
235-
decoder_config_extra,
322+
composite = update_llm_pipeline_genai_config_gpu(
323+
model=composite,
324+
output_model_dir=output_model_dir,
325+
decoder_config_extra=composite_decoder_extra,
326+
composite_components=list(zip(component_names, components)),
236327
)
237328

329+
return composite
330+
238331
@staticmethod
239332
def fix_shape(model_proto: onnx.ModelProto, param_mapping: dict[str, int]):
240333
"""Fix the shape of the model based on the param mapping.

0 commit comments

Comments
 (0)