Skip to content

Commit b5e8250

Browse files
authored
[Kernl Bench] Fix run all, avoid system crash (#43)
Run all was broken with tasks variable shadowing, ProcessPoolExecutor was crashing on trying to pick the next from a generator on multiple threads. Renamed the variable and now the conflict is gone. Running all processes at the same time was crashing the system with lower memory availability, since some tasks would take between 12GB and 16GB each. Now limiting the number of workers based on the available memory at the start of the script.
1 parent b919274 commit b5e8250

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

examples/ingress/convert-kernel-bench-to-mlir.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# there's an ingore list. Runs the conversion in parallel.
99

1010
import sys
11+
import psutil
1112

1213
from concurrent.futures import ProcessPoolExecutor
1314
from dataclasses import dataclass
@@ -20,6 +21,10 @@
2021
project_root = Path(__file__).parent.parent.parent
2122
torch_kernels_dir = project_root / "third_party" / "KernelBench" / "KernelBench"
2223
mlir_kernels_dir = project_root / "cache" / "ingress" / "KernelBench"
24+
free_mem_gb = psutil.virtual_memory().available // (1024**3)
25+
print(f"Available memory: {free_mem_gb} GB")
26+
max_workers = min(free_mem_gb // 12, psutil.cpu_count()) # some workers need 12~16GB
27+
print(f"Using max_workers={max_workers} based on available memory")
2328

2429
if not torch_kernels_dir.is_dir():
2530
print(
@@ -196,12 +201,12 @@ def process_task(task: KernelConversionTask):
196201
print(mlir_kernel, file=f)
197202

198203

199-
tasks = sorted(all_tasks(), key=lambda t: (t.level, t.id))
204+
sorted_tasks = sorted(all_tasks(), key=lambda t: (t.level, t.id))
200205

201206
if len(sys.argv) == 1:
202207

203208
def tasks_():
204-
for task in tasks:
209+
for task in sorted_tasks:
205210
if task.ignore_by_default:
206211
print(
207212
f"Skipping: {task.torch_path.parent}/{task.torch_path.name}",
@@ -217,9 +222,9 @@ def tasks_():
217222
lhs, rhs = arg.split(",")
218223
level_id, kernel_id = int(lhs), int(rhs)
219224
overall_idx = 100 * (level_id - 1) + (kernel_id - 1)
220-
tasks_.append(tasks[overall_idx])
225+
tasks_.append(sorted_tasks[overall_idx])
221226
tasks = tasks_
222227

223228
print("Output directory:", mlir_kernels_dir)
224-
for _ in ProcessPoolExecutor().map(process_task, tasks):
229+
for _ in ProcessPoolExecutor(max_workers=max_workers).map(process_task, tasks):
225230
pass # NB: obtain each result so that exceptions are propagated to the main process

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dev = [
1212
"ruff==0.14.5", # Python linter and formatter
1313
"pre-commit", # Tool to manage and apply pre-commit hooks
1414
"pytest>=8.0.0",
15+
"psutil",
1516
]
1617

1718
[project.optional-dependencies]

0 commit comments

Comments
 (0)