Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions iree/turbine/kernel/boo/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import csv
import gc
import argparse
import math
import traceback
import time
from typing import Callable, Sequence, NamedTuple
import os
import shlex
Expand All @@ -31,7 +33,7 @@
STDDEV_CHECK_ATOL_DEFAULT,
STDDEV_CHECK_RTOL_DEFAULT,
)
from iree.turbine.kernel.boo.driver.utils import get_timing_parser
from iree.turbine.kernel.boo.driver.utils import get_timing_parser, resolve_timing_args
from iree.turbine.kernel.boo.runtime.cache import set_cache_dir, toggle_cache_on
from iree.turbine.runtime.device import get_device_from_torch

Expand All @@ -51,6 +53,25 @@ class ZoneStats(NamedTuple):
ZoneStatsSummary = dict[str, ZoneStats]


def compute_auto_iters(warmup_time: float, min_time: float, min_iter: int) -> int:
"""Compute the number of iterations needed to run for at least `min_time` seconds.

When min_time is active (> 0), its computed iteration count takes priority.
The min_iter value is only used when min_time is disabled (i.e. via --iter).

Args:
warmup_time: Time in seconds for a single warmup iteration.
min_time: Minimum benchmark duration in seconds.
min_iter: Number of iterations when min_time is disabled (from --min-iter).

Returns:
The iteration count to use.
"""
if warmup_time > 0 and min_time > 0:
return math.ceil(min_time / warmup_time)
return min_iter


def _get_main_driver_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
usage="%(prog)s [-h] [... MIOpenDriver command ...] [--commands-file COMMANDS_FILE]",
Expand Down Expand Up @@ -257,6 +278,7 @@ def main(args: list[str] = sys.argv[1:]) -> int:
else:
print("Running test :", test_count)
timing_args, runner_args = timing_parser.parse_known_args(driver_args)
resolve_timing_args(timing_args)
csv_row.append(shlex.join(driver_args))
signature = BooOpRegistry.parse_command(runner_args)

Expand All @@ -278,7 +300,7 @@ def main(args: list[str] = sys.argv[1:]) -> int:
signature, meta_args.splat_input_value, devices
)

