Skip to content

Commit 59ed57f

Browse files
Copilotshaahji
andcommitted
Remove abc.ABC in favor of ConfigBase pattern - use _is_base_class flag
Co-authored-by: shaahji <96227573+shaahji@users.noreply.github.com>
1 parent 228d537 commit 59ed57f

File tree

4 files changed

+89
-39
lines changed

4 files changed

+89
-39
lines changed

olive/evaluator/accuracy.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
import inspect
65
import logging
7-
from abc import ABC, abstractmethod
86
from inspect import isfunction, signature
97
from typing import Any, Callable, ClassVar, Optional, Union
108

@@ -17,9 +15,12 @@
1715
logger = logging.getLogger(__name__)
1816

1917

20-
class AccuracyBase(ABC):
18+
class AccuracyBase:
19+
"""Base class for accuracy metrics."""
20+
2121
registry: ClassVar[dict[str, type["AccuracyBase"]]] = {}
2222
name: Optional[str] = None
23+
_is_base_class: bool = True
2324
metric_cls_map: ClassVar[dict[str, Union[torchmetrics.Metric, Callable]]] = {
2425
"accuracy_score": torchmetrics.Accuracy,
2526
"f1_score": torchmetrics.F1Score,
@@ -40,7 +41,10 @@ def __init__(self, config: Union[ConfigBase, dict[str, Any]] = None) -> None:
4041
def __init_subclass__(cls, **kwargs) -> None:
4142
"""Register the metric."""
4243
super().__init_subclass__(**kwargs)
43-
if inspect.isabstract(cls):
44+
# Subclasses should not be considered base classes unless explicitly set
45+
if '_is_base_class' not in cls.__dict__:
46+
cls._is_base_class = False
47+
if cls._is_base_class:
4448
return
4549
name = cls.name if cls.name is not None else cls.__name__.lower()
4650
cls.registry[name] = cls
@@ -78,7 +82,10 @@ def _metric_config_from_torch_metrics(cls):
7882
@classmethod
7983
def get_config_class(cls) -> type[ConfigBase]:
8084
"""Get the configuration class."""
81-
assert not inspect.isabstract(cls), "Cannot get config class for abstract class"
85+
if '_is_base_class' not in cls.__dict__:
86+
cls._is_base_class = False
87+
if cls._is_base_class:
88+
raise TypeError(f"Cannot get config class for base class {cls.__name__}")
8289
default_config = cls._metric_config_from_torch_metrics()
8390
return create_config_class(f"{cls.__name__}Config", default_config, ConfigBase, {})
8491

@@ -90,9 +97,19 @@ def prepare_tensors(preds, target, dtypes=torch.int):
9097
target = torch.tensor(target, dtype=dtypes[1]) if not isinstance(target, torch.Tensor) else target.to(dtypes[1])
9198
return preds, target
9299

93-
@abstractmethod
94100
def measure(self, model_output, target):
95-
raise NotImplementedError
101+
"""Measure the metric.
102+
103+
Subclasses must implement this method.
104+
105+
Args:
106+
model_output: Model output
107+
target: Target values
108+
109+
Returns:
110+
Measured metric value
111+
"""
112+
raise NotImplementedError(f"{self.__class__.__name__} must implement measure")
96113

97114

98115
class AccuracyScore(AccuracyBase):

olive/evaluator/metric_backend.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
import inspect
6-
from abc import ABC, abstractmethod
75
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union
86

97
from olive.common.config_utils import ConfigBase
@@ -19,9 +17,12 @@ class MetricBackendConfig(ConfigBase):
1917

2018

2119

22-
class MetricBackend(ABC):
20+
class MetricBackend:
21+
"""Base class for metric backends."""
22+
2323
registry: ClassVar[dict[str, type["MetricBackend"]]] = {}
2424
name: Optional[str] = None
25+
_is_base_class: bool = True
2526

2627
def __init__(self, config: Optional[Union[ConfigBase, dict[str, Any]]] = None) -> None:
2728
config = config or {}
@@ -33,7 +34,10 @@ def __init__(self, config: Optional[Union[ConfigBase, dict[str, Any]]] = None) -
3334
def __init_subclass__(cls, **kwargs) -> None:
3435
"""Register the metric backend."""
3536
super().__init_subclass__(**kwargs)
36-
if inspect.isabstract(cls):
37+
# Subclasses should not be considered base classes unless explicitly set
38+
if '_is_base_class' not in cls.__dict__:
39+
cls._is_base_class = False
40+
if cls._is_base_class:
3741
return
3842
name = cls.name if cls.name is not None else cls.__name__.lower()
3943
cls.registry[name] = cls
@@ -43,12 +47,22 @@ def get_config_class(cls) -> type[ConfigBase]:
4347
"""Get the configuration class."""
4448
return MetricBackendConfig
4549

46-
@abstractmethod
4750
def measure_sub_metric(
4851
self, model_output: Union[tuple, NamedTuple], targets: Any, sub_metric: "SubMetric"
4952
) -> SubMetricResult:
50-
# model_output: (preds, logits)
51-
raise NotImplementedError
53+
"""Measure a sub-metric.
54+
55+
Subclasses must implement this method.
56+
57+
Args:
58+
model_output: (preds, logits)
59+
targets: Target values
60+
sub_metric: Sub-metric to measure
61+
62+
Returns:
63+
SubMetricResult with the measurement
64+
"""
65+
raise NotImplementedError(f"{self.__class__.__name__} must implement measure_sub_metric")
5266

5367
def measure(self, model_output, targets, metrics: "Metric") -> MetricResult:
5468
metric_results_dict = {}

olive/resource_path.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
import inspect
65
import logging
76
import shutil
87
import tempfile
9-
from abc import ABC, abstractmethod
108
from copy import deepcopy
119
from pathlib import Path
1210
from typing import Any, Callable, ClassVar, Optional, Union
@@ -37,9 +35,12 @@ class ResourceType(CaseInsensitiveEnum):
3735
LOCAL_RESOURCE_TYPES = (ResourceType.LocalFile, ResourceType.LocalFolder)
3836

3937

40-
class ResourcePath(ABC):
38+
class ResourcePath:
39+
"""Base class for resource paths."""
40+
4141
registry: ClassVar[dict[str, type["ResourcePath"]]] = {}
4242
name: Optional[ResourceType] = None
43+
_is_base_class: bool = True
4344

4445
def __init__(self, config: Union[ConfigBase, dict[str, Any]]) -> None:
4546
if isinstance(config, dict):
@@ -50,16 +51,21 @@ def __init__(self, config: Union[ConfigBase, dict[str, Any]]) -> None:
5051
def __init_subclass__(cls, **kwargs) -> None:
5152
"""Register the resource path."""
5253
super().__init_subclass__(**kwargs)
53-
if inspect.isabstract(cls):
54+
# Subclasses should not be considered base classes unless explicitly set
55+
if '_is_base_class' not in cls.__dict__:
56+
cls._is_base_class = False
57+
if cls._is_base_class:
5458
return
5559
name = cls.name if cls.name is not None else cls.__name__.lower()
5660
cls.registry[name] = cls
5761

5862
@classmethod
59-
@abstractmethod
6063
def _default_config(cls) -> dict[str, ConfigParam]:
61-
"""Get the default configuration for the class."""
62-
raise NotImplementedError
64+
"""Get the default configuration for the class.
65+
66+
Subclasses must implement this method.
67+
"""
68+
raise NotImplementedError(f"{cls.__name__} must implement _default_config")
6369

6470
@classmethod
6571
def _validators(cls) -> dict[str, Callable]:
@@ -69,7 +75,10 @@ def _validators(cls) -> dict[str, Callable]:
6975
@classmethod
7076
def get_config_class(cls) -> type[ConfigBase]:
7177
"""Get the configuration class."""
72-
assert not inspect.isabstract(cls), "Cannot get config class for abstract class"
78+
if '_is_base_class' not in cls.__dict__:
79+
cls._is_base_class = False
80+
if cls._is_base_class:
81+
raise TypeError(f"Cannot get config class for base class {cls.__name__}")
7382
return create_config_class(f"{cls.__name__}Config", cls._default_config(), ConfigBase, cls._validators())
7483

7584
def __repr__(self) -> str:
@@ -79,15 +88,19 @@ def __repr__(self) -> str:
7988
def type(self) -> Optional[ResourceType]:
8089
return self.name
8190

82-
@abstractmethod
8391
def get_path(self) -> str:
84-
"""Return the resource path as a string."""
85-
raise NotImplementedError
92+
"""Return the resource path as a string.
93+
94+
Subclasses must implement this method.
95+
"""
96+
raise NotImplementedError(f"{self.__class__.__name__} must implement get_path")
8697

87-
@abstractmethod
8898
def save_to_dir(self, dir_path: Union[Path, str], name: Optional[str] = None, overwrite: bool = False) -> str:
89-
"""Save the resource to a directory."""
90-
raise NotImplementedError
99+
"""Save the resource to a directory.
100+
101+
Subclasses must implement this method.
102+
"""
103+
raise NotImplementedError(f"{self.__class__.__name__} must implement save_to_dir")
91104

92105
def is_local_resource(self) -> bool:
93106
"""Return True if the resource is a local resource."""

olive/search/samplers/search_sampler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
import inspect
6-
from abc import ABC, abstractmethod
75
from collections import OrderedDict
86
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
97

@@ -21,11 +19,12 @@ class SearchSamplerConfig(ConfigBase):
2119
max_samples: int = 0
2220

2321

24-
class SearchSampler(ABC):
25-
"""Abstract base class for searchers."""
22+
class SearchSampler:
23+
"""Base class for search samplers."""
2624

2725
registry: ClassVar[dict[str, type["SearchSampler"]]] = {}
2826
name: Optional[str] = None
27+
_is_base_class: bool = True
2928

3029
def __init__(
3130
self,
@@ -52,7 +51,10 @@ def __init__(
5251
def __init_subclass__(cls, **kwargs) -> None:
5352
"""Register the search sampler."""
5453
super().__init_subclass__(**kwargs)
55-
if inspect.isabstract(cls):
54+
# Subclasses should not be considered base classes unless explicitly set
55+
if '_is_base_class' not in cls.__dict__:
56+
cls._is_base_class = False
57+
if cls._is_base_class:
5658
return
5759
name = cls.name if cls.name is not None else cls.__name__.lower()
5860
cls.registry[name] = cls
@@ -63,10 +65,12 @@ def get_config_class(cls) -> type[ConfigBase]:
6365
return SearchSamplerConfig
6466

6567
@property
66-
@abstractmethod
6768
def num_samples_suggested(self) -> int:
68-
"""Returns the number of samples suggested so far."""
69-
return 0
69+
"""Returns the number of samples suggested so far.
70+
71+
Subclasses must implement this property.
72+
"""
73+
raise NotImplementedError(f"{self.__class__.__name__} must implement num_samples_suggested")
7074

7175
@property
7276
def max_samples(self) -> int:
@@ -82,10 +86,12 @@ def should_stop(self) -> bool:
8286
or ((self.max_samples > 0) and (self.num_samples_suggested >= self.max_samples))
8387
)
8488

85-
@abstractmethod
8689
def suggest(self) -> "SearchPoint":
87-
"""Suggest a new configuration to try."""
88-
return None
90+
"""Suggest a new configuration to try.
91+
92+
Subclasses must implement this method.
93+
"""
94+
raise NotImplementedError(f"{self.__class__.__name__} must implement suggest")
8995

9096
def record_feedback_signal(self, search_point_index: int, signal: "MetricResult", should_prune: bool = False):
9197
"""Report the result of a configuration."""

0 commit comments

Comments
 (0)