Skip to content

Commit 5a15333

Browse files
committed
Added support for accessing timesync channels from models imported as functions
Added support for returning all variables to models via a timesync conneciton if True passed in the timesync parameter 'additional_variables' for a model Fix bugs in model queue handling which dropped some buffered messages from the model stdout Fix bug in use of server name in bytes vs str
1 parent 289801a commit 5a15333

File tree

14 files changed

+423
-136
lines changed

14 files changed

+423
-136
lines changed

HISTORY.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ History
1313
* Migrate from setup.py + setup.cfg to pyproject.toml
1414
* Migrate build system to scikit-build-core
1515
* Introduce disassembler compilation tool class for dissecting compiled binaries
16+
* Added support for accessing timesync channels from models imported as functions
17+
* Added support for returning all variables to models via a timesync conneciton if True passed in the timesync parameter 'additional_variables' for a model
18+
* Fix bugs in model queue handling which dropped some buffered messages from the model stdout
1619

1720
1.10.2 (2023-10-12) Minor bug fixes and dependency updates
1821
-------------------

tests/test_runner.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,26 @@ def test_import_as_function():
9191
r"""Test import_as_function."""
9292
yamlfile = ex_yamls['fakeplant']['python']
9393
fmodel = import_as_function(yamlfile, remove_products=True)
94-
input_args = {}
95-
for x in fmodel.arguments:
96-
input_args[x] = 1.0
97-
fmodel.model_info()
98-
result = fmodel(**input_args)
99-
for x in fmodel.returns:
100-
assert x in result
101-
result = fmodel(*list(input_args.values()))
102-
for x in fmodel.returns:
103-
assert x in result
104-
fmodel.reload()
105-
fmodel.run()
106-
result = fmodel(**input_args)
107-
for x in fmodel.returns:
108-
assert x in result
109-
fmodel.stop()
110-
fmodel.stop()
94+
try:
95+
input_args = {}
96+
for x in fmodel.arguments:
97+
input_args[x] = 1.0
98+
fmodel.function_info
99+
assert len(fmodel.returns) == 1
100+
result = fmodel(**input_args)
101+
for x in fmodel.returns:
102+
assert x in result
103+
result = fmodel(*list(input_args.values()))
104+
for x in fmodel.returns:
105+
assert x in result
106+
fmodel.reload()
107+
fmodel.run()
108+
result = fmodel(**input_args)
109+
for x in fmodel.returns:
110+
assert x in result
111+
fmodel.stop()
112+
finally:
113+
fmodel.stop()
111114

112115

113116
def test_import_as_function_server():
@@ -119,18 +122,67 @@ def test_import_as_function_server():
119122
break
120123
assert yamlfile
121124
fmodel = import_as_function(yamlfile, remove_products=True)
122-
input_args = {}
123-
for x in fmodel.arguments:
124-
input_args[x] = 'hello'
125-
fmodel.model_info()
126-
result = fmodel(**input_args)
127-
for x in fmodel.returns:
128-
assert x in result
129-
result = fmodel(*list(input_args.values()))
130-
for x in fmodel.returns:
131-
assert x in result
132-
fmodel.stop()
133-
fmodel.stop()
125+
try:
126+
input_args = {}
127+
for x in fmodel.arguments:
128+
input_args[x] = 'hello'
129+
fmodel.function_info
130+
assert len(fmodel.returns) == 1
131+
result = fmodel(**input_args)
132+
for x in fmodel.returns:
133+
assert x in result
134+
result = fmodel(*list(input_args.values()))
135+
for x in fmodel.returns:
136+
assert x in result
137+
fmodel.stop()
138+
finally:
139+
fmodel.stop()
140+
141+
142+
def test_import_as_function_timesync():
143+
r"""Test import_as_function with timesync."""
144+
from yggdrasil import units
145+
contents = r"""models:
146+
- name: modelA
147+
language: python
148+
args:
149+
- ./src/timesync.py
150+
- 2
151+
- day
152+
timesync: True
153+
outputs:
154+
name: output
155+
default_file:
156+
name: modelA_output.txt
157+
in_temp: True
158+
filetype: table"""
159+
yamlfile = os.path.join(os.path.dirname(ex_yamls['timesync1']['python']),
160+
'test_import_timesync.yml')
161+
assert not os.path.isfile(yamlfile)
162+
with open(yamlfile, 'w') as fd:
163+
fd.write(contents)
164+
fmodel = None
165+
try:
166+
# TODO: Test where x/y modified in call
167+
steps = [
168+
({'timesync': (units.add_units(0.0, 'day'), {})},
169+
{'x': 0.0, 'y': 1.0}),
170+
({'timesync': (units.add_units(1.0, 'day'), {})},
171+
{'x': 0.47552825814757677, 'y': 0.09549150281252626}),
172+
]
173+
fmodel = import_as_function(yamlfile, remove_products=True,
174+
partial_timesync=True,
175+
partial_comms=['timesync'])
176+
fmodel.function_info
177+
input_args = {}
178+
for input_args, output_args in steps:
179+
result = fmodel(**input_args)
180+
assert result == output_args
181+
fmodel.stop()
182+
finally:
183+
if fmodel is not None:
184+
fmodel.stop()
185+
os.remove(yamlfile)
134186

