Skip to content

Commit 04a9054

Browse files
committed
Update comments.
1 parent b56317a commit 04a9054

3 files changed

Lines changed: 23 additions & 23 deletions

File tree

iree/turbine/kernel/boo/driver/driver.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
STDDEV_CHECK_ATOL_DEFAULT,
3434
STDDEV_CHECK_RTOL_DEFAULT,
3535
)
36-
from iree.turbine.kernel.boo.driver.utils import get_timing_parser
36+
from iree.turbine.kernel.boo.driver.utils import get_timing_parser, resolve_timing_args
3737
from iree.turbine.runtime.device import get_device_from_torch
3838

3939
ZoneData = dict[str, list[float]]
@@ -52,23 +52,23 @@ class ZoneStats(NamedTuple):
5252
ZoneStatsSummary = dict[str, ZoneStats]
5353

5454

55-
def compute_auto_iters(warmup_time: float, min_time: float, iter_fallback: int) -> int:
55+
def compute_auto_iters(warmup_time: float, min_time: float, min_iter: int) -> int:
5656
"""Compute the number of iterations needed to run for at least `min_time` seconds.
5757
58-
When min_time is active (> 0), its computed iteration count takes priority
59-
over --iter. The iter_fallback is only used when min_time is disabled.
58+
When min_time is active (> 0), its computed iteration count takes priority.
59+
The min_iter value is only used when min_time is disabled (i.e. via --iter).
6060
6161
Args:
6262
warmup_time: Time in seconds for a single warmup iteration.
6363
min_time: Minimum benchmark duration in seconds.
64-
iter_fallback: Fallback number of iterations when min_time is disabled (from --iter).
64+
min_iter: Number of iterations when min_time is disabled (from --min-iter).
6565
6666
Returns:
6767
The iteration count to use.
6868
"""
6969
if warmup_time > 0 and min_time > 0:
7070
return math.ceil(min_time / warmup_time)
71-
return iter_fallback
71+
return min_iter
7272

7373

7474
def _get_main_driver_parser() -> argparse.ArgumentParser:
@@ -267,6 +267,7 @@ def main(args: list[str] = sys.argv[1:]) -> int:
267267
else:
268268
print("Running test :", test_count)
269269
timing_args, runner_args = timing_parser.parse_known_args(driver_args)
270+
resolve_timing_args(timing_args)
270271
csv_row.append(shlex.join(driver_args))
271272
signature = BooOpRegistry.parse_command(runner_args)
272273

@@ -550,15 +551,16 @@ def pause_and_collect_mem():
550551
# Auto-adjust iteration count: ensure benchmark runs for at least min_time seconds.
551552
if timing_args.time:
552553
actual_iters = compute_auto_iters(
553-
warmup_time, timing_args.min_time, timing_args.iter
554+
warmup_time, timing_args.min_time, timing_args.min_iter
554555
)
555556
else:
556-
actual_iters = timing_args.iter
557+
actual_iters = timing_args.min_iter
557558

558-
if verbose and actual_iters != timing_args.iter:
559+
if verbose and actual_iters != timing_args.min_iter:
559560
print(
560561
f">>>\tAuto-adjusted iterations: {actual_iters} "
561-
f"(warmup: {warmup_time:.4f}s, target: {timing_args.min_time:.1f}s, floor: {timing_args.iter})"
562+
f"(warmup: {warmup_time:.4f}s, target: {timing_args.min_time:.1f}s, "
563+
f"min-iter: {timing_args.min_iter})"
562564
)
563565

564566
output_num_bytes = sum(x.element_size() * x.numel() for x in example_results)

