Skip to content

Commit 03bef3c

Browse files
committed
Use DiffusersModelVariant
1 parent 3fa6a6c commit 03bef3c

File tree

4 files changed

+16
-7
lines changed

4 files changed

+16
-7
lines changed

olive/common/hf/io_config/input_generators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
if TYPE_CHECKING:
2020
from transformers import PretrainedConfig
2121

22+
from olive.constants import DiffusersModelVariant
23+
2224

2325
class DummyInputGenerator(ABC):
2426
"""Generate dummy inputs for ONNX export."""
@@ -70,7 +72,7 @@ class DiffusersDummyInputGenerator(DummyInputGenerator):
7072
Reads input specifications from diffusers.yaml.
7173
"""
7274

73-
def __init__(self, component_name: str, config: PretrainedConfig, pipeline: str | None = None):
75+
def __init__(self, component_name: str, config: PretrainedConfig, pipeline: DiffusersModelVariant | None = None):
7476
self.component_name = component_name
7577
self.config = config
7678

@@ -145,7 +147,7 @@ def generate(self, input_name: str):
145147
def generate_diffusers_dummy_inputs(
146148
component_name: str,
147149
config: PretrainedConfig,
148-
pipeline: str | None = None,
150+
pipeline: DiffusersModelVariant | None = None,
149151
) -> dict[str, Any]:
150152
"""Create all dummy inputs for a diffusers component.
151153

olive/common/hf/io_config/io_resolver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import functools
88
import logging
99
from importlib.resources import files
10-
from typing import Any
10+
from typing import TYPE_CHECKING, Any
11+
12+
if TYPE_CHECKING:
13+
from olive.constants import DiffusersModelVariant
1114

1215
import yaml
1316

@@ -54,7 +57,9 @@ def get_task_template(task: str) -> dict[str, Any] | None:
5457
return tasks.get(task)
5558

5659

57-
def get_diffusers_component_config(component_name: str, pipeline: str | None = None) -> dict[str, Any] | None:
60+
def get_diffusers_component_config(
61+
component_name: str, pipeline: DiffusersModelVariant | None = None
62+
) -> dict[str, Any] | None:
5863
"""Get diffusers component configuration.
5964
6065
Args:

olive/common/hf/io_config/task_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
if TYPE_CHECKING:
1818
from transformers import PretrainedConfig, PreTrainedModel
1919

20+
from olive.constants import DiffusersModelVariant
21+
2022
logger = logging.getLogger(__name__)
2123

2224

@@ -365,7 +367,7 @@ def generate_dummy_inputs(
365367
def get_diffusers_io_config(
366368
component_name: str,
367369
config: PretrainedConfig,
368-
pipeline: str | None = None,
370+
pipeline: DiffusersModelVariant | None = None,
369371
**kwargs,
370372
) -> dict[str, Any]:
371373
"""Get IO configuration for a diffusers component.

olive/passes/onnx/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,14 +673,14 @@ def _convert_diffusers_model(
673673
dummy_inputs = generate_diffusers_dummy_inputs(
674674
component_name=component_name,
675675
config=component_config,
676-
pipeline=str(pipeline_type),
676+
pipeline=pipeline_type,
677677
)
678678

679679
# Get IO config using new task-driven API
680680
io_config = get_diffusers_io_config(
681681
component_name=component_name,
682682
config=component_config,
683-
pipeline=str(pipeline_type),
683+
pipeline=pipeline_type,
684684
)
685685

686686
# Create output directory for this component

0 commit comments

Comments
 (0)