Skip to content

Commit de1c035

Browse files
committed
fix
1 parent c8fd721 commit de1c035

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

KunQuant/passes/Partitioner.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _is_bad_op_to_add(op: OpBase, partiton: _Partition):
103103
has_bad_input = True
104104
return has_bad_input, connected_to_parti
105105

106-
def _select_next(ready_ops: List[Tuple[OpBase, int]], info: Dict[OpBase, _PartitionOpInfo], partiton: _Partition, f: Function) -> OpBase:
106+
def _select_next(ready_ops: List[Tuple[OpBase, int]], info: Dict[OpBase, _PartitionOpInfo], partiton: _Partition, f: Function, edge_ops: Dict[OpBase, bool]) -> OpBase:
107107
'''
108108
Select the next op to put into the partition.
109109
1. If there is CrossSectionalOp, always select it first
@@ -113,7 +113,8 @@ def _select_next(ready_ops: List[Tuple[OpBase, int]], info: Dict[OpBase, _Partit
113113
'''
114114
cur_best = (-1, -1, -1) # (is_loop, critical, score)
115115
cur_best_op: OpBase = None
116-
edge_ops = partiton.get_edge_ops(f)
116+
if edge_ops is None:
117+
edge_ops = partiton.get_edge_ops(f)
117118
for op, idx in ready_ops:
118119
if isinstance(op, CrossSectionalOp):
119120
return op
@@ -152,7 +153,13 @@ def _select_next(ready_ops: List[Tuple[OpBase, int]], info: Dict[OpBase, _Partit
152153
cur_best = score_tuple
153154
cur_best_op = op
154155
return cur_best_op
155-
156+
157+
def _has_critical_ops(edge_ops: Dict[OpBase, bool]) -> bool:
158+
for op, is_in_loop in edge_ops.items():
159+
if is_in_loop:
160+
return True
161+
return False
162+
156163
def _partition(f: Function, partition_thres = 3) -> List[_Partition]:
157164
opinfo = _collect_op_info(f)
158165
partitions: List[_Partition] = []
@@ -162,7 +169,7 @@ def _partition(f: Function, partition_thres = 3) -> List[_Partition]:
162169
while len(ready_ops):
163170
partition = _Partition(OrderedDict(), set())
164171
# print("============\nnew partition:", partition)
165-
selected = _select_next(ready_ops, opinfo, partition, f)
172+
selected = _select_next(ready_ops, opinfo, partition, f, None)
166173
while selected:
167174
# remove the pending dependency. If an op is ready, put into ready queue
168175
def maintain_ready_queue(s_op: OpBase):
@@ -198,7 +205,8 @@ def maintain_ready_queue(s_op: OpBase):
198205
partition.add(opinfo, inp)
199206
partition.add(opinfo, selected)
200207
# print("@@@add ", selected)
201-
if partition.num_outputs > partition_thres:
208+
next_edge_ops = partition.get_edge_ops(f)
209+
if partition.num_outputs > partition_thres and not _has_critical_ops(next_edge_ops):
202210
# if an output is directly connected with the partition, add it
203211
direct_output = None
204212
for candidate, bat in ready_ops:
@@ -213,7 +221,7 @@ def maintain_ready_queue(s_op: OpBase):
213221
continue
214222
# too many outputs visited, make a new partition
215223
break
216-
selected = _select_next(ready_ops, opinfo, partition, f)
224+
selected = _select_next(ready_ops, opinfo, partition, f, next_edge_ops)
217225
if partition.ops.__len__():
218226
partitions.append(partition)
219227
if to_visit.__len__() != 0:

cpp/Kun/Ops/Quantile.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ struct SkipListStateImpl {
4242
}
4343
};
4444

45-
template <typename T, int simdLen, int window>
45+
template <typename T, int simdLen, int expectedwindow>
4646
struct SkipListState : SkipListStateImpl<T, simdLen> {
47-
SkipListState() : SkipListStateImpl<T, simdLen>(window) {}
47+
SkipListState() : SkipListStateImpl<T, simdLen>(expectedwindow) {}
4848
};
4949
} // namespace
5050

0 commit comments

Comments
 (0)