Skip to content

Commit 280e21e

Browse files
committed
fixed edge case with type_safe decorator where it wasn't converting return values
1 parent ca7df93 commit 280e21e

File tree

6 files changed

+150
-16
lines changed

6 files changed

+150
-16
lines changed

osbot_utils/type_safe/type_safe_core/collections/Type_Safe__List.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,23 @@
99
class Type_Safe__List(Type_Safe__Base, list):
1010
expected_type : Type = None # Class-level default
1111

12-
def __init__(self, expected_type=None, *args):
13-
super().__init__(*args)
14-
self.expected_type = expected_type or self.__class__.expected_type
15-
if self.expected_type is None:
12+
def __init__(self, expected_type=None, initial_data=None):
13+
super().__init__() # Initialize empty list first
14+
15+
if isinstance(expected_type, list) and initial_data is None: # Smart detection: if expected_type is a list and initial_data is None
16+
initial_data = expected_type # They're using the pattern: An_List(['a', 'b'])
17+
expected_type = None # Move the list to initial_data
18+
19+
self.expected_type = expected_type or self.__class__.expected_type # Use provided type, or fall back to class-level attribute
20+
21+
if self.expected_type is None: # Validate that we have type set (either from args or class)
1622
raise ValueError(f"{self.__class__.__name__} requires expected_type")
1723

24+
if initial_data is not None: # Process initial data through our type-safe append
25+
for item in initial_data:
26+
self.append(item) # let the .append() check the type_safety
27+
28+
1829
def __contains__(self, item):
1930
if super().__contains__(item): # First try direct lookup
2031
return True

osbot_utils/type_safe/type_safe_core/collections/Type_Safe__Set.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,22 @@
77
class Type_Safe__Set(Type_Safe__Base, set):
88
expected_type : Type = None # Class-level default
99

10-
def __init__(self, expected_type=None, *args):
11-
super().__init__(*args)
12-
self.expected_type = expected_type or self.__class__.expected_type
13-
if self.expected_type is None:
10+
def __init__(self, expected_type=None, initial_data=None):
11+
super().__init__() # Initialize empty set first
12+
13+
if isinstance(expected_type, (set, frozenset)) and initial_data is None: # Smart detection
14+
initial_data = expected_type # They're using the pattern: An_Set({'a', 'b'})
15+
expected_type = None # Move the set to initial_data
16+
17+
self.expected_type = expected_type or self.__class__.expected_type # Use provided type, or fall back to class-level attribute
18+
19+
if self.expected_type is None: # Validate that we have type set
1420
raise ValueError(f"{self.__class__.__name__} requires expected_type")
1521

22+
if initial_data is not None: # Process initial data through our type-safe add
23+
for item in initial_data:
24+
self.add(item)
25+
1626
def __contains__(self, item):
1727
if super().__contains__(item): # First try direct lookup
1828
return True

osbot_utils/type_safe/type_safe_core/decorators/type_safe.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import functools # For wrapping functions
2+
from typing import get_args
23
from osbot_utils.type_safe.Type_Safe__Base import Type_Safe__Base
34
from osbot_utils.type_safe.Type_Safe__Primitive import Type_Safe__Primitive
45
from osbot_utils.type_safe.type_safe_core.methods.Type_Safe__Method import Type_Safe__Method
6+
from osbot_utils.type_safe.type_safe_core.shared.Type_Safe__Cache import type_safe_cache
57

68

79
def type_safe(func): # Main decorator function
@@ -30,6 +32,8 @@ def wrapper(*args, **kwargs):
3032
if return_type is not None and result is not None: # Validate return type using existing type checking infrastructure
3133
if isinstance(return_type, type) and issubclass(return_type, Type_Safe__Primitive): # Try to convert Type_Safe__Primitive types
3234
result = return_type(result) # Since we are using a Type_Safe__Primitive, if there is bad data (like a negative number in Safe_UInt) this will trigger an exception
35+
36+
result = convert_return_value(result, return_type) # todo: refactor convert_return_value into another class (maybe validator or a new converter class)
3337
try:
3438
validator.is_instance_of_type(result, return_type)
3539
except TypeError as e:
@@ -38,3 +42,57 @@ def wrapper(*args, **kwargs):
3842
return result
3943
return wrapper # Return wrapped function
4044

