diff --git a/python/src/llvm.cc b/python/src/llvm.cc index fa93102ff7ba..724228e0ee6e 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -406,12 +406,11 @@ std::string translateLLVMIRToASM(llvm::Module &module, return result; } -std::string translateMIRToASM(const std::string &mirPath, - const std::string &triple, - const std::string &proc, - const std::string &features, - const std::vector &flags, - bool enable_fp_fusion, bool isObject) { +std::string +translateMIRToASM(const std::string &mirPath, const std::string &triple, + const std::string &proc, const std::string &features, + const std::vector &flags, bool enable_fp_fusion, + bool isObject, bool enableMISched) { using namespace mlir; // We need to start before machine-scheduler and disable it instead of simply @@ -421,8 +420,9 @@ std::string translateMIRToASM(const std::string &mirPath, // Use RAII to set options and restore them when scope exits ScopedLLVMOption startBeforeGuard("start-before", "machine-scheduler"); - ScopedLLVMOption enableMISchedGuard("enable-misched", false); - ScopedLLVMOption enablePostMISchedGuard("enable-post-misched", false); + ScopedLLVMOption enableMISchedGuard("enable-misched", enableMISched); + ScopedLLVMOption enablePostMISchedGuard("enable-post-misched", + enableMISched); if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { setLLVMOption("print-after-all", true); @@ -843,18 +843,22 @@ void init_triton_llvm(py::module &&m) { "translate_mir_to_asm", [](std::string mirPath, std::string triple, std::string proc, std::string features, std::vector flags, - bool enable_fp_fusion, bool isObject) -> py::object { + bool enable_fp_fusion, bool isObject, + bool enableMISched) -> py::object { std::string result; { py::gil_scoped_release allow_threads; result = translateMIRToASM(mirPath, triple, proc, features, flags, - enable_fp_fusion, isObject); + enable_fp_fusion, isObject, enableMISched); } if (isObject) return py::bytes(result); else return py::str(result); }, + py::arg("mirPath"), py::arg("triple"), py::arg("proc"), + py::arg("features"), py::arg("flags"), py::arg("enable_fp_fusion"), + py::arg("isObject"), py::arg("enableMISched") = false, ret::take_ownership); m.def("init_targets", []() { diff --git a/python/test/backend/test_mir_stage.py b/python/test/backend/test_mir_stage.py index b9136dea2513..722c8fa61a4c 100644 --- a/python/test/backend/test_mir_stage.py +++ b/python/test/backend/test_mir_stage.py @@ -154,14 +154,7 @@ def copy_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): torch.testing.assert_close(output2, x) -def test_mir_swap_pipeline_passes(tmp_path): - """Test that MIR swap pipeline starts before machine-scheduler and disables schedulers.""" - import re - import os - import subprocess - - # Write test script to a file (required for @triton.jit to get source) - test_script = ''' +_SIMPLE_KERNEL_SCRIPT = ''' import triton import triton.language as tl import torch @@ -182,8 +175,15 @@ def simple_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): simple_kernel[grid](x, output, size, BLOCK_SIZE=128) ''' + +def test_mir_swap_pipeline_passes(tmp_path): + """Test that MIR swap pipeline starts before machine-scheduler and disables schedulers.""" + import re + import os + import subprocess + script_file = tmp_path / "test_kernel.py" - script_file.write_text(test_script) + script_file.write_text(_SIMPLE_KERNEL_SCRIPT) # Phase 1: Dump MIR env = os.environ.copy() @@ -294,3 +294,157 @@ def extract_machine_code(text): assert before_mc == after_mc, \ "post-RA scheduler should not modify MIR when disabled, but MIR changed" + + +def _dump_and_prepare_mir(tmp_path, script_file): + """Dump MIR for a kernel script and strip it for swapping. Returns the cleaned MIR file path.""" + import os + import subprocess + + env = os.environ.copy() + env["TRITON_DUMP_MIR"] = str(tmp_path) + env["TRITON_ALWAYS_COMPILE"] = "1" + + result = subprocess.run(["python", str(script_file)], capture_output=True, text=True, env=env, timeout=120) + assert result.returncode == 0, \ + f"Dump phase should succeed. stderr: {result.stderr[:1000]}" + + mir_files = list(tmp_path.glob("complex_kernel_*.txt")) + assert len(mir_files) == 1, "Exactly one MIR file should have been dumped" + + mir_file = mir_files[0] + mir_content = mir_file.read_text() + dag_marker = "\n---\n==========" + if dag_marker in mir_content: + mir_content = mir_content.split(dag_marker)[0] + if mir_content.rstrip().endswith("..."): + mir_content = mir_content.rstrip()[:-3] + mir_file.write_text(mir_content) + return mir_file + + +def _swap_mir_and_get_output(tmp_path, script_file, enable_misched): + """Swap MIR with LLVM_IR_ENABLE_DUMP and return stderr output.""" + import os + import subprocess + + env = os.environ.copy() + env["TRITON_SWAP_MIR"] = str(tmp_path) + env["TRITON_ALWAYS_COMPILE"] = "1" + env["LLVM_IR_ENABLE_DUMP"] = "1" + if enable_misched: + env["TRITON_SWAP_MIR_ENABLE_MISCHED"] = "1" + + result = subprocess.run(["python", str(script_file)], capture_output=True, text=True, env=env, timeout=120) + assert result.returncode == 0, \ + f"Swap phase (misched={'enabled' if enable_misched else 'disabled'}) should succeed. stderr: {result.stderr[:1000]}" + return result.stderr + + +def _extract_mc_around_sched(output_text): + """Extract machine code before and after the Machine Instruction Scheduler pass.""" + import re + + dumps = re.split(r'# \*\*\* IR Dump After ([^*]+) \*\*\*', output_text) + + machine_sched_idx = None + for i, part in enumerate(dumps): + if 'Machine Instruction Scheduler' in part and 'PostRA' not in part: + machine_sched_idx = i + break + + if machine_sched_idx is None or machine_sched_idx < 1 or machine_sched_idx + 1 >= len(dumps): + return None, None + + def extract_machine_code(text): + match = re.search(r'# Machine code for function.*', text, re.DOTALL) + return match.group(0).strip() if match else text.strip() + + before_mc = extract_machine_code(dumps[machine_sched_idx - 1]) + after_mc = extract_machine_code(dumps[machine_sched_idx + 1]) + return before_mc, after_mc + + +# Kernel script with enough independent operations for the scheduler to reorder +_COMPLEX_KERNEL_SCRIPT = ''' +import triton +import triton.language as tl +import torch + +@triton.jit +def complex_kernel(a_ptr, b_ptr, c_ptr, d_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Multiple independent loads + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + c = tl.load(c_ptr + offsets, mask=mask) + d = tl.load(d_ptr + offsets, mask=mask) + # Independent arithmetic chains + ab = a * b + c + cd = c * d + a + bd = b + d + ac = a - c + # Merge results + result = ab * cd + bd * ac + tl.store(output_ptr + offsets, result, mask=mask) + +size = 1024 +a = torch.randn(size, device='cuda') +b = torch.randn(size, device='cuda') +c = torch.randn(size, device='cuda') +d = torch.randn(size, device='cuda') +output = torch.empty_like(a) +grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']), ) +complex_kernel[grid](a, b, c, d, output, size, BLOCK_SIZE=256) + +expected = (a * b + c) * (c * d + a) + (b + d) * (a - c) +torch.testing.assert_close(output, expected) +''' + + +def test_mir_swap_enable_misched(tmp_path): + """Test that TRITON_SWAP_MIR_ENABLE_MISCHED=1 causes the machine scheduler to actually modify MIR.""" + script_file = tmp_path / "test_kernel.py" + script_file.write_text(_COMPLEX_KERNEL_SCRIPT) + + # Phase 1: Dump and prepare MIR + _dump_and_prepare_mir(tmp_path, script_file) + + # Phase 2: Swap with misched DISABLED (default) — scheduler should be a no-op + disabled_output = _swap_mir_and_get_output(tmp_path, script_file, enable_misched=False) + before_disabled, after_disabled = _extract_mc_around_sched(disabled_output) + + assert before_disabled is not None and after_disabled is not None, \ + "Should find machine code around scheduler pass (disabled case)" + assert before_disabled == after_disabled, \ + "Scheduler should NOT modify MIR when misched is disabled" + + # Phase 3: Swap with misched ENABLED — scheduler should actually reschedule + enabled_output = _swap_mir_and_get_output(tmp_path, script_file, enable_misched=True) + before_enabled, after_enabled = _extract_mc_around_sched(enabled_output) + + assert before_enabled is not None and after_enabled is not None, \ + "Should find machine code around scheduler pass (enabled case)" + assert before_enabled != after_enabled, \ + "Scheduler SHOULD modify MIR when misched is enabled" + + +def test_mir_swap_enable_misched_requires_swap_mir(tmp_path): + """Test that TRITON_SWAP_MIR_ENABLE_MISCHED raises an error without TRITON_SWAP_MIR.""" + import os + import subprocess + + script_file = tmp_path / "test_kernel.py" + script_file.write_text(_SIMPLE_KERNEL_SCRIPT) + + env = os.environ.copy() + env["TRITON_SWAP_MIR_ENABLE_MISCHED"] = "1" + env["TRITON_ALWAYS_COMPILE"] = "1" + # TRITON_SWAP_MIR is NOT set + + result = subprocess.run(["python", str(script_file)], capture_output=True, text=True, env=env, timeout=120) + assert result.returncode != 0 + assert "TRITON_SWAP_MIR_ENABLE_MISCHED requires TRITON_SWAP_MIR" in result.stderr diff --git a/python/triton/knobs.py b/python/triton/knobs.py index c8bf2d421a06..524357be6d4e 100644 --- a/python/triton/knobs.py +++ b/python/triton/knobs.py @@ -521,6 +521,8 @@ class amd_knobs(base_knobs): dump_mir: env_opt_str = env_opt_str("TRITON_DUMP_MIR") # Path to externally-provided MIR files to use instead of generated ones swap_mir: env_opt_str = env_opt_str("TRITON_SWAP_MIR") + # Enable machine instruction scheduler in MIR swap mode + swap_mir_enable_misched: env_bool = env_bool("TRITON_SWAP_MIR_ENABLE_MISCHED", False) class proton_knobs(base_knobs): diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 0796216f7fe5..ff2f3bf973dc 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -473,10 +473,12 @@ def make_amdgcn(src, metadata, options): dump_file_id) llvm.dump_sched_dag(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion, dump_file_id) + if knobs.amd.swap_mir_enable_misched and not knobs.amd.swap_mir: + raise ValueError("TRITON_SWAP_MIR_ENABLE_MISCHED requires TRITON_SWAP_MIR to be set") if knobs.amd.swap_mir: - amdgcn = llvm.translate_mir_to_asm(os.path.join(knobs.amd.swap_mir, - dump_file_id + '.txt'), amd.TARGET_TRIPLE, options.arch, - features, flags, options.enable_fp_fusion, False) + amdgcn = llvm.translate_mir_to_asm(os.path.join(knobs.amd.swap_mir, dump_file_id + '.txt'), + amd.TARGET_TRIPLE, options.arch, features, flags, + options.enable_fp_fusion, False, knobs.amd.swap_mir_enable_misched) else: amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion, False)