Skip to content

Commit 907c3eb

Browse files
committed
Apply suggestions from code review
Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com>
1 parent e718172 commit 907c3eb

3 files changed

Lines changed: 18 additions & 21 deletions

File tree

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from power_grid_model_ds._core.model.arrays.base._string import convert_array_to_string
2020
from power_grid_model_ds._core.model.arrays.base.errors import ArrayDefinitionError
2121
from power_grid_model_ds._core.model.constants import EMPTY_ID, empty
22-
from power_grid_model_ds._core.utils.misc import array_equal_with_nan, build_mro_attribute, get_public_annotations
22+
from power_grid_model_ds._core.utils.misc import (
23+
array_equal_with_nan,
24+
combine_attribute_from_parent_classes,
25+
get_public_annotations,
26+
)
2327

2428
# pylint: disable=missing-function-docstring, too-many-public-methods
2529

@@ -80,13 +84,13 @@ def data(self) -> NDArray:
8084
@classmethod
8185
@lru_cache
8286
def get_defaults(cls) -> dict[str, Any]:
83-
return build_mro_attribute(cls, "_defaults", attribute_type=dict)
87+
return combine_attribute_from_parent_classes(cls, "_defaults", attribute_type=dict)
8488

8589
@classmethod
8690
@lru_cache
8791
def get_dtype(cls):
8892
annotations = get_public_annotations(cls)
89-
str_lengths = build_mro_attribute(cls, "_str_lengths", dict)
93+
str_lengths = combine_attribute_from_parent_classes(cls, "_str_lengths", dict)
9094
dtypes = {}
9195
for name, dtype in annotations.items():
9296
if len(dtype.__args__) > 1:

src/power_grid_model_ds/_core/utils/misc.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""Misc utils"""
66

77
from collections.abc import Sequence
8-
from typing import get_type_hints, overload
8+
from typing import get_type_hints
99

1010
import numpy as np
1111

@@ -30,34 +30,27 @@ def get_public_annotations(cls: type):
3030
return {attr: type_ for attr, type_ in class_attributes.items() if not attr.startswith("_")}
3131

3232

33-
@overload
34-
def build_mro_attribute(cls: type, attribute_name: str, attribute_type: type[dict]) -> dict: ...
35-
36-
37-
@overload
38-
def build_mro_attribute(cls: type, attribute_name: str, attribute_type: type[set]) -> set: ...
39-
40-
41-
def build_mro_attribute(cls: type, attribute_name: str, attribute_type) -> dict | set:
33+
def combine_attribute_from_parent_classes[T: (dict, set)](cls: type, attribute_name: str, attribute_type: type[T]) -> T:
4234
"""Combine all versions of an attribute in the Method Resolution Order (mro) of a class into a single attribute
4335
4436
For dicts this means the dict is updated so that child classes override parent classes.
4537
For sets this means the sets are unioned together.
4638
4739
Types other than dict and set are not supported
4840
"""
49-
attr_value = attribute_type()
41+
combined_attr = attribute_type()
5042
for parent in reversed(list(cls.__mro__)):
43+
parent_attr = getattr(parent, attribute_name, attribute_type())
5144
if attribute_type is dict:
52-
attr_value.update(getattr(parent, attribute_name, {}))
45+
combined_attr.update(parent_attr)
5346
elif attribute_type is set:
54-
attr_value |= getattr(parent, attribute_name, set())
47+
combined_attr |= parent_attr
5548
else:
5649
raise NotImplementedError(
5750
f"Type {attribute_type} cannot combine inherited for attribute {attribute_name}. "
5851
f"Only dict and set are currently supported."
5952
)
60-
return attr_value
53+
return combined_attr
6154

6255

6356
def array_equal_with_nan(array1: np.ndarray, array2: np.ndarray) -> bool:

tests/unit/utils/test_misc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from power_grid_model_ds._core.utils.misc import (
1010
array_equal_with_nan,
11-
build_mro_attribute,
11+
combine_attribute_from_parent_classes,
1212
get_public_annotations,
1313
is_sequence,
1414
)
@@ -67,16 +67,16 @@ def test_get_public_class_attrs():
6767

6868

6969
def test_build_mro_attribute_set():
70-
attr = build_mro_attribute(_ChildClass, attribute_name="a", attribute_type=set)
70+
attr = combine_attribute_from_parent_classes(_ChildClass, attribute_name="a", attribute_type=set)
7171
assert attr == {1, 2, 3, 4, 5}
7272

7373

7474
def test_build_mro_attribute_dict():
7575
assert _ChildClass.b == {2: "b", 3: "ccc"}
76-
attr = build_mro_attribute(_ChildClass, attribute_name="b", attribute_type=dict)
76+
attr = combine_attribute_from_parent_classes(_ChildClass, attribute_name="b", attribute_type=dict)
7777
assert attr == {1: "a", 2: "b", 3: "ccc"}
7878

7979

8080
def test_build_mro_attribute_list():
8181
with pytest.raises(NotImplementedError):
82-
build_mro_attribute(_ChildClass, attribute_name="c", attribute_type=list)
82+
combine_attribute_from_parent_classes(_ChildClass, attribute_name="c", attribute_type=list)

0 commit comments

Comments
 (0)