Skip to content

Commit 290993e

Browse files
committed
Add Optional trait type as shorthand for Union(None, ...)
1 parent ba745ce commit 290993e

File tree

6 files changed

+356
-6
lines changed

6 files changed

+356
-6
lines changed

traits/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
ToolbarButton,
104104
Either,
105105
Union,
106+
Optional,
106107
Type,
107108
Subclass,
108109
Symbol,

traits/tests/test_constant.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,19 @@ class TestClass(HasTraits):
7272

7373
# Check directly that both refer to the same object.
7474
self.assertIs(obj1.c_atr, obj2.c_atr)
75+
76+
@unittest.expectedFailure
77+
def test_constant_validator(self):
78+
"""
79+
XFAIL: `validate` on constant does not reject new values.
80+
81+
See enthought/traits#1784
82+
"""
83+
class TestClass(HasTraits):
84+
attribute = Constant(123)
85+
86+
a = TestClass()
87+
const_trait = a.traits()["attribute"]
88+
89+
with self.assertRaises(TraitError):
90+
const_trait.validate(a, "attribute", 456)

traits/tests/test_optional.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# (C) Copyright 2005-2024 Enthought, Inc., Austin, TX
2+
# All rights reserved.
3+
#
4+
# This software is provided without warranty under the terms of the BSD
5+
# license included in LICENSE.txt and may be redistributed only under
6+
# the conditions described in the aforementioned license. The license
7+
# is also available online at http://www.enthought.com/licenses/BSD.txt
8+
#
9+
# Thanks for using Enthought open source!
10+
11+
import unittest
12+
13+
from traits.api import (
14+
Bytes,
15+
DefaultValue,
16+
Float,
17+
HasTraits,
18+
Instance,
19+
Int,
20+
List,
21+
Str,
22+
TraitError,
23+
TraitType,
24+
Type,
25+
Optional,
26+
Constant,
27+
)
28+
from traits.trait_types import _NoneTrait
29+
30+
31+
class CustomClass(HasTraits):
32+
value = Int
33+
34+
35+
class CustomStrType(TraitType):
36+
37+
#: The default value type to use.
38+
default_value_type = DefaultValue.constant
39+
40+
#: The default value.
41+
default_value = "a string value"
42+
43+
def validate(self, obj, name, value):
44+
if not isinstance(value, Str):
45+
return value
46+
self.error(obj, name, value)
47+
48+
49+
class TestOptional(unittest.TestCase):
50+
51+
def test_optional_basic(self):
52+
class TestClass(HasTraits):
53+
attribute = Optional(Int)
54+
55+
TestClass(attribute=None)
56+
TestClass(attribute=3)
57+
58+
self.assertIsNone(TestClass(attribute=None).attribute)
59+
self.assertEqual(TestClass(attribute=3).attribute, 3)
60+
61+
with self.assertRaises(TraitError):
62+
TestClass(attribute="3")
63+
64+
def test_optional_list(self):
65+
class TestClass(HasTraits):
66+
attribute = Optional(List(Int))
67+
68+
TestClass(attribute=None)
69+
TestClass(attribute=[1, 2, 3])
70+
71+
with self.assertRaises(TraitError):
72+
TestClass(attribute=3)
73+
74+
def test_optional_instance(self):
75+
class TestClass(HasTraits):
76+
attribute = Optional(Instance(Int))
77+
78+
TestClass(attribute=None)
79+
TestClass(attribute=Int(3))
80+
81+
with self.assertRaises(TraitError):
82+
TestClass(attribute=3)
83+
with self.assertRaises(TraitError):
84+
TestClass(attribute=Int)
85+
86+
def test_optional_instance_custom_class(self):
87+
class TestClass(HasTraits):
88+
attribute = Optional(Instance(CustomClass))
89+
90+
TestClass(attribute=None)
91+
TestClass(attribute=CustomClass(value=5))
92+
93+
with self.assertRaises(TraitError):
94+
TestClass(attribute=5)
95+
96+
with self.assertRaises(TraitError):
97+
TestClass(attribute=CustomClass)
98+
99+
self.assertEqual(
100+
TestClass(attribute=CustomClass(value=5)).attribute.value, 5
101+
)
102+
103+
self.assertIsNone(TestClass().attribute)
104+
self.assertIsNone(TestClass(attribute=None).attribute)
105+
106+
def test_optional_user_defined_type(self):
107+
class TestClass(HasTraits):
108+
attribute = Optional(CustomStrType)
109+
110+
a = TestClass(attribute="my value")
111+
self.assertEqual(a.attribute, "my value")
112+
113+
b = TestClass()
114+
self.assertIsNone(b.attribute)
115+
116+
c = TestClass(attribute=3)
117+
self.assertEqual(c.attribute, 3)
118+
119+
def test_optional_instance_with_implicit_default_value(self):
120+
"""
121+
Implicit default is always ``None``
122+
"""
123+
124+
class TestClass(HasTraits):
125+
attribute = Optional(Int)
126+
127+
self.assertIsNone(TestClass().attribute)
128+
self.assertEqual(TestClass(attribute=3).attribute, 3)
129+
self.assertIsNone(TestClass(attribute=None).attribute)
130+
131+
def test_optional_instance_with_metadata_default_value(self):
132+
"""
133+
Setting the ``default_value`` at trait-level sets the default value
134+
"""
135+
136+
class TestClass(HasTraits):
137+
attribute = Optional(Int, default_value=5)
138+
139+
self.assertEqual(TestClass().attribute, 5)
140+
self.assertEqual(TestClass(attribute=3).attribute, 3)
141+
self.assertIsNone(TestClass(attribute=None).attribute)
142+
143+
def test_optional_instance_with_type_default_value(self):
144+
"""
145+
Setting the ``default_value`` of the inner trait does not set the
146+
default value of the ``Optional``
147+
"""
148+
# Note: may want to warn in this case
149+
# Discussion ref: enthought/traits#1298
150+
151+
class TestClass(HasTraits):
152+
attribute = Optional(Int(5))
153+
154+
self.assertIsNone(TestClass().attribute)
155+
self.assertEqual(TestClass(attribute=3).attribute, 3)
156+
self.assertIsNone(TestClass(attribute=None).attribute)
157+
158+
def test_optional_invalid_trait(self):
159+
with self.assertRaises(ValueError) as e:
160+
161+
class _TestClass(HasTraits):
162+
attribute = Optional(123)
163+
164+
self.assertEqual(
165+
str(e.exception),
166+
"Optional trait declaration expects a trait type or an instance "
167+
"of trait type or None, but got 123 instead",
168+
)
169+
170+
def test_optional_of_none(self):
171+
with self.assertRaises(TraitError) as e:
172+
173+
class _TestClass(HasTraits):
174+
attribute = Optional(None)
175+
176+
self.assertEqual(str(e.exception), "Optional type must not be None.")
177+
178+
def test_optional_unspecified_arguments(self):
179+
with self.assertRaises(TypeError) as e:
180+
181+
class TestClass(HasTraits):
182+
none = Optional()
183+
184+
self.assertIn(
185+
"missing 1 required positional argument", str(e.exception)
186+
)
187+
188+
def test_optional_multiple_type_arguments(self):
189+
with self.assertRaises(TypeError):
190+
191+
class TestClass(HasTraits):
192+
attribute = Optional(Int, Float)
193+
194+
def test_optional_default_raise_error(self):
195+
"""
196+
Behaviour inherited from ``Union``
197+
"""
198+
with self.assertRaises(ValueError) as e:
199+
200+
class TestClass(HasTraits):
201+
attribute = Optional(Int(), default=5)
202+
203+
self.assertEqual(
204+
str(e.exception),
205+
"Optional default value should be set via 'default_value', not "
206+
"'default'.",
207+
)
208+
209+
def test_optional_inner_traits(self):
210+
class TestClass(HasTraits):
211+
attribute = Optional(Int(3))
212+
213+
obj = TestClass()
214+
t1, t2 = obj.trait("attribute").inner_traits
215+
216+
self.assertEqual(type(t1.trait_type), _NoneTrait)
217+
self.assertEqual(t1.default_value(), (DefaultValue.constant, None))
218+
self.assertEqual(type(t2.trait_type), Int)
219+
self.assertEqual(t2.default_value(), (DefaultValue.constant, 3))
220+
221+
def test_optional_notification(self):
222+
class TestClass(HasTraits):
223+
attribute = Optional(Int)
224+
shadow_attribute = None
225+
226+
def _attribute_changed(self, new):
227+
self.shadow_attribute = new
228+
229+
obj = TestClass(attribute=3)
230+
231+
obj.attribute = 5
232+
self.assertEqual(obj.shadow_attribute, 5)
233+
234+
obj.attribute = None
235+
self.assertIsNone(obj.shadow_attribute)
236+
237+
def test_optional_extend_trait(self):
238+
class OptionalOrStr(Optional):
239+
def validate(self, obj, name, value):
240+
if isinstance(value, str):
241+
return value
242+
return super().validate(obj, name, value)
243+
244+
class TestClass(HasTraits):
245+
attribute = OptionalOrStr(Int)
246+
247+
self.assertEqual(TestClass(attribute=123).attribute, 123)
248+
self.assertEqual(TestClass(attribute="abc").attribute, "abc")
249+
self.assertIsNone(TestClass(attribute=None).attribute)
250+
self.assertIsNone(TestClass().attribute)
251+
252+
with self.assertRaises(TraitError):
253+
TestClass(attribute=1.23)
254+
255+
@unittest.expectedFailure # See enthought/traits#1784
256+
def test_optional_constant(self):
257+
class TestClass(HasTraits):
258+
attribute = Optional(Constant(123))
259+
260+
self.assertEqual(TestClass(attribute=123).attribute, 123)
261+
self.assertIsNone(TestClass(attribute=None).attribute)
262+
263+
# Fails here - internal trait validation fails
264+
with self.assertRaises(TraitError):
265+
TestClass(attribute=124)