45+
46+
# todo: see if we can optmise this return value conversion, namely in detecting if a conversion is needed
47+
def convert_return_value(result, return_type):
48+
"""Convert return value to match expected type with auto-conversion."""
49+
from osbot_utils.type_safe.type_safe_core.collections.Type_Safe__List import Type_Safe__List
50+
from osbot_utils.type_safe.type_safe_core.collections.Type_Safe__Set import Type_Safe__Set
51+
from osbot_utils.type_safe.type_safe_core.collections.Type_Safe__Dict import Type_Safe__Dict
52+
53+
origin = type_safe_cache.get_origin(return_type)
54+
args = get_args(return_type)
55+
56+
# # Handle List[T] -> Type_Safe__List conversion
57+
if origin is list and args and isinstance(result, list) and not isinstance(result, Type_Safe__List):
58+
item_type = args[0]
59+
type_safe_list = Type_Safe__List(expected_type=item_type)
60+
for item in result:
61+
type_safe_list.append(item) # Auto-converts items
62+
return type_safe_list
63+
64+
# Handle Set[T] -> Type_Safe__Set conversion
65+
if origin is set and args and isinstance(result, set) and not isinstance(result, Type_Safe__Set):
66+
item_type = args[0]
67+
type_safe_set = Type_Safe__Set(expected_type=item_type)
68+
for item in result:
69+
type_safe_set.add(item) # Auto-converts items
70+
return type_safe_set
71+
72+
# Handle Dict[K, V] -> Type_Safe__Dict conversion
73+
if origin is dict and args and len(args) == 2 and isinstance(result, dict) and not isinstance(result, Type_Safe__Dict):
74+
key_type, value_type = args
75+
type_safe_dict = Type_Safe__Dict(expected_key_type=key_type, expected_value_type=value_type)
76+
for k, v in result.items():
77+
type_safe_dict[k] = v # Auto-converts keys and values
78+
return type_safe_dict
79+
80+
# Handle Type_Safe__List subclass directly (e.g., -> An_List)
81+
if isinstance(return_type, type) and issubclass(return_type, Type_Safe__List):
82+
if isinstance(result, list) and not isinstance(result, return_type):
83+
return return_type(result) # Uses nice constructor: An_List(['a', 'b'])
84+
return result
85+
86+
# Handle Type_Safe__Set subclass directly (e.g., -> An_Set)
87+
if isinstance(return_type, type) and issubclass(return_type, Type_Safe__Set):
88+
if isinstance(result, (set, list, tuple)) and not isinstance(result, return_type):
89+
return return_type(result) # Uses nice constructor: An_Set({'a', 'b'})
90+
return result
91+
92+
# Handle Type_Safe__Dict subclass directly (e.g., -> Hash_Mapping)
93+
if isinstance(return_type, type) and issubclass(return_type, Type_Safe__Dict):
94+
if isinstance(result, dict) and not isinstance(result, return_type):
95+
return return_type(result) # Uses nice constructor: Hash_Mapping({'key': 'val'})
96+
return result
97+
98+
return result

tests/unit/type_safe/type_safe_core/_bugs/test_Type_Safe__List__bugs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
class test_Type_Safe__List__bugs(TestCase):
88

9-
109
def test__bug__type_safe_list_with_callable(self):
1110
from typing import Callable
1211

tests/unit/type_safe/type_safe_core/decorators/test__decorator__type_safe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,8 @@ def return_list_wrong() -> List[int]:
433433

434434
assert return_list_int() == [1, 2, 3]
435435

436-
with pytest.raises(TypeError, match="return type validation failed"):
436+
error_message = "In Type_Safe__List: Invalid type for item: Expected 'int', but got 'str'"
437+
with pytest.raises(TypeError, match=re.escape(error_message)):
437438
return_list_wrong()
438439

439440

@@ -472,10 +473,12 @@ def return_dict_wrong_value() -> Dict[str, int]:
472473

