Skip to content

Commit 1646c83

Browse files
Copilotshaahji
andcommitted
Fix test_olive_evaluator.py - ClassVar for _nested_field_name and safe user_config access
Co-authored-by: shaahji <96227573+shaahji@users.noreply.github.com>
1 parent cd61881 commit 1646c83

File tree

7 files changed

+26
-20
lines changed

7 files changed

+26
-20
lines changed

olive/data/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
from copy import deepcopy
88
from pathlib import Path
9-
from typing import TYPE_CHECKING, Optional, Union
9+
from typing import TYPE_CHECKING, ClassVar, Optional, Union
1010

1111
from olive.common.config_utils import ConfigBase, NestedConfig, validate_lowercase
1212
from olive.common.import_lib import import_user_module
@@ -21,7 +21,7 @@
2121

2222

2323
class DataComponentConfig(NestedConfig):
24-
_nested_field_name = "params"
24+
_nested_field_name: ClassVar[str] = "params"
2525

2626
type: str = None
2727
params: dict = Field(default_factory=dict)

olive/evaluator/metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55
import logging
6-
from typing import Any, Optional, Union
6+
from typing import Any, Callable, ClassVar, Optional, Union
77

88
from olive.common.config_utils import ConfigBase, NestedConfig, validate_config
99
from olive.common.pydantic_v1 import field_validator
@@ -105,13 +105,13 @@ def validate_goal(cls, v, info):
105105

106106

107107
class Metric(NestedConfig):
108-
_nested_field_name = "user_config"
108+
_nested_field_name: ClassVar[str] = "user_config"
109109

110110
name: str
111111
type: MetricType
112112
backend: Optional[str] = "torch_metrics"
113113
sub_types: list[SubMetric]
114-
user_config: ConfigBase = None
114+
user_config: Optional[ConfigBase] = None
115115
data_config: Optional[DataConfig] = None
116116

117117
def get_inference_settings(self, framework):

olive/evaluator/olive_evaluator.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from functools import partial
1111
from numbers import Number
1212
from pathlib import Path
13-
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
13+
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union
1414

