Skip to content

Commit 9a20c4a

Browse files
authored
cleanup: keep a copy of min/max values (#411)
When finding column min/max index and value, keep a scalar copy of the min/max value to avoid redundant field accesses.
1 parent 145328a commit 9a20c4a

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

ndsl/stencils/column_operations.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ def column_max(field, start_index, end_index):
1717
Returns: [max value, index of max value]
1818
"""
1919
max_index = start_index
20+
max_value = field.at(K=max_index)
2021
level = start_index
2122
while level <= end_index:
22-
new = field.at(K=level)
23-
old = field.at(K=max_index)
24-
if new > old:
23+
value = field.at(K=level)
24+
if value > max_value:
25+
max_value = value
2526
max_index = level
2627
level += 1
2728

28-
return field.at(K=max_index), max_index
29+
return max_value, max_index
2930

3031

3132
@typing.no_type_check
@@ -42,15 +43,16 @@ def column_max_ddim(field, ddim, start_index, end_index):
4243
Returns: [max value, index of max value]
4344
"""
4445
max_index = start_index
46+
max_value = field.at(K=max_index, ddim=[ddim])
4547
level = start_index
4648
while level <= end_index:
47-
new = field.at(K=level, ddim=[ddim])
48-
old = field.at(K=max_index, ddim=[ddim])
49-
if new > old:
49+
value = field.at(K=level, ddim=[ddim])
50+
if value > max_value:
51+
max_value = value
5052
max_index = level
5153
level += 1
5254

53-
return field.at(K=max_index, ddim=[ddim]), max_index
55+
return max_value, max_index
5456

5557

5658
@typing.no_type_check
@@ -67,15 +69,16 @@ def column_min(field, start_index, end_index):
6769
Returns: [min value, index of min value]
6870
"""
6971
min_index = start_index
72+
min_value = field.at(K=min_index)
7073
level = start_index
7174
while level <= end_index:
72-
new = field.at(K=level)
73-
old = field.at(K=min_index)
74-
if new < old:
75+
value = field.at(K=level)
76+
if value < min_value:
77+
min_value = value
7578
min_index = level
7679
level += 1
7780

78-
return field.at(K=min_index), min_index
81+
return min_value, min_index
7982

8083

8184
@typing.no_type_check
@@ -92,12 +95,13 @@ def column_min_ddim(field, ddim, start_index, end_index):
9295
Returns: [min value, index of min value]
9396
"""
9497
min_index = start_index
98+
min_value = field.at(K=min_index, ddim=[ddim])
9599
level = start_index
96100
while level <= end_index:
97-
new = field.at(K=level, ddim=[ddim])
98-
old = field.at(K=min_index, ddim=[ddim])
99-
if new < old:
101+
value = field.at(K=level, ddim=[ddim])
102+
if value < min_value:
103+
min_value = value
100104
min_index = level
101105
level += 1
102106

103-
return field.at(K=min_index, ddim=[ddim]), min_index
107+
return min_value, min_index

0 commit comments

Comments
 (0)