diff --git a/iree/turbine/kernel/boo/driver/driver.py b/iree/turbine/kernel/boo/driver/driver.py index a45f1e843..2f278a527 100644 --- a/iree/turbine/kernel/boo/driver/driver.py +++ b/iree/turbine/kernel/boo/driver/driver.py @@ -9,6 +9,9 @@ import csv import gc import argparse +import multiprocessing +import signal +import time import traceback from typing import Callable, Sequence, NamedTuple import os @@ -50,6 +53,76 @@ class ZoneStats(NamedTuple): ZoneStatsSummary = dict[str, ZoneStats] +def _call_with_timeout_subprocess(fn, args, kwargs, return_pipe): + """Helper function to run in subprocess with timeout.""" + try: + result = fn(*args, **kwargs) + return_pipe.send(result) + except Exception as exc: + return_pipe.send(("exception", exc)) + + +def _call_with_timeout(fn, args, kwargs=None, timeout=None): + """ + Execute a function with a timeout. + + Args: + fn: Function to execute + args: Positional arguments for fn + kwargs: Keyword arguments for fn + timeout: Timeout in seconds (None means no timeout) + + Returns: + Result from fn + + Raises: + TimeoutError: If execution exceeds timeout + Exception: Any exception raised by fn + """ + if timeout is None: + # No timeout, run directly + return fn(*args, **(kwargs or {})) + + kwargs = kwargs or {} + parent_conn, child_conn = multiprocessing.Pipe() + start = time.time() + # Use spawn instead of fork to avoid CUDA re-initialization issues + ctx = multiprocessing.get_context('spawn') + proc = ctx.Process( + target=_call_with_timeout_subprocess, args=(fn, args, kwargs, child_conn) + ) + proc.start() + + while proc.is_alive(): + if parent_conn.poll(1): + result = parent_conn.recv() + proc.join() + if isinstance(result, tuple) and len(result) == 2 and result[0] == "exception": + raise result[1] + return result + if time.time() - start > timeout: + # Timeout exceeded - aggressively kill the process + # GPU operations can hang, so don't wait - use SIGKILL immediately + try: + os.kill(proc.pid, signal.SIGKILL) + except ProcessLookupError: + pass + proc.join(timeout=2) + if proc.is_alive(): + proc.kill() + proc.join() + raise TimeoutError(f"Command execution exceeded timeout of {timeout} seconds") + + proc.join() + if proc.exitcode == 0: + result = parent_conn.recv() + if isinstance(result, tuple) and len(result) == 2 and result[0] == "exception": + raise result[1] + return result + else: + raise RuntimeError(f"Subprocess exited with code {proc.exitcode}") + + def _get_main_driver_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( usage="%(prog)s [-h] [... MIOpenDriver command ...] [--commands-file COMMANDS_FILE]", @@ -68,6 +141,11 @@ def _get_main_driver_parser() -> argparse.ArgumentParser: formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--commands-file", type=str, help="read commands from file") + parser.add_argument( + "--skip-commands-file", + type=str, + help="file containing commands to skip from the commands file", + ) parser.add_argument( "--backend", dest="backends", @@ -164,9 +242,180 @@ def _get_main_driver_parser() -> argparse.ArgumentParser: action="store_true", help="Skip structured pattern tests during numerics verification.", ) + parser.add_argument( + "--timeout", + type=float, + default=None, + help="Timeout in seconds for each command execution. If a command exceeds this time, it will be terminated and marked as timed out.", + ) return parser +def _execute_single_command( + driver_args: list[str], + meta_args: argparse.Namespace, + backends: list[str], + gpu_id: int, + csv_stats: list[str], + numerics_csv_cols: list[str], +) -> tuple[list[str], bool]: + """ + Execute a single driver command. + + Returns: + tuple of (csv_row, had_error) + """ + csv_row: list[str] = [] + had_error = False + + # Create timing parser and devices in subprocess to avoid pickling issues + timing_parser = get_timing_parser() + devices = _get_devices(gpu_id) + + timing_args, runner_args = timing_parser.parse_known_args(driver_args) + csv_row.append(shlex.join(driver_args)) + signature = BooOpRegistry.parse_command(runner_args) + + if signature is None: + if meta_args.verbose: + print( + f">>> Boo op registry failed to parse '{shlex.join(runner_args)}'." + ) + csv_row.append("N.A.") + csv_row += ["N.A."] * len(numerics_csv_cols) + return csv_row, True + + for backend in backends: + try: + _func = BACKEND_TO_FUNC_GENERATOR[backend](signature) + sample_inputs = _get_sample_args( + signature, meta_args.splat_input_value, devices + ) + + prof = run( + _func, + timing_args, + sample_inputs, + devices, + meta_args.verbose, + ) + except Exception as exc: + if meta_args.verbose: + traceback.print_exception(exc) + csv_row += ["N.A."] * len(csv_stats) + had_error = True + continue + + if not timing_args.time: + csv_row += ["untimed"] * len(csv_stats) + had_error = True + continue + + zones = _extract_zones(prof) + + if len(zones.keys()) == 0: + if meta_args.verbose: + print(">>> FAILED TO COLLECT TIMING INFO") + csv_row += ["failed to collect timing info"] * len(csv_stats) + had_error = True + continue + + # Get iree stats and print. + results = _get_zone_stats(zones) + if meta_args.verbose: + _print_zone_stats(results) + + aggregate_stats = get_aggregate_stats(csv_stats, results, timing_args.iter) + + # Check that the number of dispatches per launch is an integer + dispatches_per_launch = aggregate_stats.num_dispatches / timing_args.iter + if not dispatches_per_launch.is_integer(): + if meta_args.verbose: + print( + f">>> ERROR: Number of dispatches per launch is fractional: {dispatches_per_launch} " + f"(total dispatches: {aggregate_stats.num_dispatches}, iterations: {timing_args.iter}). " + f"This usually indicates the torch profiler failed to capture data for the entire run. " + f"Try lowering the iteration count with --iter." + ) + csv_row += ["incomplete profiling data"] * len(csv_stats) + had_error = True + continue + + if meta_args.verbose: + print( + f">>>\tPer-launch # GPU kernel dispatches ({backend}): {dispatches_per_launch}" + ) + print( + f">>>\tPer-launch GPU mean time ({backend}): {aggregate_stats.mean}us" + ) + + for stat in csv_stats: + csv_row.append(f"{aggregate_stats._asdict()[stat]}") + + # Run numerics verification if requested + if meta_args.verify_numerics: + from iree.turbine.kernel.boo.driver.numerics import ( + verify_numerics, + format_verdict_verbose, + format_verdict_simple, + ) + + gpu_id = meta_args.gpu_id if meta_args.gpu_id >= 0 else 0 + cmd = shlex.join(runner_args) + ref_dtype = {"float32": torch.float32, "float64": torch.float64}[ + meta_args.numerics_reference_dtype + ] + try: + verdicts = verify_numerics( + [cmd], + device=gpu_id, + min_samples=meta_args.numerics_min_samples, + stddev_check_rtol=meta_args.numerics_stddev_rtol, + stddev_check_atol=meta_args.numerics_stddev_atol, + mean_check_atol=meta_args.numerics_mean_atol, + mean_check_rtol=meta_args.numerics_mean_rtol, + run_structured_tests=not meta_args.skip_structured_tests, + reference_dtype=ref_dtype, + ) + verdict = verdicts[0] + except Exception as exc: + if meta_args.numerics_verbose: + traceback.print_exception(exc) + csv_row += ["N.A."] * len(numerics_csv_cols) + torch.compiler.reset() + return csv_row, True + + csv_row.append("PASS" if verdict.passed else "FAIL") + for stats in [ + verdict.boo_gpu_err, + verdict.pytorch_gpu_err, + verdict.boo_pytorch_diff, + ]: + if stats is not None: + csv_row += [ + f"{stats.mean:.6e}", + f"{stats.stddev:.6e}", + f"{stats.max_abs_err:.6e}", + ] + else: + csv_row += ["N.A."] * 3 + csv_row.append( + "N.A." + if verdict.structured_test_passed is None + else ("PASS" if verdict.structured_test_passed else "FAIL") + ) + + if meta_args.numerics_verbose: + print(format_verdict_verbose(verdict)) + else: + print(format_verdict_simple(verdict)) + + if not verdict.passed: + had_error = True + + return csv_row, had_error + + def main(args: list[str] = sys.argv[1:]) -> int: # Set saner defaults for pytorch/miopen environment variables. This affects # pytorch's inferred tensor layouts on AMDGPU, even when not actually using @@ -185,6 +434,7 @@ def main(args: list[str] = sys.argv[1:]) -> int: # separates to ['foo', 'bar'] extra_cli_args = [a for arg in extra_cli_args for a in arg.split("\t")] commands_file: str | None = meta_args.commands_file + skip_commands_file: str | None = meta_args.skip_commands_file if commands_file: splitter: Callable[[str], list[str]] = lambda s: ( @@ -196,6 +446,23 @@ def main(args: list[str] = sys.argv[1:]) -> int: for s in f.readlines() if s.strip() and not s.startswith("#") ] + + # Filter out commands listed in skip_commands_file + if skip_commands_file: + skip_splitter: Callable[[str], list[str]] = lambda s: ( + s.strip().split("\t") if skip_commands_file.endswith(".tsv") else shlex.split(s) + ) + with open(skip_commands_file) as f: + skip_commands_set = { + shlex.join(skip_splitter(s)) + for s in f.readlines() + if s.strip() and not s.startswith("#") + } + # Filter out commands that match any skip command (ignoring extra_cli_args) + mio_args = [ + cmd for cmd in mio_args + if shlex.join(cmd[:-len(extra_cli_args)] if extra_cli_args else cmd) not in skip_commands_set + ] else: mio_args = [extra_cli_args] # use CLI arguments @@ -232,163 +499,53 @@ def main(args: list[str] = sys.argv[1:]) -> int: csv_file.writerow(csv_headers) - timing_parser = get_timing_parser() - - devices = _get_devices(meta_args.gpu_id) test_count = 0 test_error = 0 for driver_args in mio_args: - csv_row: list[str] = [] test_count = test_count + 1 if meta_args.verbose: print(f"\n>>> {shlex.join(driver_args)}\n") else: print("Running test :", test_count) - timing_args, runner_args = timing_parser.parse_known_args(driver_args) - csv_row.append(shlex.join(driver_args)) - signature = BooOpRegistry.parse_command(runner_args) - if signature is None: + try: + csv_row, had_error = _call_with_timeout( + _execute_single_command, + args=( + driver_args, + meta_args, + backends, + meta_args.gpu_id, + csv_stats, + numerics_csv_cols, + ), + timeout=meta_args.timeout, + ) + if had_error: + test_error += 1 + except TimeoutError as exc: if meta_args.verbose: - print( - f">>> Boo op registry failed to parse '{shlex.join(runner_args)}'." - ) - csv_row.append("N.A.") + print(f">>> TIMEOUT: {exc}") + # Clean up GPU state after killing hung process + try: + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + except Exception: + pass # Ignore cleanup errors + csv_row = [shlex.join(driver_args)] + csv_row += [f"timeout ({meta_args.timeout}s)"] * len(csv_stats) csv_row += ["N.A."] * len(numerics_csv_cols) - csv_file.writerow(csv_row) test_error += 1 - continue - - for backend in backends: - try: - _func = BACKEND_TO_FUNC_GENERATOR[backend](signature) - sample_inputs = _get_sample_args( - signature, meta_args.splat_input_value, devices - ) - - prof = run( - _func, - timing_args, - sample_inputs, - devices, - meta_args.verbose, - ) - except Exception as exc: - if meta_args.verbose: - traceback.print_exception(exc) - csv_row += ["N.A."] * len(csv_stats) - test_error += 1 - continue - - if not timing_args.time: - csv_row += ["untimed"] * len(csv_stats) - test_error += 1 - continue - - zones = _extract_zones(prof) - - if len(zones.keys()) == 0: - if meta_args.verbose: - print(">>> FAILED TO COLLECT TIMING INFO") - csv_row += ["failed to collect timing info"] * len(csv_stats) - test_error += 1 - continue - - # Get iree stats and print. - results = _get_zone_stats(zones) + except Exception as exc: if meta_args.verbose: - _print_zone_stats(results) - - aggregate_stats = get_aggregate_stats(csv_stats, results, timing_args.iter) - - # Check that the number of dispatches per launch is an integer - dispatches_per_launch = aggregate_stats.num_dispatches / timing_args.iter - if not dispatches_per_launch.is_integer(): - if meta_args.verbose: - print( - f">>> ERROR: Number of dispatches per launch is fractional: {dispatches_per_launch} " - f"(total dispatches: {aggregate_stats.num_dispatches}, iterations: {timing_args.iter}). " - f"This usually indicates the torch profiler failed to capture data for the entire run. " - f"Try lowering the iteration count with --iter." - ) - csv_row += ["incomplete profiling data"] * len(csv_stats) - test_error += 1 - continue - - if meta_args.verbose: - print( - f">>>\tPer-launch # GPU kernel dispatches ({backend}): {dispatches_per_launch}" - ) - print( - f">>>\tPer-launch GPU mean time ({backend}): {aggregate_stats.mean}us" - ) - - for stat in csv_stats: - csv_row.append(f"{aggregate_stats._asdict()[stat]}") - - # Run numerics verification if requested - if meta_args.verify_numerics: - from iree.turbine.kernel.boo.driver.numerics import ( - verify_numerics, - format_verdict_verbose, - format_verdict_simple, - ) - - gpu_id = meta_args.gpu_id if meta_args.gpu_id >= 0 else 0 - cmd = shlex.join(runner_args) - ref_dtype = {"float32": torch.float32, "float64": torch.float64}[ - meta_args.numerics_reference_dtype - ] - try: - verdicts = verify_numerics( - [cmd], - device=gpu_id, - min_samples=meta_args.numerics_min_samples, - stddev_check_rtol=meta_args.numerics_stddev_rtol, - stddev_check_atol=meta_args.numerics_stddev_atol, - mean_check_atol=meta_args.numerics_mean_atol, - mean_check_rtol=meta_args.numerics_mean_rtol, - run_structured_tests=not meta_args.skip_structured_tests, - reference_dtype=ref_dtype, - ) - verdict = verdicts[0] - except Exception as exc: - if meta_args.numerics_verbose: - traceback.print_exception(exc) - csv_row += ["N.A."] * len(numerics_csv_cols) - test_error += 1 - csv_file.writerow(csv_row) - torch.compiler.reset() - continue - - csv_row.append("PASS" if verdict.passed else "FAIL") - for stats in [ - verdict.boo_gpu_err, - verdict.pytorch_gpu_err, - verdict.boo_pytorch_diff, - ]: - if stats is not None: - csv_row += [ - f"{stats.mean:.6e}", - f"{stats.stddev:.6e}", - f"{stats.max_abs_err:.6e}", - ] - else: - csv_row += ["N.A."] * 3 - csv_row.append( - "N.A." - if verdict.structured_test_passed is None - else ("PASS" if verdict.structured_test_passed else "FAIL") - ) - - if meta_args.numerics_verbose: - print(format_verdict_verbose(verdict)) - else: - print(format_verdict_simple(verdict)) - - if not verdict.passed: - test_error += 1 + print(f">>> ERROR: {exc}") + traceback.print_exception(exc) + csv_row = [shlex.join(driver_args)] + csv_row += ["error"] * len(csv_stats) + csv_row += ["N.A."] * len(numerics_csv_cols) + test_error += 1 csv_file.writerow(csv_row) # Exit code: zero if no errors, non-zero otherwise. diff --git a/iree/turbine/kernel/boo/driver/numerics.py b/iree/turbine/kernel/boo/driver/numerics.py index c041e011d..6bbfe8583 100644 --- a/iree/turbine/kernel/boo/driver/numerics.py +++ b/iree/turbine/kernel/boo/driver/numerics.py @@ -192,10 +192,33 @@ def collect_error_samples( num_batches = max(1, (min_samples + output_numel - 1) // output_numel) # Run PyTorch GPU - with torch.no_grad(): - pytorch_gpu_result = reference_module(*gpu_args) - if isinstance(pytorch_gpu_result, torch.Tensor): - pytorch_gpu_result = (pytorch_gpu_result,) + # WORKAROUND: PyTorch/ROCm has a bug with convolution_backward on GPU + # for half-precision types (bfloat16/float16). Cast to float32 to avoid + # floating-point exceptions. + needs_fp32_cast = any( + arg.is_floating_point() and arg.dtype in (torch.bfloat16, torch.float16) + for arg in gpu_args + ) + if needs_fp32_cast: + gpu_args_fp32 = tuple( + arg.to(dtype=torch.float32) if arg.is_floating_point() else arg + for arg in gpu_args + ) + with torch.no_grad(): + pytorch_gpu_result = reference_module(*gpu_args_fp32) + if isinstance(pytorch_gpu_result, torch.Tensor): + pytorch_gpu_result = (pytorch_gpu_result.to(dtype=gpu_args[0].dtype),) + else: + # Cast all float outputs back to original dtype + pytorch_gpu_result = tuple( + res.to(dtype=gpu_args[0].dtype) if res.is_floating_point() else res + for res in pytorch_gpu_result + ) + else: + with torch.no_grad(): + pytorch_gpu_result = reference_module(*gpu_args) + if isinstance(pytorch_gpu_result, torch.Tensor): + pytorch_gpu_result = (pytorch_gpu_result,) # Run BOO GPU try: @@ -384,17 +407,47 @@ def run_structured_test( return False, f"BOO compilation failed: {e}" # Run both - with torch.no_grad(): - pytorch_result = reference_module(*gpu_args) - try: - boo_result = boo_module(*gpu_args) - except Exception as e: - return False, f"BOO runtime failed: {e}" - - if isinstance(pytorch_result, torch.Tensor): - pytorch_result = (pytorch_result,) - if isinstance(boo_result, torch.Tensor): - boo_result = (boo_result,) + # WORKAROUND: PyTorch/ROCm has a bug with convolution_backward on GPU + # for half-precision types (bfloat16/float16). Cast to float32 to avoid + # floating-point exceptions. + needs_fp32_cast = any( + arg.is_floating_point() and arg.dtype in (torch.bfloat16, torch.float16) + for arg in gpu_args + ) + if needs_fp32_cast: + gpu_args_fp32 = tuple( + arg.to(dtype=torch.float32) if arg.is_floating_point() else arg + for arg in gpu_args + ) + with torch.no_grad(): + pytorch_result = reference_module(*gpu_args_fp32) + try: + boo_result = boo_module(*gpu_args) + except Exception as e: + return False, f"BOO runtime failed: {e}" + + if isinstance(pytorch_result, torch.Tensor): + pytorch_result = (pytorch_result.to(dtype=gpu_args[0].dtype),) + else: + # Cast all float outputs back to original dtype + pytorch_result = tuple( + res.to(dtype=gpu_args[0].dtype) if res.is_floating_point() else res + for res in pytorch_result + ) + if isinstance(boo_result, torch.Tensor): + boo_result = (boo_result,) + else: + with torch.no_grad(): + pytorch_result = reference_module(*gpu_args) + try: + boo_result = boo_module(*gpu_args) + except Exception as e: + return False, f"BOO runtime failed: {e}" + + if isinstance(pytorch_result, torch.Tensor): + pytorch_result = (pytorch_result,) + if isinstance(boo_result, torch.Tensor): + boo_result = (boo_result,) # Compare main results on GPU main_idx = sig.main_result_index