prof = run(
prof, actual_iter = run(
_func,
timing_args,
sample_inputs,
Expand Down Expand Up @@ -311,15 +333,15 @@ def main(args: list[str] = sys.argv[1:]) -> int:
if meta_args.verbose:
_print_zone_stats(results)

aggregate_stats = get_aggregate_stats(csv_stats, results, timing_args.iter)
aggregate_stats = get_aggregate_stats(csv_stats, results, actual_iter)

# Check that the number of dispatches per launch is an integer
dispatches_per_launch = aggregate_stats.num_dispatches / timing_args.iter
dispatches_per_launch = aggregate_stats.num_dispatches / actual_iter
if not dispatches_per_launch.is_integer():
if meta_args.verbose:
print(
f">>> ERROR: Number of dispatches per launch is fractional: {dispatches_per_launch} "
f"(total dispatches: {aggregate_stats.num_dispatches}, iterations: {timing_args.iter}). "
f"(total dispatches: {aggregate_stats.num_dispatches}, iterations: {actual_iter}). "
f"This usually indicates the torch profiler failed to capture data for the entire run. "
f"Try lowering the iteration count with --iter."
)
Expand Down Expand Up @@ -503,10 +525,11 @@ def run(
per_device_args: Sequence[tuple[torch.Tensor, ...]],
devices: Sequence[torch.device],
verbose: bool,
) -> profile | None:
) -> tuple[profile | None, int]:
"""Distributes `iter`-many applications of `func` to `per_device_args`. If
timing is requested, returns a torch profiler object that can be inspected
to recover time-related information."""
to recover time-related information, along with the actual iteration count
used (which may be auto-adjusted upward from --iter)."""

def pause_and_collect_mem():
for device in devices:
Expand All @@ -524,7 +547,33 @@ def pause_and_collect_mem():
for device in devices:
get_device_from_torch(device).hal_device.allocator.trim()

# First call: triggers JIT compilation for compiled backends.
torch.cuda.synchronize(devices[0])
example_results = func(*per_device_args[0])
torch.cuda.synchronize(devices[0])

# Second call: measures actual kernel run time (excluding compilation).
torch.cuda.synchronize(devices[0])
warmup_start = time.time()
example_results = func(*per_device_args[0])
torch.cuda.synchronize(devices[0])
warmup_time = time.time() - warmup_start

# Auto-adjust iteration count: ensure benchmark runs for at least min_time seconds.
if timing_args.time:
actual_iters = compute_auto_iters(
warmup_time, timing_args.min_time, timing_args.min_iter
)
else:
actual_iters = timing_args.min_iter

if verbose and actual_iters != timing_args.min_iter:
print(
f">>>\tAuto-adjusted iterations: {actual_iters} "
f"(warmup: {warmup_time:.4f}s, target: {timing_args.min_time:.1f}s, "
f"min-iter: {timing_args.min_iter})"
)

output_num_bytes = sum(x.element_size() * x.numel() for x in example_results)
input_num_bytes = sum(x.element_size() * x.numel() for x in per_device_args[0])
num_devices = len(per_device_args)
Expand All @@ -536,14 +585,14 @@ def pause_and_collect_mem():
), "Cannot reliably profile if cleanup is needed after every step."

schedule_fn, total_num_iters, needs_cleanup = make_profiler_schedule(
timing_args.iter, num_devices, iter_thresh
actual_iters, num_devices, iter_thresh
)

if timing_args.time:
profile_context = make_profiler_context(schedule_fn)
else:
# When not profiling, just run as many times as requested.
total_num_iters = timing_args.iter
total_num_iters = actual_iters
profile_context = nullcontext()

results: tuple[torch.Tensor, ...] | torch.Tensor | None = None
Expand All @@ -569,7 +618,7 @@ def pause_and_collect_mem():
print(
f">>>\tresult #{i} shape: {list(result.shape)}; stride: {list(result.stride())}; dtype: {result.dtype}; device type: {result.device.type}"
)
return prof if timing_args.time else None
return (prof if timing_args.time else None, actual_iters)


DEFAULT_BACKEND = "iree_boo_experimental"
Expand Down
25 changes: 24 additions & 1 deletion iree/turbine/kernel/boo/driver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,34 @@ def get_timing_parser() -> argparse.ArgumentParser:
"--time", "-t", type=int, help="Enable timing", default=1
)
timing_parser.add_argument(
"--iter", type=int, help="Number of iterations to run", default=100
"--iter",
type=int,
help="Exact number of iterations (disables auto-adjustment; "
"shorthand for --min-iter X --min-time 0)",
default=None,
)
timing_parser.add_argument(
"--min-iter",
type=int,
help="Minimum number of iterations when auto-adjusting (default: 100)",
default=100,
)
timing_parser.add_argument(
"--min-time",
type=float,
help="Minimum benchmark duration in seconds (default: 3.0)",
default=3.0,
)
return timing_parser


def resolve_timing_args(timing_args: argparse.Namespace) -> None:
"""Resolve --iter shorthand into --min-iter and --min-time 0."""
if timing_args.iter is not None:
timing_args.min_iter = timing_args.iter
timing_args.min_time = 0.0


def load_commands(commands_file: str) -> list[str]:
"""Loads commands of a given kind from a text file.

Expand Down
5 changes: 3 additions & 2 deletions tests/kernel/boo/driver/csv_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def test_roundtrip_csv_commands_file():
# Check we don't encounter an error.
assert driver.main(args) == 0
data = _read_csv_as_dicts(csv_file, backends=[backend])
for d, c in zip(data, commands, strict=True):
expected_iters = [4, 2]
for d, c, iters in zip(data, commands, expected_iters, strict=True):
# Check the arguments column contains the individual command.
assert d["arguments"] == shlex.join(c)
# Check all convs have a single dispatch per launch.
assert d[f"{backend} num_dispatches"] == c[-1]
assert d[f"{backend} num_dispatches"] == str(iters)
39 changes: 39 additions & 0 deletions tests/kernel/boo/driver/profiler_schedule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.profiler import ProfilerAction

from iree.turbine.kernel.boo.driver.driver import (
compute_auto_iters,
make_profiler_schedule,
make_profiler_context,
)
Expand Down Expand Up @@ -147,3 +148,41 @@ def dummy_kernel():
assert (
len(cuda_events) == timing_iter
), f"Expected {timing_iter} CUDA events, got {len(cuda_events)}"


class TestComputeAutoIters:
def test_short_kernel_gets_more_iters(self):
"""A 1ms kernel should need 3000 iters for 3s target."""
result = compute_auto_iters(warmup_time=0.001, min_time=3.0, min_iter=100)
assert result == 3000

def test_long_kernel_uses_min_time(self):
"""A 10s kernel with 3s target → 1 iter. min_time overrides --iter."""
result = compute_auto_iters(warmup_time=10.0, min_time=3.0, min_iter=100)
assert result == 1

def test_exact_match(self):
"""A 0.03s kernel needs exactly 100 iters for 3s target."""
result = compute_auto_iters(warmup_time=0.03, min_time=3.0, min_iter=100)
assert result == 100

def test_rounds_up(self):
"""Should round up to ensure minimum time is met."""
result = compute_auto_iters(warmup_time=0.007, min_time=3.0, min_iter=100)
# 3.0 / 0.007 = 428.57... -> ceil = 429
assert result == 429

def test_min_time_overrides_high_min_iter(self):
"""min_time takes priority over --min-iter. 3s / 0.1s = 30 iters, not 500."""
result = compute_auto_iters(warmup_time=0.1, min_time=3.0, min_iter=500)
assert result == 30

def test_zero_min_time_uses_fallback(self):
"""When min_time is 0, use the fallback (disables auto-adjust)."""
result = compute_auto_iters(warmup_time=0.001, min_time=0.0, min_iter=100)
assert result == 100

def test_zero_warmup_time_uses_fallback(self):
"""When warmup_time is 0 (shouldn't happen), use the fallback."""
result = compute_auto_iters(warmup_time=0.0, min_time=3.0, min_iter=100)
assert result == 100
Loading