@@ -103,7 +103,7 @@ def _is_bad_op_to_add(op: OpBase, partiton: _Partition):
103103 has_bad_input = True
104104 return has_bad_input , connected_to_parti
105105
106- def _select_next (ready_ops : List [Tuple [OpBase , int ]], info : Dict [OpBase , _PartitionOpInfo ], partiton : _Partition , f : Function ) -> OpBase :
106+ def _select_next (ready_ops : List [Tuple [OpBase , int ]], info : Dict [OpBase , _PartitionOpInfo ], partiton : _Partition , f : Function , edge_ops : Dict [ OpBase , bool ] ) -> OpBase :
107107 '''
108108 Select the next op to put into the partition.
109109 1. If there is CrossSectionalOp, always select it first
@@ -113,7 +113,8 @@ def _select_next(ready_ops: List[Tuple[OpBase, int]], info: Dict[OpBase, _Partit
113113 '''
114114 cur_best = (- 1 , - 1 , - 1 ) # (is_loop, critical, score)
115115 cur_best_op : OpBase = None
116- edge_ops = partiton .get_edge_ops (f )
116+ if edge_ops is None :
117+ edge_ops = partiton .get_edge_ops (f )
117118 for op , idx in ready_ops :
118119 if isinstance (op , CrossSectionalOp ):
119120 return op
@@ -152,7 +153,13 @@ def _select_next(ready_ops: List[Tuple[OpBase, int]], info: Dict[OpBase, _Partit
152153 cur_best = score_tuple
153154 cur_best_op = op
154155 return cur_best_op
155-
156+
157+ def _has_critical_ops (edge_ops : Dict [OpBase , bool ]) -> bool :
158+ for op , is_in_loop in edge_ops .items ():
159+ if is_in_loop :
160+ return True
161+ return False
162+
156163def _partition (f : Function , partition_thres = 3 ) -> List [_Partition ]:
157164 opinfo = _collect_op_info (f )
158165 partitions : List [_Partition ] = []
@@ -162,7 +169,7 @@ def _partition(f: Function, partition_thres = 3) -> List[_Partition]:
162169 while len (ready_ops ):
163170 partition = _Partition (OrderedDict (), set ())
164171 # print("============\nnew partition:", partition)
165- selected = _select_next (ready_ops , opinfo , partition , f )
172+ selected = _select_next (ready_ops , opinfo , partition , f , None )
166173 while selected :
167174 # remove the pending dependency. If an op is ready, put into ready queue
168175 def maintain_ready_queue (s_op : OpBase ):
@@ -198,7 +205,8 @@ def maintain_ready_queue(s_op: OpBase):
198205 partition .add (opinfo , inp )
199206 partition .add (opinfo , selected )
200207 # print("@@@add ", selected)
201- if partition .num_outputs > partition_thres :
208+ next_edge_ops = partition .get_edge_ops (f )
209+ if partition .num_outputs > partition_thres and not _has_critical_ops (next_edge_ops ):
202210 # if an output is directly connected with the partition, add it
203211 direct_output = None
204212 for candidate , bat in ready_ops :
@@ -213,7 +221,7 @@ def maintain_ready_queue(s_op: OpBase):
213221 continue
214222 # too many outputs visited, make a new partition
215223 break
216- selected = _select_next (ready_ops , opinfo , partition , f )
224+ selected = _select_next (ready_ops , opinfo , partition , f , next_edge_ops )
217225 if partition .ops .__len__ ():
218226 partitions .append (partition )
219227 if to_visit .__len__ () != 0 :
0 commit comments