135187

136188
@pytest.mark.language('c')
@@ -148,9 +200,10 @@ def test_import_as_function_C():
148200
assert not os.path.isfile(yamlfile)
149201
with open(yamlfile, 'w') as fd:
150202
fd.write(contents)
203+
fmodel = None
151204
try:
152205
fmodel = import_as_function(yamlfile, remove_products=True)
153-
fmodel.model_info()
206+
fmodel.function_info
154207
input_args = {}
155208
for x in fmodel.arguments:
156209
input_args[x] = b'hello'
@@ -161,6 +214,7 @@ def test_import_as_function_C():
161214
for x in fmodel.returns:
162215
assert x in result
163216
fmodel.stop()
164-
fmodel.stop()
165217
finally:
218+
if fmodel is not None:
219+
fmodel.stop()
166220
os.remove(yamlfile)

yggdrasil/.ygg_schema.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,7 +3474,9 @@ definitions:
34743474
additionalProperties:
34753475
items:
34763476
type: string
3477-
type: array
3477+
type:
3478+
- boolean
3479+
- array
34783480
default: {}
34793481
type: object
34803482
aggregation:
@@ -5453,7 +5455,9 @@ definitions:
54535455
additionalProperties:
54545456
items:
54555457
type: string
5456-
type: array
5458+
type:
5459+
- boolean
5460+
- array
54575461
default: {}
54585462
type: object
54595463
aggregation:

yggdrasil/broker.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,27 @@ def drivers(self):
8989
r"""iterator: All drivers."""
9090
return itertools.chain(self.models, self.connections)
9191

92+
@classmethod
93+
def _get_commlist(cls, comm):
94+
if isinstance(comm['commtype'], list):
95+
return comm['commtype']
96+
elif (comm['commtype'] in ['server', 'client', 'fork']
97+
and isinstance(comm.get('comm_list', None), list)):
98+
return comm['comm_list']
99+
elif (comm['commtype'] in ['server', 'client']
100+
and isinstance(comm.get('request_commtype', None), list)):
101+
return comm['request_commtype']
102+
return None
103+
92104
@classmethod
93105
def _param_model_comm(cls, comm):
94106
commlist = None
95107
if isinstance(comm, dict):
96108
name = comm['name']
97109
direction = comm['direction']
98-
model = comm['model']
110+
model = comm.get('model', None)
99111
name = strip_model_prefix(name, model)
100-
if isinstance(comm['commtype'], list):
101-
commlist = comm['commtype']
112+
commlist = cls._get_commlist(comm)
102113
else:
103114
name = comm.opp_name
104115
direction = comm.opp_direction
@@ -107,9 +118,15 @@ def _param_model_comm(cls, comm):
107118
if comm._commtype == 'fork':
108119
commlist = comm.comm_list
109120
elif comm._commtype == 'server' and direction == 'send':
110-
commlist = [comm.icomm]
121+
if comm.icomm._commtype == 'fork':
122+
commlist = comm.icomm.comm_list
123+
else:
124+
commlist = [comm.icomm]
111125
elif comm._commtype == 'client' and direction == 'recv':
112-
commlist = [comm.ocomm]
126+
if comm.ocomm._commtype == 'fork':
127+
commlist = comm.ocomm.comm_list
128+
else:
129+
commlist = [comm.ocomm]
113130
return (model, direction, name, commlist)
114131

115132
@classmethod
@@ -126,7 +143,8 @@ def _check_model_comm(cls, comm, model, direction, name):
126143
pass
127144
raise CommBase.CommError("no match")
128145

