Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 74 additions & 57 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,64 +466,81 @@ def choose_one(
}

for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, runner_id, tactic, _ = self.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config
)
if not is_cache_hit:
min_time = float("inf")
# Initialize runner and tactic as None in case of no valid tactic or runners are found
runner_id, tactic = None, None
for r_id, r in enumerate(runners):
# TODO: use FakeTensor here.
valid_tactics = r.get_valid_tactics(tensors, p)
runner_arg_names = runner_arg_names_map[r]
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
r(tensors, tactic=-1, do_preparation=True, **kwargs)
for tac in valid_tactics:
try:
time_measured = self._profile_single_kernel(
r, tensors, tac, **kwargs
)
except Exception as e:
shapes = self._get_input_sizes(tensors)
logger.warning(
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}"
)

# Log stacktrace as debug to not spam log
logger.debug(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
)

# Record the failed profiling combinations
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
self.stats.failed_profiling_count[custom_op].add(
AutoTuner._get_cache_key(
custom_op, r, p.get_opt_shapes(), tuning_config
try:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, runner_id, tactic, _ = self.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config
)
if not is_cache_hit:
min_time = float("inf")
# Initialize runner and tactic as None in case of no valid tactic or runners are found
runner_id, tactic = None, None
for r_id, r in enumerate(runners):
# TODO: use FakeTensor here.
valid_tactics = r.get_valid_tactics(tensors, p)
runner_arg_names = runner_arg_names_map[r]
if (
"do_preparation" in runner_arg_names
and len(valid_tactics) > 0
):
r(tensors, tactic=-1, do_preparation=True, **kwargs)
for tac in valid_tactics:
try:
time_measured = self._profile_single_kernel(
r, tensors, tac, **kwargs
)
)

# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
# or some runtime error occurs during profiling.
time_measured = float("inf")
if time_measured < min_time:
min_time = time_measured
runner_id, tactic = r_id, tac
if runner_id is not None:
# At least one valid (runner, tactic) pair is found
cache_key = AutoTuner._get_cache_key(
custom_op, runners[runner_id], p.get_opt_shapes(), tuning_config
)
# inspect call stack
self.profiling_cache[cache_key] = (runner_id, tactic, p)
self.stats.tuned_op_successful_configs[custom_op] = (
self.stats.tuned_op_successful_configs.get(custom_op, 0) + 1
)
logger.debug(
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
)
except torch.cuda.OutOfMemoryError:
raise
except Exception as e:
shapes = self._get_input_sizes(tensors)
logger.warning(
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}"
)

# Log stacktrace as debug to not spam log
logger.debug(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
)

# Record the failed profiling combinations
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
self.stats.failed_profiling_count[custom_op].add(
AutoTuner._get_cache_key(
custom_op, r, p.get_opt_shapes(), tuning_config
)
)

# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
# or some runtime error occurs during profiling.
time_measured = float("inf")
if time_measured < min_time:
min_time = time_measured
runner_id, tactic = r_id, tac

if runner_id is not None:
# At least one valid (runner, tactic) pair is found
cache_key = AutoTuner._get_cache_key(
custom_op,
runners[runner_id],
p.get_opt_shapes(),
tuning_config,
)
# inspect call stack
self.profiling_cache[cache_key] = (runner_id, tactic, p)
self.stats.tuned_op_successful_configs[custom_op] = (
self.stats.tuned_op_successful_configs.get(custom_op, 0) + 1
)
logger.debug(
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
)

except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
logger.warning(
"[Autotuner]: OOM detected, falling back to default tactic"
)
return runners[0], -1

# Get the best runner and tactic from cache
# If no valid tactic is found, the fallback runner and tactic will be used
Expand Down
Loading