traits/tests/test_union.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from traits.api import (
1414
Bytes, DefaultValue, Float, HasTraits, Instance, Int, List, Str,
15-
TraitError, TraitType, Type, Union)
15+
TraitError, TraitType, Type, Union, Constant)
1616

1717

1818
class CustomClass(HasTraits):
@@ -133,6 +133,20 @@ def test_default_raise_error(self):
133133
"'default'."
134134
)
135135

136+
def test_default_raise_error_subclass(self):
137+
# Name used in error message inherited by subclass
138+
class TestUnion(Union):
139+
pass
140+
141+
with self.assertRaises(ValueError) as exception_context:
142+
TestUnion(Int(), Float(), default=1.0)
143+
144+
self.assertEqual(
145+
str(exception_context.exception),
146+
"TestUnion default value should be set via 'default_value', not "
147+
"'default'."
148+
)
149+
136150
def test_inner_traits(self):
137151
class TestClass(HasTraits):
138152
atr = Union(Float, Int, Str)
@@ -214,3 +228,15 @@ class HasUnionWithList(HasTraits):
214228
has_union.trait("nested").default_value(),
215229
(DefaultValue.constant, ""),
216230
)
231+
232+
@unittest.expectedFailure # See enthought/traits#1784
233+
def test_union_constant(self):
234+
class TestClass(HasTraits):
235+
attribute = Union(None, Constant(123))
236+
237+
self.assertEqual(TestClass(attribute=123).attribute, 123)
238+
self.assertIsNone(TestClass(attribute=None).attribute)
239+
240+
# Fails here - internal trait validation fails
241+
with self.assertRaises(TraitError):
242+
TestClass(attribute=124)

0 commit comments

Comments
 (0)