Skip to content

Commit a4dc0c2

Browse files
authored
Merge pull request #3159 from jsiirola/config-domain
Support config domains with either method or attribute domain_name
2 parents 0b4ea7d + 41d8197 commit a4dc0c2

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

pyomo/common/config.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,11 @@ def _domain_name(domain):
11341134
if domain is None:
11351135
return ""
11361136
elif hasattr(domain, 'domain_name'):
1137-
return domain.domain_name()
1137+
dn = domain.domain_name
1138+
if hasattr(dn, '__call__'):
1139+
return dn()
1140+
else:
1141+
return dn
11381142
elif domain.__class__ is type:
11391143
return domain.__name__
11401144
elif inspect.isfunction(domain):

pyomo/common/tests/test_config.py

+35
Original file line numberDiff line numberDiff line change
@@ -3265,6 +3265,41 @@ def __init__(
32653265
OUT.getvalue().replace('null', 'None'),
32663266
)
32673267

3268+
def test_domain_name(self):
3269+
cfg = ConfigDict()
3270+
3271+
cfg.declare('none', ConfigValue())
3272+
self.assertEqual(cfg.get('none').domain_name(), '')
3273+
3274+
def fcn(val):
3275+
return val
3276+
3277+
cfg.declare('fcn', ConfigValue(domain=fcn))
3278+
self.assertEqual(cfg.get('fcn').domain_name(), 'fcn')
3279+
3280+
fcn.domain_name = 'custom fcn'
3281+
self.assertEqual(cfg.get('fcn').domain_name(), 'custom fcn')
3282+
3283+
class functor:
3284+
def __call__(self, val):
3285+
return val
3286+
3287+
cfg.declare('functor', ConfigValue(domain=functor()))
3288+
self.assertEqual(cfg.get('functor').domain_name(), 'functor')
3289+
3290+
class cfunctor:
3291+
def __call__(self, val):
3292+
return val
3293+
3294+
def domain_name(self):
3295+
return 'custom functor'
3296+
3297+
cfg.declare('cfunctor', ConfigValue(domain=cfunctor()))
3298+
self.assertEqual(cfg.get('cfunctor').domain_name(), 'custom functor')
3299+
3300+
cfg.declare('type', ConfigValue(domain=int))
3301+
self.assertEqual(cfg.get('type').domain_name(), 'int')
3302+
32683303

32693304
if __name__ == "__main__":
32703305
unittest.main()

0 commit comments

Comments
 (0)