55"""Misc utils"""
66
77from collections .abc import Sequence
8- from typing import get_type_hints , overload
8+ from typing import get_type_hints
99
1010import numpy as np
1111
@@ -30,34 +30,27 @@ def get_public_annotations(cls: type):
3030 return {attr : type_ for attr , type_ in class_attributes .items () if not attr .startswith ("_" )}
3131
3232
33- @overload
34- def build_mro_attribute (cls : type , attribute_name : str , attribute_type : type [dict ]) -> dict : ...
35-
36-
37- @overload
38- def build_mro_attribute (cls : type , attribute_name : str , attribute_type : type [set ]) -> set : ...
39-
40-
41- def build_mro_attribute (cls : type , attribute_name : str , attribute_type ) -> dict | set :
33+ def combine_attribute_from_parent_classes [T : (dict , set )](cls : type , attribute_name : str , attribute_type : type [T ]) -> T :
4234 """Combine all versions of an attribute in the Method Resolution Order (mro) of a class into a single attribute
4335
4436 For dicts this means the dict is updated so that child classes override parent classes.
4537 For sets this means the sets are unioned together.
4638
4739 Types other than dict and set are not supported
4840 """
49- attr_value = attribute_type ()
41+ combined_attr = attribute_type ()
5042 for parent in reversed (list (cls .__mro__ )):
43+ parent_attr = getattr (parent , attribute_name , attribute_type ())
5144 if attribute_type is dict :
52- attr_value .update (getattr ( parent , attribute_name , {}) )
45+ combined_attr .update (parent_attr )
5346 elif attribute_type is set :
54- attr_value |= getattr ( parent , attribute_name , set ())
47+ combined_attr |= parent_attr
5548 else :
5649 raise NotImplementedError (
5750 f"Type { attribute_type } cannot combine inherited for attribute { attribute_name } . "
5851 f"Only dict and set are currently supported."
5952 )
60- return attr_value
53+ return combined_attr
6154
6255
6356def array_equal_with_nan (array1 : np .ndarray , array2 : np .ndarray ) -> bool :
0 commit comments