473474
assert return_dict() == {"a": 1, "b": 2}
474475

475-
with pytest.raises(TypeError, match="return type validation failed"):
476+
error_message_1 = "Expected 'str', but got 'int'"
477+
with pytest.raises(TypeError, match=error_message_1):
476478
return_dict_wrong_key()
477479

478-
with pytest.raises(TypeError, match="return type validation failed"):
480+
error_message_2 = "Expected 'int', but got 'str'"
481+
with pytest.raises(TypeError, match=error_message_2):
479482
return_dict_wrong_value()
480483

481484

tests/unit/type_safe/type_safe_core/decorators/test__decorator__type_safe__bugs.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,64 @@
1+
import re
2+
13
import pytest
2-
from typing import List, Dict, Callable
3-
from unittest import TestCase
4-
from osbot_utils.type_safe.Type_Safe import Type_Safe
5-
from osbot_utils.type_safe.type_safe_core.decorators.type_safe import type_safe
4+
from typing import List, Dict, Callable
5+
from unittest import TestCase
6+
from osbot_utils.type_safe.primitives.domains.identifiers.safe_str.Safe_Str__Id import Safe_Str__Id
7+
from osbot_utils.type_safe.type_safe_core.collections.Type_Safe__List import Type_Safe__List
8+
from osbot_utils.type_safe.type_safe_core.decorators.type_safe import type_safe
69

710

811
class test__decorator__type_safe__bugs(TestCase):
912

13+
def test__regression__type_safe_decorator__doesnt_auto_convert_lists_return_value(self):
14+
15+
#BUG 1
16+
@type_safe
17+
def an_function_1(value:str) -> List[Safe_Str__Id]:
18+
values = ['42', value]
19+
return values
20+
21+
# error_message_1 = "Function 'test__decorator__type_safe__bugs.test__bug__type_safe_decorator__doesnt_auto_convert_lists_return_value.<locals>.an_function_1' return type validation failed: In list at index 0: Expected 'Safe_Str__Id', but got 'str'"
22+
# with pytest.raises(TypeError, match=re.escape(error_message_1)):
23+
# an_function_1('is the answer') # BUG
24+
assert an_function_1('is the answer') == ['42', 'is_the_answer'] # FIXED
25+
assert an_function_1('is the answer').obj() == ['42', 'is_the_answer'] # FIXED
26+
27+
# BUG 2
28+
class An_List(Type_Safe__List):
29+
expected_type = Safe_Str__Id
30+
31+
@type_safe
32+
def an_function_2(value:str) -> An_List:
33+
values= ['42', value]
34+
return values
35+
36+
# error_message_2 = "In Type_Safe__List: Invalid type for item: Expected 'An_List', but got 'str'"
37+
# with pytest.raises(TypeError, match=re.escape(error_message_2)):
38+
# an_function_2('is the answer') # BUG
39+
40+
assert an_function_2('is the answer') == ['42', 'is_the_answer'] # FIXED
41+
assert an_function_2('is the answer').obj () == ['42', 'is_the_answer'] # FIXED
42+
assert an_function_2('is the answer').json() == ['42', 'is_the_answer'] # FIXED
43+
44+
# Control test 1
45+
46+
@type_safe
47+
def an_function_3(value:str) -> List[Safe_Str__Id]:
48+
an_list = Type_Safe__List(expected_type=Safe_Str__Id)
49+
values = ['42', value]
50+
an_list.extend(values)
51+
return an_list
52+
53+
an_function_3('is the answer')
54+
55+
# # Control test 2
56+
an_list_1 = Type_Safe__List(expected_type=Safe_Str__Id)
57+
values_1 = ['42', 'value']
58+
an_list_1.extend(values_1)
59+
assert type(an_list_1) is Type_Safe__List
60+
assert type(an_list_1[0]) is Safe_Str__Id
61+
assert an_list_1 == ['42', 'value']
1062

1163
def test__bug__type_safe__list_callable_with_signatures(self):
1264

@@ -63,3 +115,4 @@ def create_openapi_spec(servers: List[Dict[str, str]]):
63115

64116

65117

118+

0 commit comments

Comments
 (0)