|
25 | 25 | from olive.data.config import DataConfig |
26 | 26 | from olive.data.container.dummy_data_container import TRANSFORMER_DUMMY_DATA_CONTAINER |
27 | 27 | from olive.data.template import dummy_data_config_template |
28 | | -from olive.evaluator.metric import LatencySubType, Metric, MetricType, ThroughputSubType, get_latency_config_from_metric |
| 28 | +from olive.evaluator.metric import ( |
| 29 | + LatencySubType, |
| 30 | + Metric, |
| 31 | + MetricType, |
| 32 | + SizeOnDiskSubType, |
| 33 | + ThroughputSubType, |
| 34 | + get_latency_config_from_metric, |
| 35 | +) |
29 | 36 | from olive.evaluator.metric_backend import MetricBackend |
30 | 37 | from olive.evaluator.metric_result import MetricResult, SubMetricResult, flatten_metric_result, joint_metric_key |
31 | 38 | from olive.evaluator.registry import Registry |
@@ -276,6 +283,19 @@ def _evaluate_throughput( |
276 | 283 | latencies = self._evaluate_raw_latency(model, metric, dataloader, post_func, device, execution_providers) |
277 | 284 | return OliveEvaluator.compute_throughput(metric, latencies) |
278 | 285 |
|
| 286 | + def _evaluate_size_on_disk( |
| 287 | + self, |
| 288 | + model: "OliveModelHandler", |
| 289 | + metric: Metric, |
| 290 | + dataloader: "DataLoader", |
| 291 | + post_func=None, |
| 292 | + device: Device = Device.CPU, |
| 293 | + execution_providers: Union[str, list[str]] = None, |
| 294 | + ) -> MetricResult: |
| 295 | + return MetricResult.parse_obj( |
| 296 | + {SizeOnDiskSubType.BYTES.value: {"value": model.size_on_disk, "priority": -1, "higher_is_better": False}} |
| 297 | + ) |
| 298 | + |
279 | 299 | def _evaluate_custom( |
280 | 300 | self, |
281 | 301 | model: "OliveModelHandler", |
@@ -335,6 +355,10 @@ def evaluate( |
335 | 355 | metrics_res[metric.name] = self._evaluate_throughput( |
336 | 356 | model, metric, dataloader, post_func, device, execution_providers |
337 | 357 | ) |
| 358 | + elif metric.type == MetricType.SIZE_ON_DISK: |
| 359 | + metrics_res[metric.name] = self._evaluate_size_on_disk( |
| 360 | + model, metric, dataloader, post_func, device, execution_providers |
| 361 | + ) |
338 | 362 | elif metric.type == MetricType.CUSTOM: |
339 | 363 | metrics_res[metric.name] = self._evaluate_custom( |
340 | 364 | model, metric, dataloader, eval_func, post_func, device, execution_providers |
@@ -1056,30 +1080,44 @@ def evaluate( |
1056 | 1080 | self.model_class, |
1057 | 1081 | {k: v for k, v in init_args.items() if k in ["device", "ep", "ep_options"]}, |
1058 | 1082 | ) |
1059 | | - lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length) |
1060 | | - |
1061 | | - results = simple_evaluate( |
1062 | | - model=lmmodel, |
1063 | | - tasks=self.tasks, |
1064 | | - task_manager=TaskManager(), |
1065 | | - log_samples=False, |
1066 | | - batch_size=self.batch_size, |
1067 | | - device=device, |
1068 | | - limit=self.limit, |
1069 | | - ) |
1070 | 1083 |
|
1071 | 1084 | metrics = {} |
1072 | | - for task_name in sorted(results["results"].keys()): |
1073 | | - metric_items = sorted(results["results"][task_name].items()) |
| 1085 | + if MetricType.SIZE_ON_DISK.value in self.tasks: |
| 1086 | + self.tasks.remove(MetricType.SIZE_ON_DISK.value) |
| 1087 | + metrics[MetricType.SIZE_ON_DISK.value] = MetricResult.parse_obj( |
| 1088 | + { |
| 1089 | + SizeOnDiskSubType.BYTES.value: { |
| 1090 | + "value": model.size_on_disk, |
| 1091 | + "priority": -1, |
| 1092 | + "higher_is_better": False, |
| 1093 | + } |
| 1094 | + } |
| 1095 | + ) |
| 1096 | + |
| 1097 | + if self.tasks: |
| 1098 | + lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length) |
| 1099 | + |
| 1100 | + results = simple_evaluate( |
| 1101 | + model=lmmodel, |
| 1102 | + tasks=self.tasks, |
| 1103 | + task_manager=TaskManager(), |
| 1104 | + log_samples=False, |
| 1105 | + batch_size=self.batch_size, |
| 1106 | + device=device, |
| 1107 | + limit=self.limit, |
| 1108 | + ) |
| 1109 | + |
| 1110 | + for task_name in sorted(results["results"].keys()): |
| 1111 | + metric_items = sorted(results["results"][task_name].items()) |
1074 | 1112 |
|
1075 | | - task_metrics = {} |
1076 | | - for mf, v in metric_items: |
1077 | | - if mf != "alias": |
1078 | | - m, _ = mf.split(",", 1) |
1079 | | - if not m.endswith("_stderr"): |
1080 | | - task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) |
| 1113 | + task_metrics = {} |
| 1114 | + for mf, v in metric_items: |
| 1115 | + if mf != "alias": |
| 1116 | + m, _ = mf.split(",", 1) |
| 1117 | + if not m.endswith("_stderr"): |
| 1118 | + task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) |
1081 | 1119 |
|
1082 | | - metrics[task_name] = MetricResult.parse_obj(task_metrics) |
| 1120 | + metrics[task_name] = MetricResult.parse_obj(task_metrics) |
1083 | 1121 |
|
1084 | 1122 | return flatten_metric_result(metrics) |
1085 | 1123 |
|
|
0 commit comments