|
10 | 10 | from functools import partial |
11 | 11 | from numbers import Number |
12 | 12 | from pathlib import Path |
13 | | -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union |
| 13 | +from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union |
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 | import torch |
@@ -121,15 +121,16 @@ def generate_metric_user_config_with_model_io(metric: Metric, model: "OliveModel |
121 | 121 |
|
122 | 122 | @staticmethod |
123 | 123 | def get_user_config(framework: Framework, metric: Metric): |
124 | | - assert metric.user_config, "user_config is not specified in the metric config" |
125 | | - |
126 | 124 | dataloader = None |
127 | 125 | eval_func = None |
128 | 126 | post_func = None |
129 | 127 |
|
130 | 128 | # load the evaluate function |
131 | 129 | # priority: evaluate_func > metric_func |
132 | 130 | if metric.type == MetricType.CUSTOM: |
| 131 | + if not metric.user_config: |
| 132 | + raise ValueError("user_config is required for CUSTOM metric type") |
| 133 | + |
133 | 134 | evaluate_func = getattr(metric.user_config, "evaluate_func", None) |
134 | 135 | kwargs = getattr(metric.user_config, "evaluate_func_kwargs", None) or {} |
135 | 136 | if not evaluate_func: |
@@ -217,7 +218,7 @@ def device_string_to_torch_device(device: Device): |
217 | 218 |
|
218 | 219 | @classmethod |
219 | 220 | def io_bind_enabled(cls, metric: Metric, inference_settings: dict) -> bool: |
220 | | - if metric.user_config.io_bind: |
| 221 | + if metric.user_config and metric.user_config.io_bind: |
221 | 222 | return True |
222 | 223 |
|
223 | 224 | return inference_settings and inference_settings.get("io_bind") |
@@ -307,7 +308,7 @@ def _evaluate_custom( |
307 | 308 | execution_providers=None, |
308 | 309 | ) -> MetricResult: |
309 | 310 | raw_res = None |
310 | | - if metric.user_config.evaluate_func: |
| 311 | + if metric.user_config and metric.user_config.evaluate_func: |
311 | 312 | raw_res = eval_func(model, device, execution_providers) |
312 | 313 | else: |
313 | 314 | inference_output, targets = self._inference( |
@@ -645,15 +646,15 @@ def _evaluate_distributed_latency_worker(config) -> list[float]: |
645 | 646 | batch = next(iter(dataloader)) |
646 | 647 | input_data = OliveEvaluator.extract_input_data(batch) |
647 | 648 | input_feed = format_data(input_data, io_config) |
648 | | - kv_cache_ortvalues = {} if metric.user_config.shared_kv_buffer else None |
| 649 | + kv_cache_ortvalues = {} if (metric.user_config and getattr(metric.user_config, 'shared_kv_buffer', None)) else None |
649 | 650 |
|
650 | 651 | io_bind = OnnxEvaluator.io_bind_enabled(metric, model.inference_settings) |
651 | 652 | if io_bind: |
652 | 653 | io_bind_op = prepare_io_bindings( |
653 | 654 | session, |
654 | 655 | input_feed, |
655 | 656 | Device.GPU, |
656 | | - shared_kv_buffer=metric.user_config.shared_kv_buffer, |
| 657 | + shared_kv_buffer=getattr(metric.user_config, 'shared_kv_buffer', None) if metric.user_config else None, |
657 | 658 | kv_cache_ortvalues=kv_cache_ortvalues, |
658 | 659 | ) |
659 | 660 | latencies = [] |
@@ -1123,7 +1124,7 @@ def evaluate( |
1123 | 1124 |
|
1124 | 1125 |
|
1125 | 1126 | class OliveEvaluatorConfig(NestedConfig): |
1126 | | - _nested_field_name = "type_args" |
| 1127 | + _nested_field_name: ClassVar[str] = "type_args" |
1127 | 1128 |
|
1128 | 1129 | name: Optional[str] = None |
1129 | 1130 | type: Optional[str] = None |
|
0 commit comments