1515
import numpy as np
1616
import torch
@@ -121,15 +121,16 @@ def generate_metric_user_config_with_model_io(metric: Metric, model: "OliveModel
121121

122122
@staticmethod
123123
def get_user_config(framework: Framework, metric: Metric):
124-
assert metric.user_config, "user_config is not specified in the metric config"
125-
126124
dataloader = None
127125
eval_func = None
128126
post_func = None
129127

130128
# load the evaluate function
131129
# priority: evaluate_func > metric_func
132130
if metric.type == MetricType.CUSTOM:
131+
if not metric.user_config:
132+
raise ValueError("user_config is required for CUSTOM metric type")
133+
133134
evaluate_func = getattr(metric.user_config, "evaluate_func", None)
134135
kwargs = getattr(metric.user_config, "evaluate_func_kwargs", None) or {}
135136
if not evaluate_func:
@@ -217,7 +218,7 @@ def device_string_to_torch_device(device: Device):
217218

218219
@classmethod
219220
def io_bind_enabled(cls, metric: Metric, inference_settings: dict) -> bool:
220-
if metric.user_config.io_bind:
221+
if metric.user_config and metric.user_config.io_bind:
221222
return True
222223

223224
return inference_settings and inference_settings.get("io_bind")
@@ -307,7 +308,7 @@ def _evaluate_custom(
307308
execution_providers=None,
308309
) -> MetricResult:
309310
raw_res = None
310-
if metric.user_config.evaluate_func:
311+
if metric.user_config and metric.user_config.evaluate_func:
311312
raw_res = eval_func(model, device, execution_providers)
312313
else:
313314
inference_output, targets = self._inference(
@@ -645,15 +646,15 @@ def _evaluate_distributed_latency_worker(config) -> list[float]:
645646
batch = next(iter(dataloader))
646647
input_data = OliveEvaluator.extract_input_data(batch)
647648
input_feed = format_data(input_data, io_config)
648-
kv_cache_ortvalues = {} if metric.user_config.shared_kv_buffer else None
649+
kv_cache_ortvalues = {} if (metric.user_config and getattr(metric.user_config, 'shared_kv_buffer', None)) else None
649650

650651
io_bind = OnnxEvaluator.io_bind_enabled(metric, model.inference_settings)
651652
if io_bind:
652653
io_bind_op = prepare_io_bindings(
653654
session,
654655
input_feed,
655656
Device.GPU,
656-
shared_kv_buffer=metric.user_config.shared_kv_buffer,
657+
shared_kv_buffer=getattr(metric.user_config, 'shared_kv_buffer', None) if metric.user_config else None,
657658
kv_cache_ortvalues=kv_cache_ortvalues,
658659
)
659660
latencies = []
@@ -1123,7 +1124,7 @@ def evaluate(
11231124

11241125

11251126
class OliveEvaluatorConfig(NestedConfig):
1126-
_nested_field_name = "type_args"
1127+
_nested_field_name: ClassVar[str] = "type_args"
11271128

11281129
name: Optional[str] = None
11291130
type: Optional[str] = None

olive/model/config/hf_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# --------------------------------------------------------------------------
55
import logging
66
from copy import deepcopy
7-
from typing import Any, Optional, Union
7+
from typing import Any, ClassVar, Optional, Union
88

99
import torch
1010
import transformers
@@ -22,7 +22,7 @@ class HfLoadKwargs(NestedConfig):
2222
Refer to https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained
2323
"""
2424

25-
_nested_field_name = "extra_args"
25+
_nested_field_name: ClassVar[str] = "extra_args"
2626

2727
torch_dtype: str = Field(
2828
None,

olive/passes/pytorch/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import dataclasses
66
import logging
77
from copy import deepcopy
8-
from typing import TYPE_CHECKING, Any, Optional, Union
8+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
99

1010
import torch
1111
import transformers
@@ -33,7 +33,7 @@
3333
class BaseHFTrainingArguments(NestedConfig):
3434
"""Training arguments for transformers.Trainer."""
3535

36-
_nested_field_name = "extra_args"
36+
_nested_field_name: ClassVar[str] = "extra_args"
3737

3838
gradient_checkpointing: bool = Field(True, description="Use gradient checkpointing. Recommended.")
3939
report_to: Union[str, list[str]] = Field(

olive/workflows/run/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# --------------------------------------------------------------------------
55
import shutil
66
from pathlib import Path
7-
from typing import Any, Union
7+
from typing import Any, ClassVar, Optional, Union
88

99
from olive.cache import CacheConfig
1010
from olive.common.config_utils import NestedConfig, validate_config
@@ -109,7 +109,7 @@ class RunConfig(NestedConfig):
109109
evaluators, engine, passes, and auto optimizer.
110110
"""
111111

112-
_nested_field_name = "engine"
112+
_nested_field_name: ClassVar[str] = "engine"
113113

114114
workflow_id: str = Field(
115115
DEFAULT_WORKFLOW_ID, description="Workflow ID. If not provided, use the default ID 'default_workflow'."

test/evaluator/test_olive_evaluator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,12 @@ def test_evaluator_get_inference_session(self, metric_inference_settings, model_
379379
"""
380380
metric = get_latency_metric(LatencySubType.AVG)
381381
if metric_inference_settings:
382-
metric.user_config.inference_settings = {"onnx": metric_inference_settings.copy()}
382+
# Initialize user_config if needed to set inference_settings
383+
if metric.user_config is None:
384+
from olive.common.config_utils import ConfigBase
385+
metric.user_config = ConfigBase()
386+
# Use object.__setattr__ to set dynamic attributes on ConfigBase
387+
object.__setattr__(metric.user_config, 'inference_settings', {"onnx": metric_inference_settings.copy()})
383388
model = get_onnx_model()
384389
model.inference_settings = model_inference_settings.copy() if model_inference_settings else None
385390
inference_settings = OnnxEvaluator.get_inference_settings(metric, model)

0 commit comments

Comments
 (0)