Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions olive/assets/io_configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ visual_seq_length: 16
text_batch_size: 2
image_batch_size: 2
projection_dim: 512
text_encoder_projection_dim: 1280

# Audio
feature_size: 80
Expand Down
4 changes: 2 additions & 2 deletions olive/cli/capture_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def register_subcommand(parser: ArgumentParser):
pte_group.add_argument(
"--target_opset",
type=int,
default=17,
help="The target opset version for the ONNX model. Default is 17.",
default=20,
help="The target opset version for the ONNX model. Default is 20.",
)

# Model Builder options
Expand Down
20 changes: 16 additions & 4 deletions olive/common/hf/io_config/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
if TYPE_CHECKING:
from transformers import PretrainedConfig

from olive.constants import DiffusersModelVariant


class DummyInputGenerator(ABC):
"""Generate dummy inputs for ONNX export."""
Expand Down Expand Up @@ -70,11 +72,11 @@ class DiffusersDummyInputGenerator(DummyInputGenerator):
Reads input specifications from diffusers.yaml.
"""

def __init__(self, component_name: str, config: PretrainedConfig):
def __init__(self, component_name: str, config: PretrainedConfig, pipeline: DiffusersModelVariant | None = None):
self.component_name = component_name
self.config = config

self.component_spec = get_diffusers_component_config(component_name)
self.component_spec = get_diffusers_component_config(component_name, pipeline=pipeline)
if self.component_spec is None:
raise ValueError(f"Unknown diffusers component: {component_name}")

Expand Down Expand Up @@ -145,20 +147,22 @@ def generate(self, input_name: str):
def generate_diffusers_dummy_inputs(
component_name: str,
config: PretrainedConfig,
pipeline: DiffusersModelVariant | None = None,
) -> dict[str, Any]:
"""Create all dummy inputs for a diffusers component.

Args:
component_name: Name of the diffusers component (e.g., "unet", "vae_encoder")
config: The component's config object
pipeline: Pipeline variant (e.g., "sdxl", "flux")

Returns:
Dict of input_name -> tensor

"""
generator = DiffusersDummyInputGenerator(component_name, config)
generator = DiffusersDummyInputGenerator(component_name, config, pipeline=pipeline)

component_spec = get_diffusers_component_config(component_name)
component_spec = get_diffusers_component_config(component_name, pipeline=pipeline)
if component_spec is None:
raise ValueError(f"Unknown diffusers component: {component_name}")

Expand All @@ -168,6 +172,14 @@ def generate_diffusers_dummy_inputs(
for input_name in component_spec.get("inputs", {}):
dummy_inputs[input_name] = generator.generate(input_name)

# Generate SDXL-specific inputs (e.g., text_embeds, time_ids for UNet)
is_sdxl = getattr(config, "addition_embed_type", None) == "text_time"
if is_sdxl and "sdxl_inputs" in component_spec:
sdxl_inputs = {}
for input_name in component_spec["sdxl_inputs"]:
sdxl_inputs[input_name] = generator.generate(input_name)
dummy_inputs["added_cond_kwargs"] = sdxl_inputs

# Generate optional inputs
for input_name, spec in component_spec.get("optional_inputs", {}).items():
if "condition" in spec and getattr(config, spec["condition"], False):
Expand Down
28 changes: 24 additions & 4 deletions olive/common/hf/io_config/io_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import functools
import logging
from importlib.resources import files
from typing import Any
from typing import TYPE_CHECKING, Any

import yaml

if TYPE_CHECKING:
from olive.constants import DiffusersModelVariant

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -54,19 +57,36 @@ def get_task_template(task: str) -> dict[str, Any] | None:
return tasks.get(task)


def get_diffusers_component_config(component_name: str) -> dict[str, Any] | None:
def get_diffusers_component_config(
component_name: str, pipeline: DiffusersModelVariant | None = None
) -> dict[str, Any] | None:
"""Get diffusers component configuration.

Args:
component_name: Component name (e.g., "text_encoder", "unet").
component_name: Pipeline component name (e.g., "text_encoder_2", "transformer").
pipeline: Pipeline variant (e.g., "sdxl", "flux") to resolve names like
"text_encoder_2" that map to different configs per pipeline.

Returns:
Component configuration dict, or None if not found.

"""
diffusers = _load_diffusers()
components = diffusers.get("components", {})
return components.get(component_name)

# Direct match
if component_name in components:
return components[component_name]

# Resolve via pipeline definition (e.g., "text_encoder_with_projection:text_encoder_2")
if pipeline:
for entry in diffusers.get("pipelines", {}).get(pipeline, []):
if ":" in entry:
config_type, pipe_name = entry.split(":", 1)
if pipe_name == component_name:
return components.get(config_type)

return None


def get_default_shapes() -> dict[str, int]:
Expand Down
6 changes: 5 additions & 1 deletion olive/common/hf/io_config/task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel

from olive.constants import DiffusersModelVariant

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -365,20 +367,22 @@ def generate_dummy_inputs(
def get_diffusers_io_config(
component_name: str,
config: PretrainedConfig,
pipeline: DiffusersModelVariant | None = None,
**kwargs,
) -> dict[str, Any]:
"""Get IO configuration for a diffusers component.

Args:
component_name: Component name (e.g., "text_encoder", "unet").
config: Component's config.
pipeline: Pipeline variant (e.g., "sdxl", "flux").
**kwargs: Additional arguments (e.g., is_sdxl).

Returns:
Dict containing input_names, output_names, dynamic_axes.

"""
component_config = get_diffusers_component_config(component_name)
component_config = get_diffusers_component_config(component_name, pipeline=pipeline)
if component_config is None:
raise ValueError(f"Unknown diffusers component: {component_name}")

Expand Down
4 changes: 3 additions & 1 deletion olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,14 @@ def _convert_diffusers_model(
dummy_inputs = generate_diffusers_dummy_inputs(
component_name=component_name,
config=component_config,
pipeline=pipeline_type,
)

# Get IO config using new task-driven API
io_config = get_diffusers_io_config(
component_name=component_name,
config=component_config,
pipeline=pipeline_type,
)

# Create output directory for this component
Expand All @@ -693,7 +695,7 @@ def _convert_diffusers_model(
io_config=io_config,
config=config,
device=device,
dynamo=config.use_dynamo_exporter,
dynamo=True,
torch_dtype=torch_dtype,
)

Expand Down
Loading