129-
def find_model_comm(self, model, direction, name):
146+
def find_model_comm(self, model, direction, name,
147+
is_split_server=False):
130148
r"""Locate a matching communicator from the registered
131149
connections.
132150
@@ -135,6 +153,8 @@ def find_model_comm(self, model, direction, name):
135153
direction (str): Direction that the model communicator
136154
operates in.
137155
name (str): Channel name for the communicator.
156+
is_split_server (bool, optional): If True, the comm is part
157+
of a split server comm.
138158
139159
Returns:
140160
dict, CommBase: Partner communicator or parameters for the
@@ -155,8 +175,13 @@ def find_model_comm(self, model, direction, name):
155175
comm, model, direction, name)
156176
except CommBase.CommError:
157177
pass
178+
if is_split_server:
179+
try:
180+
return self.find_model_comm(model, 'recv', model)
181+
except BrokerError:
182+
pass
158183
raise BrokerError(f"Could not locate a {direction} communicator "
159-
f"with name \"{name}\" for model \"{model}\": "
184+
f"with name \"{name}\" for model \"{model}\":\n"
160185
f"{pprint.pformat(self.comms)}")
161186

162187
@classmethod
@@ -512,6 +537,24 @@ def _send_request(cls, action, *args, **kwargs):
512537
# client.info(f"RESPONSE [{action}]: {response['return']}")
513538
return response['return']
514539

540+
@classmethod
541+
def is_address_set(cls, comm):
542+
r"""Check if an address is defined.
543+
544+
Args:
545+
comm (dict): Communication definition.
546+
547+
Returns:
548+
bool: True if the address is defined, False otherwise.
549+
550+
"""
551+
if 'address' in comm:
552+
return True
553+
commlist = cls._get_commlist(comm)
554+
if commlist is None:
555+
return False
556+
return all('address' in x for x in commlist)
557+
515558
@classmethod
516559
def update_model_comm_kwargs(cls, name, kwargs, self=None):
517560
r"""Update the communicator parameters for the partner model
@@ -525,15 +568,31 @@ def update_model_comm_kwargs(cls, name, kwargs, self=None):
525568
if self is None:
526569
return cls._send_request('update_model_comm_kwargs',
527570
name, kwargs)
571+
partner_model = kwargs.pop('partner_model', None)
528572
model = kwargs.pop('model')
529573
direction = kwargs.pop('direction')
530574
comm = self.find_model_comm(model, direction, name)
531575
assert isinstance(comm, dict)
576+
comm0 = comm
577+
commlist = self._get_commlist(comm)
578+
if commlist is not None:
579+
assert partner_model is not None
580+
partners = [x['partner_model'] for x in commlist]
581+
assert sum(x == partner_model for x in partners) == 1
582+
idx = partners.index(partner_model)
583+
comm = commlist[idx]
584+
if kwargs['commtype'] in ['server', 'client']:
585+
assert comm0['commtype'] == kwargs['commtype']
586+
kwargs.pop('commtype')
587+
if 'request_commtype' in kwargs:
588+
kwargs['commtype'] = kwargs.pop('request_commtype')
532589
comm.update(kwargs)
533590
assert comm['direction'] == direction
534-
if name in self._awaiting_comm:
591+
if commlist is None:
535592
assert 'address' in comm
536-
# self.server_comm.info(f"UNLOCKING: {name}")
593+
if name in self._awaiting_comm and self.is_address_set(comm0):
594+
# self.server_comm.info(f"UNLOCKING: {name} "
595+
# f"{self._awaiting_comm}")
537596
self._awaiting_comm.remove(name)
538597

539598
@classmethod
@@ -569,9 +628,9 @@ def model_comm_kwargs(cls, name, direction, self=None):
569628
if isinstance(model_driver.is_server, dict):
570629
if ((name == model_driver.is_server['input']
571630
or name == model_driver.is_server['output'])):
572-
# TODO: Verify that this works
573631
is_split_server = True
574-
comm = self.find_model_comm(model, direction, name)
632+
comm = self.find_model_comm(model, direction, name,
633+
is_split_server)
575634
if isinstance(comm, dict):
576635
out = copy.deepcopy(comm)
577636
if (('address' not in comm
@@ -591,8 +650,11 @@ def model_comm_kwargs(cls, name, direction, self=None):
591650
)
592651
if is_split_server or out['commtype'] == 'model_function':
593652
out['global_scope'] = model
594-
if 'address' not in out and 'partner_name' in out:
595-
# self.server_comm.info(f"LOCKING {out['partner_name']}")
653+
if (('partner_name' in out
654+
and out['partner_name'] not in self._awaiting_comm
655+
and (not self.is_address_set(out)))):
656+
# self.server_comm.info(f"LOCKING: {out['partner_name']} "
657+
# f"{self._awaiting_comm}")
596658
self._awaiting_comm.append(out['partner_name'])
597659
return out
598660

@@ -611,7 +673,8 @@ def model_comm(cls, name, direction, **kwargs):
611673
612674
"""
613675
kwargs = dict(cls.model_comm_kwargs(name, direction), **kwargs)
614-
assert direction == kwargs['direction']
676+
if kwargs.get('commtype', None) not in ['client', 'server']:
677+
assert direction == kwargs['direction']
615678
name = kwargs['name']
616679
global_scope = kwargs.pop('global_scope', False)
617680
global_name = name
@@ -622,13 +685,16 @@ def model_comm(cls, name, direction, **kwargs):
622685
# for case where server or function comm is split between
623686
# two aliases
624687
return cls._global_scope_comms[global_name]
625-
if 'address' in kwargs:
688+
if cls.is_address_set(kwargs):
626689
out = get_comm(**kwargs)
627690
else:
628691
partner_name = kwargs['partner_name']
629692
out = new_comm(**kwargs)
630693
cls.update_model_comm_kwargs(
631-
partner_name, out.model_comm_kwargs)
694+
partner_name,
695+
dict(out.model_comm_kwargs,
696+
partner_model=kwargs['model']),
697+
)
632698
if global_scope:
633699
cls._global_scope_comms[global_name] = out
634700
return out

0 commit comments

Comments
 (0)