tests/kernel/boo/driver/csv_output_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_roundtrip_csv_single_command():
3838
csv_file = Path(td) / "conv_stats.csv"
3939
iters = 4
4040
meta_args = [f"--csv={csv_file}"]
41-
command_args = ["convbfp16", "-F=1", f"--iter={iters}", "--min-time=0"]
41+
command_args = ["convbfp16", "-F=1", f"--iter={iters}"]
4242
args = meta_args + command_args
4343
# Check we don't encounter an error.
4444
assert driver.main(args) == 0
@@ -58,7 +58,7 @@ def test_roundtrip_csv_commands_file():
5858
with tempfile.TemporaryDirectory() as td:
5959
commands_file = Path(td) / "commands.txt"
6060
commands = [
61-
["convbfp16", "-F", "1", "--iter", "4", "--min-time", "0"],
61+
["convbfp16", "-F", "1", "--iter", "4"],
6262
[
6363
"convbfp16",
6464
"-F",
@@ -71,8 +71,6 @@ def test_roundtrip_csv_commands_file():
7171
"NHWC",
7272
"--iter",
7373
"2",
74-
"--min-time",
75-
"0",
7674
],
7775
]
7876
commands_file.write_text("\n".join([shlex.join(c) for c in commands]))

tests/kernel/boo/driver/profiler_schedule_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,36 +153,36 @@ def dummy_kernel():
153153
class TestComputeAutoIters:
154154
def test_short_kernel_gets_more_iters(self):
155155
"""A 1ms kernel should need 3000 iters for 3s target."""
156-
result = compute_auto_iters(warmup_time=0.001, min_time=3.0, iter_fallback=100)
156+
result = compute_auto_iters(warmup_time=0.001, min_time=3.0, min_iter=100)
157157
assert result == 3000
158158

159159
def test_long_kernel_uses_min_time(self):
160160
"""A 10s kernel with 3s target → 1 iter. min_time overrides --iter."""
161-
result = compute_auto_iters(warmup_time=10.0, min_time=3.0, iter_fallback=100)
161+
result = compute_auto_iters(warmup_time=10.0, min_time=3.0, min_iter=100)
162162
assert result == 1
163163

164164
def test_exact_match(self):
165165
"""A 0.03s kernel needs exactly 100 iters for 3s target."""
166-
result = compute_auto_iters(warmup_time=0.03, min_time=3.0, iter_fallback=100)
166+
result = compute_auto_iters(warmup_time=0.03, min_time=3.0, min_iter=100)
167167
assert result == 100
168168

169169
def test_rounds_up(self):
170170
"""Should round up to ensure minimum time is met."""
171-
result = compute_auto_iters(warmup_time=0.007, min_time=3.0, iter_fallback=100)
171+
result = compute_auto_iters(warmup_time=0.007, min_time=3.0, min_iter=100)
172172
# 3.0 / 0.007 = 428.57... -> ceil = 429
173173
assert result == 429
174174

175-
def test_min_time_overrides_high_iter(self):
176-
"""min_time takes priority over --iter. 3s / 0.1s = 30 iters, not 500."""
177-
result = compute_auto_iters(warmup_time=0.1, min_time=3.0, iter_fallback=500)
175+
def test_min_time_overrides_high_min_iter(self):
176+
"""min_time takes priority over --min-iter. 3s / 0.1s = 30 iters, not 500."""
177+
result = compute_auto_iters(warmup_time=0.1, min_time=3.0, min_iter=500)
178178
assert result == 30
179179

180180
def test_zero_min_time_uses_fallback(self):
181181
"""When min_time is 0, use the fallback (disables auto-adjust)."""
182-
result = compute_auto_iters(warmup_time=0.001, min_time=0.0, iter_fallback=100)
182+
result = compute_auto_iters(warmup_time=0.001, min_time=0.0, min_iter=100)
183183
assert result == 100
184184

185185
def test_zero_warmup_time_uses_fallback(self):
186186
"""When warmup_time is 0 (shouldn't happen), use the fallback."""
187-
result = compute_auto_iters(warmup_time=0.0, min_time=3.0, iter_fallback=100)
187+
result = compute_auto_iters(warmup_time=0.0, min_time=3.0, min_iter=100)
188188
assert result == 100

0 commit comments

Comments
 (0)