@@ -89,16 +89,27 @@ def drivers(self):
8989 r"""iterator: All drivers."""
9090 return itertools .chain (self .models , self .connections )
9191
92+ @classmethod
93+ def _get_commlist (cls , comm ):
94+ if isinstance (comm ['commtype' ], list ):
95+ return comm ['commtype' ]
96+ elif (comm ['commtype' ] in ['server' , 'client' , 'fork' ]
97+ and isinstance (comm .get ('comm_list' , None ), list )):
98+ return comm ['comm_list' ]
99+ elif (comm ['commtype' ] in ['server' , 'client' ]
100+ and isinstance (comm .get ('request_commtype' , None ), list )):
101+ return comm ['request_commtype' ]
102+ return None
103+
92104 @classmethod
93105 def _param_model_comm (cls , comm ):
94106 commlist = None
95107 if isinstance (comm , dict ):
96108 name = comm ['name' ]
97109 direction = comm ['direction' ]
98- model = comm [ 'model' ]
110+ model = comm . get ( 'model' , None )
99111 name = strip_model_prefix (name , model )
100- if isinstance (comm ['commtype' ], list ):
101- commlist = comm ['commtype' ]
112+ commlist = cls ._get_commlist (comm )
102113 else :
103114 name = comm .opp_name
104115 direction = comm .opp_direction
@@ -107,9 +118,15 @@ def _param_model_comm(cls, comm):
107118 if comm ._commtype == 'fork' :
108119 commlist = comm .comm_list
109120 elif comm ._commtype == 'server' and direction == 'send' :
110- commlist = [comm .icomm ]
121+ if comm .icomm ._commtype == 'fork' :
122+ commlist = comm .icomm .comm_list
123+ else :
124+ commlist = [comm .icomm ]
111125 elif comm ._commtype == 'client' and direction == 'recv' :
112- commlist = [comm .ocomm ]
126+ if comm .ocomm ._commtype == 'fork' :
127+ commlist = comm .ocomm .comm_list
128+ else :
129+ commlist = [comm .ocomm ]
113130 return (model , direction , name , commlist )
114131
115132 @classmethod
@@ -126,7 +143,8 @@ def _check_model_comm(cls, comm, model, direction, name):
126143 pass
127144 raise CommBase .CommError ("no match" )
128145
129- def find_model_comm (self , model , direction , name ):
146+ def find_model_comm (self , model , direction , name ,
147+ is_split_server = False ):
130148 r"""Locate a matching communicator from the registered
131149 connections.
132150
@@ -135,6 +153,8 @@ def find_model_comm(self, model, direction, name):
135153 direction (str): Direction that the model communicator
136154 operates in.
137155 name (str): Channel name for the communicator.
156+ is_split_server (bool, optional): If True, the comm is part
157+ of a split server comm.
138158
139159 Returns:
140160 dict, CommBase: Partner communicator or parameters for the
@@ -155,8 +175,13 @@ def find_model_comm(self, model, direction, name):
155175 comm , model , direction , name )
156176 except CommBase .CommError :
157177 pass
178+ if is_split_server :
179+ try :
180+ return self .find_model_comm (model , 'recv' , model )
181+ except BrokerError :
182+ pass
158183 raise BrokerError (f"Could not locate a { direction } communicator "
159- f"with name \" { name } \" for model \" { model } \" : "
184+ f"with name \" { name } \" for model \" { model } \" :\n "
160185 f"{ pprint .pformat (self .comms )} " )
161186
162187 @classmethod
@@ -512,6 +537,24 @@ def _send_request(cls, action, *args, **kwargs):
512537 # client.info(f"RESPONSE [{action}]: {response['return']}")
513538 return response ['return' ]
514539
540+ @classmethod
541+ def is_address_set (cls , comm ):
542+ r"""Check if an address is defined.
543+
544+ Args:
545+ comm (dict): Communication definition.
546+
547+ Returns:
548+ bool: True if the address is defined, False otherwise.
549+
550+ """
551+ if 'address' in comm :
552+ return True
553+ commlist = cls ._get_commlist (comm )
554+ if commlist is None :
555+ return False
556+ return all ('address' in x for x in commlist )
557+
515558 @classmethod
516559 def update_model_comm_kwargs (cls , name , kwargs , self = None ):
517560 r"""Update the communicator parameters for the partner model
@@ -525,15 +568,31 @@ def update_model_comm_kwargs(cls, name, kwargs, self=None):
525568 if self is None :
526569 return cls ._send_request ('update_model_comm_kwargs' ,
527570 name , kwargs )
571+ partner_model = kwargs .pop ('partner_model' , None )
528572 model = kwargs .pop ('model' )
529573 direction = kwargs .pop ('direction' )
530574 comm = self .find_model_comm (model , direction , name )
531575 assert isinstance (comm , dict )
576+ comm0 = comm
577+ commlist = self ._get_commlist (comm )
578+ if commlist is not None :
579+ assert partner_model is not None
580+ partners = [x ['partner_model' ] for x in commlist ]
581+ assert sum (x == partner_model for x in partners ) == 1
582+ idx = partners .index (partner_model )
583+ comm = commlist [idx ]
584+ if kwargs ['commtype' ] in ['server' , 'client' ]:
585+ assert comm0 ['commtype' ] == kwargs ['commtype' ]
586+ kwargs .pop ('commtype' )
587+ if 'request_commtype' in kwargs :
588+ kwargs ['commtype' ] = kwargs .pop ('request_commtype' )
532589 comm .update (kwargs )
533590 assert comm ['direction' ] == direction
534- if name in self . _awaiting_comm :
591+ if commlist is None :
535592 assert 'address' in comm
536- # self.server_comm.info(f"UNLOCKING: {name}")
593+ if name in self ._awaiting_comm and self .is_address_set (comm0 ):
594+ # self.server_comm.info(f"UNLOCKING: {name} "
595+ # f"{self._awaiting_comm}")
537596 self ._awaiting_comm .remove (name )
538597
539598 @classmethod
@@ -569,9 +628,9 @@ def model_comm_kwargs(cls, name, direction, self=None):
569628 if isinstance (model_driver .is_server , dict ):
570629 if ((name == model_driver .is_server ['input' ]
571630 or name == model_driver .is_server ['output' ])):
572- # TODO: Verify that this works
573631 is_split_server = True
574- comm = self .find_model_comm (model , direction , name )
632+ comm = self .find_model_comm (model , direction , name ,
633+ is_split_server )
575634 if isinstance (comm , dict ):
576635 out = copy .deepcopy (comm )
577636 if (('address' not in comm
@@ -591,8 +650,11 @@ def model_comm_kwargs(cls, name, direction, self=None):
591650 )
592651 if is_split_server or out ['commtype' ] == 'model_function' :
593652 out ['global_scope' ] = model
594- if 'address' not in out and 'partner_name' in out :
595- # self.server_comm.info(f"LOCKING {out['partner_name']}")
653+ if (('partner_name' in out
654+ and out ['partner_name' ] not in self ._awaiting_comm
655+ and (not self .is_address_set (out )))):
656+ # self.server_comm.info(f"LOCKING: {out['partner_name']} "
657+ # f"{self._awaiting_comm}")
596658 self ._awaiting_comm .append (out ['partner_name' ])
597659 return out
598660
@@ -611,7 +673,8 @@ def model_comm(cls, name, direction, **kwargs):
611673
612674 """
613675 kwargs = dict (cls .model_comm_kwargs (name , direction ), ** kwargs )
614- assert direction == kwargs ['direction' ]
676+ if kwargs .get ('commtype' , None ) not in ['client' , 'server' ]:
677+ assert direction == kwargs ['direction' ]
615678 name = kwargs ['name' ]
616679 global_scope = kwargs .pop ('global_scope' , False )
617680 global_name = name
@@ -622,13 +685,16 @@ def model_comm(cls, name, direction, **kwargs):
622685 # for case where server or function comm is split between
623686 # two aliases
624687 return cls ._global_scope_comms [global_name ]
625- if 'address' in kwargs :
688+ if cls . is_address_set ( kwargs ) :
626689 out = get_comm (** kwargs )
627690 else :
628691 partner_name = kwargs ['partner_name' ]
629692 out = new_comm (** kwargs )
630693 cls .update_model_comm_kwargs (
631- partner_name , out .model_comm_kwargs )
694+ partner_name ,
695+ dict (out .model_comm_kwargs ,
696+ partner_model = kwargs ['model' ]),
697+ )
632698 if global_scope :
633699 cls ._global_scope_comms [global_name ] = out
634700 return out
0 commit comments