Skip to content

Commit 7471ad0

Browse files
zack041raayandhar
authored andcommitted
Fix autotuner oom (flashinfer-ai#2442)
## 📌 Description Add graceful OOM handling during autotuning. When `torch.cuda.OutOfMemoryError` occurs, the autotuner now clears CUDA cache and falls back to the default tactic `(runners[0], -1)` instead of crashing. The try-except block wraps the entire profiling loop, covering methods like `_prepare_input_tensors()` that could also cause OOM. OOM from the inner profiling loop is raised to be caught by the outer exception handler. ## 🔍 Related Issues Fixes flashinfer-ai#2357 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes No tests added because OOM during autotuning is difficult to reliably reproduce in a test environment. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved profiling error handling so individual tactic failures are caught, logged, recorded, and do not abort tuning. * Added robust out-of-memory handling that clears GPU resources and falls back to safe/previous configurations instead of crashing. * Ensured tuning continues after non‑OOM errors, preserves cache/metrics consistency, and still selects the best measured configuration when available. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent cbab15c commit 7471ad0

1 file changed

Lines changed: 74 additions & 57 deletions

File tree

flashinfer/autotuner.py

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -466,64 +466,81 @@ def choose_one(
466466
}
467467

468468
for p in profiles:
469-
tensors = self._prepare_input_tensors(p, inputs)
470-
is_cache_hit, runner_id, tactic, _ = self.search_cache(
471-
custom_op, runners, p.get_opt_shapes(), tuning_config
472-
)
473-
if not is_cache_hit:
474-
min_time = float("inf")
475-
# Initialize runner and tactic as None in case of no valid tactic or runners are found
476-
runner_id, tactic = None, None
477-
for r_id, r in enumerate(runners):
478-
# TODO: use FakeTensor here.
479-
valid_tactics = r.get_valid_tactics(tensors, p)
480-
runner_arg_names = runner_arg_names_map[r]
481-
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
482-
r(tensors, tactic=-1, do_preparation=True, **kwargs)
483-
for tac in valid_tactics:
484-
try:
485-
time_measured = self._profile_single_kernel(
486-
r, tensors, tac, **kwargs
487-
)
488-
except Exception as e:
489-
shapes = self._get_input_sizes(tensors)
490-
logger.warning(
491-
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}"
492-
)
493-
494-
# Log stacktrace as debug to not spam log
495-
logger.debug(
496-
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
497-
)
498-
499-
# Record the failed profiling combinations
500-
if custom_op not in self.stats.failed_profiling_count:
501-
self.stats.failed_profiling_count[custom_op] = set()
502-
self.stats.failed_profiling_count[custom_op].add(
503-
AutoTuner._get_cache_key(
504-
custom_op, r, p.get_opt_shapes(), tuning_config
469+
try:
470+
tensors = self._prepare_input_tensors(p, inputs)
471+
is_cache_hit, runner_id, tactic, _ = self.search_cache(
472+
custom_op, runners, p.get_opt_shapes(), tuning_config
473+
)
474+
if not is_cache_hit:
475+
min_time = float("inf")
476+
# Initialize runner and tactic as None in case of no valid tactic or runners are found
477+
runner_id, tactic = None, None
478+
for r_id, r in enumerate(runners):
479+
# TODO: use FakeTensor here.
480+
valid_tactics = r.get_valid_tactics(tensors, p)
481+
runner_arg_names = runner_arg_names_map[r]
482+
if (
483+
"do_preparation" in runner_arg_names
484+
and len(valid_tactics) > 0
485+
):
486+
r(tensors, tactic=-1, do_preparation=True, **kwargs)
487+
for tac in valid_tactics:
488+
try:
489+
time_measured = self._profile_single_kernel(
490+
r, tensors, tac, **kwargs
505491
)
506-
)
507-
508-
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
509-
# or some runtime error occurs during profiling.
510-
time_measured = float("inf")
511-
if time_measured < min_time:
512-
min_time = time_measured
513-
runner_id, tactic = r_id, tac
514-
if runner_id is not None:
515-
# At least one valid (runner, tactic) pair is found
516-
cache_key = AutoTuner._get_cache_key(
517-
custom_op, runners[runner_id], p.get_opt_shapes(), tuning_config
518-
)
519-
# inspect call stack
520-
self.profiling_cache[cache_key] = (runner_id, tactic, p)
521-
self.stats.tuned_op_successful_configs[custom_op] = (
522-
self.stats.tuned_op_successful_configs.get(custom_op, 0) + 1
523-
)
524-
logger.debug(
525-
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
526-
)
492+
except torch.cuda.OutOfMemoryError:
493+
raise
494+
except Exception as e:
495+
shapes = self._get_input_sizes(tensors)
496+
logger.warning(
497+
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}"
498+
)
499+
500+
# Log stacktrace as debug to not spam log
501+
logger.debug(
502+
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
503+
)
504+
505+
# Record the failed profiling combinations
506+
if custom_op not in self.stats.failed_profiling_count:
507+
self.stats.failed_profiling_count[custom_op] = set()
508+
self.stats.failed_profiling_count[custom_op].add(
509+
AutoTuner._get_cache_key(
510+
custom_op, r, p.get_opt_shapes(), tuning_config
511+
)
512+
)
513+
514+
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
515+
# or some runtime error occurs during profiling.
516+
time_measured = float("inf")
517+
if time_measured < min_time:
518+
min_time = time_measured
519+
runner_id, tactic = r_id, tac
520+
521+
if runner_id is not None:
522+
# At least one valid (runner, tactic) pair is found
523+
cache_key = AutoTuner._get_cache_key(
524+
custom_op,
525+
runners[runner_id],
526+
p.get_opt_shapes(),
527+
tuning_config,
528+
)
529+
# inspect call stack
530+
self.profiling_cache[cache_key] = (runner_id, tactic, p)
531+
self.stats.tuned_op_successful_configs[custom_op] = (
532+
self.stats.tuned_op_successful_configs.get(custom_op, 0) + 1
533+
)
534+
logger.debug(
535+
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
536+
)
537+
538+
except torch.cuda.OutOfMemoryError:
539+
torch.cuda.empty_cache()
540+
logger.warning(
541+
"[Autotuner]: OOM detected, falling back to default tactic"
542+
)
543+
return runners[0], -1
527544

528545
# Get the best runner and tactic from cache
529546
# If no valid tactic is found, the fallback runner and tactic will be used

0 commit comments

Comments
 (0)