Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,11 @@ std::string translateLLVMIRToASM(llvm::Module &module,
return result;
}

std::string translateMIRToASM(const std::string &mirPath,
const std::string &triple,
const std::string &proc,
const std::string &features,
const std::vector<std::string> &flags,
bool enable_fp_fusion, bool isObject) {
std::string
translateMIRToASM(const std::string &mirPath, const std::string &triple,
const std::string &proc, const std::string &features,
const std::vector<std::string> &flags, bool enable_fp_fusion,
bool isObject, bool enableMISched) {
using namespace mlir;

// We need to start before machine-scheduler and disable it instead of simply
Expand All @@ -421,8 +420,9 @@ std::string translateMIRToASM(const std::string &mirPath,
// Use RAII to set options and restore them when scope exits
ScopedLLVMOption<std::string> startBeforeGuard("start-before",
"machine-scheduler");
ScopedLLVMOption<bool> enableMISchedGuard("enable-misched", false);
ScopedLLVMOption<bool> enablePostMISchedGuard("enable-post-misched", false);
ScopedLLVMOption<bool> enableMISchedGuard("enable-misched", enableMISched);
ScopedLLVMOption<bool> enablePostMISchedGuard("enable-post-misched",
enableMISched);

if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
setLLVMOption<bool>("print-after-all", true);
Expand Down Expand Up @@ -843,18 +843,22 @@ void init_triton_llvm(py::module &&m) {
"translate_mir_to_asm",
[](std::string mirPath, std::string triple, std::string proc,
std::string features, std::vector<std::string> flags,
bool enable_fp_fusion, bool isObject) -> py::object {
bool enable_fp_fusion, bool isObject,
bool enableMISched) -> py::object {
std::string result;
{
py::gil_scoped_release allow_threads;
result = translateMIRToASM(mirPath, triple, proc, features, flags,
enable_fp_fusion, isObject);
enable_fp_fusion, isObject, enableMISched);
}
if (isObject)
return py::bytes(result);
else
return py::str(result);
},
py::arg("mirPath"), py::arg("triple"), py::arg("proc"),
py::arg("features"), py::arg("flags"), py::arg("enable_fp_fusion"),
py::arg("isObject"), py::arg("enableMISched") = false,
ret::take_ownership);

m.def("init_targets", []() {
Expand Down
172 changes: 163 additions & 9 deletions python/test/backend/test_mir_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,7 @@ def copy_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
torch.testing.assert_close(output2, x)


def test_mir_swap_pipeline_passes(tmp_path):
"""Test that MIR swap pipeline starts before machine-scheduler and disables schedulers."""
import re
import os
import subprocess

# Write test script to a file (required for @triton.jit to get source)
test_script = '''
_SIMPLE_KERNEL_SCRIPT = '''
import triton
import triton.language as tl
import torch
Expand All @@ -182,8 +175,15 @@ def simple_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
simple_kernel[grid](x, output, size, BLOCK_SIZE=128)
'''


def test_mir_swap_pipeline_passes(tmp_path):
"""Test that MIR swap pipeline starts before machine-scheduler and disables schedulers."""
import re
import os
import subprocess

script_file = tmp_path / "test_kernel.py"
script_file.write_text(test_script)
script_file.write_text(_SIMPLE_KERNEL_SCRIPT)

# Phase 1: Dump MIR
env = os.environ.copy()
Expand Down Expand Up @@ -294,3 +294,157 @@ def extract_machine_code(text):

assert before_mc == after_mc, \
"post-RA scheduler should not modify MIR when disabled, but MIR changed"


def _dump_and_prepare_mir(tmp_path, script_file):
"""Dump MIR for a kernel script and strip it for swapping. Returns the cleaned MIR file path."""
import os
import subprocess

env = os.environ.copy()
env["TRITON_DUMP_MIR"] = str(tmp_path)
env["TRITON_ALWAYS_COMPILE"] = "1"

result = subprocess.run(["python", str(script_file)], capture_output=True, text=True, env=env, timeout=120)
assert result.returncode == 0, \
f"Dump phase should succeed. stderr: {result.stderr[:1000]}"

mir_files = list(tmp_path.glob("complex_kernel_*.txt"))
assert len(mir_files) == 1, "Exactly one MIR file should have been dumped"

mir_file = mir_files[0]
mir_content = mir_file.read_text()
dag_marker = "\n---\n=========="
if dag_marker in mir_content:
mir_content = mir_content.split(dag_marker)[0]
if mir_content.rstrip().endswith("..."):
mir_content = mir_content.rstrip()[:-3]
mir_file.write_text(mir_content)
return mir_file


def _swap_mir_and_get_output(tmp_path, script_file, enable_misched):
"""Swap MIR with LLVM_IR_ENABLE_DUMP and return stderr output."""
import os
import subprocess

env = os.environ.copy()
env["TRITON_SWAP_MIR"] = str(tmp_path)
env["TRITON_ALWAYS_COMPILE"] = "1"
env["LLVM_IR_ENABLE_DUMP"] = "1"
if enable_misched:
env["TRITON_SWAP_MIR_ENABLE_MISCHED"] = "1"

result = subprocess.run(["python", str(script_file)], capture_output=True, text=True, env=env, timeout=120)
assert result.returncode == 0, \
f"Swap phase (misched={'enabled' if enable_misched else 'disabled'}) should succeed. stderr: {result.stderr[:1000]}"
return result.stderr


def _extract_mc_around_sched(output_text):
"""Extract machine code before and after the Machine Instruction Scheduler pass."""
import re

dumps = re.split(r'# \*\*\* IR Dump After ([^*]+) \*\*\*', output_text)

machine_sched_idx = None
for i, part in enumerate(dumps):
if 'Machine Instruction Scheduler' in part and 'PostRA' not in part:
machine_sched_idx = i
break

if machine_sched_idx is None or machine_sched_idx < 1 or machine_sched_idx + 1 >= len(dumps):
return None, None

def extract_machine_code(text):
match = re.search(r'# Machine code for function.*', text, re.DOTALL)
return match.group(0).strip() if match else text.strip()

before_mc = extract_machine_code(dumps[machine_sched_idx - 1])
after_mc = extract_machine_code(dumps[machine_sched_idx + 1])
return before_mc, after_mc


# Kernel script with enough independent operations for the scheduler to reorder
_COMPLEX_KERNEL_SCRIPT = '''
import triton
import triton.language as tl
import torch

@triton.jit
def complex_kernel(a_ptr, b_ptr, c_ptr, d_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Multiple independent loads
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
c = tl.load(c_ptr + offsets, mask=mask)
d = tl.load(d_ptr + offsets, mask=mask)
# Independent arithmetic chains
ab = a * b + c
cd = c * d + a
bd = b + d
ac = a - c
# Merge results
result = ab * cd + bd * ac
tl.store(output_ptr + offsets, result, mask=mask)

size = 1024
a = torch.randn(size, device='cuda')
b = torch.randn(size, device='cuda')
c = torch.randn(size, device='cuda')
d = torch.randn(size, device='cuda')
output = torch.empty_like(a)
grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']), )
complex_kernel[grid](a, b, c, d, output, size, BLOCK_SIZE=256)

expected = (a * b + c) * (c * d + a) + (b + d) * (a - c)
torch.testing.assert_close(output, expected)
'''


def test_mir_swap_enable_misched(tmp_path):
"""Test that TRITON_SWAP_MIR_ENABLE_MISCHED=1 causes the machine scheduler to actually modify MIR."""
script_file = tmp_path / "test_kernel.py"
script_file.write_text(_COMPLEX_KERNEL_SCRIPT)

# Phase 1: Dump and prepare MIR
_dump_and_prepare_mir(tmp_path, script_file)

# Phase 2: Swap with misched DISABLED (default) — scheduler should be a no-op
disabled_output = _swap_mir_and_get_output(tmp_path, script_file, enable_misched=False)
before_disabled, after_disabled = _extract_mc_around_sched(disabled_output)

assert before_disabled is not None and after_disabled is not None, \
"Should find machine code around scheduler pass (disabled case)"
assert before_disabled == after_disabled, \
"Scheduler should NOT modify MIR when misched is disabled"

# Phase 3: Swap with misched ENABLED — scheduler should actually reschedule
enabled_output = _swap_mir_and_get_output(tmp_path, script_file, enable_misched=True)
before_enabled, after_enabled = _extract_mc_around_sched(enabled_output)

assert before_enabled is not None and after_enabled is not None, \
"Should find machine code around scheduler pass (enabled case)"
assert before_enabled != after_enabled, \
"Scheduler SHOULD modify MIR when misched is enabled"


def test_mir_swap_enable_misched_requires_swap_mir(tmp_path):
"""Test that TRITON_SWAP_MIR_ENABLE_MISCHED raises an error without TRITON_SWAP_MIR."""
import os
import subprocess

script_file = tmp_path / "test_kernel.py"
script_file.write_text(_SIMPLE_KERNEL_SCRIPT)

env = os.environ.copy()
env["TRITON_SWAP_MIR_ENABLE_MISCHED"] = "1"
env["TRITON_ALWAYS_COMPILE"] = "1"
# TRITON_SWAP_MIR is NOT set

result = subprocess.run(["python", str(script_file)], capture_output=True, text=True, env=env, timeout=120)
assert result.returncode != 0
assert "TRITON_SWAP_MIR_ENABLE_MISCHED requires TRITON_SWAP_MIR" in result.stderr
2 changes: 2 additions & 0 deletions python/triton/knobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ class amd_knobs(base_knobs):
dump_mir: env_opt_str = env_opt_str("TRITON_DUMP_MIR")
# Path to externally-provided MIR files to use instead of generated ones
swap_mir: env_opt_str = env_opt_str("TRITON_SWAP_MIR")
# Enable machine instruction scheduler in MIR swap mode
swap_mir_enable_misched: env_bool = env_bool("TRITON_SWAP_MIR_ENABLE_MISCHED", False)


class proton_knobs(base_knobs):
Expand Down
8 changes: 5 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,12 @@ def make_amdgcn(src, metadata, options):
dump_file_id)
llvm.dump_sched_dag(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
dump_file_id)
if knobs.amd.swap_mir_enable_misched and not knobs.amd.swap_mir:
raise ValueError("TRITON_SWAP_MIR_ENABLE_MISCHED requires TRITON_SWAP_MIR to be set")
if knobs.amd.swap_mir:
amdgcn = llvm.translate_mir_to_asm(os.path.join(knobs.amd.swap_mir,
dump_file_id + '.txt'), amd.TARGET_TRIPLE, options.arch,
features, flags, options.enable_fp_fusion, False)
amdgcn = llvm.translate_mir_to_asm(os.path.join(knobs.amd.swap_mir, dump_file_id + '.txt'),
amd.TARGET_TRIPLE, options.arch, features, flags,
options.enable_fp_fusion, False, knobs.amd.swap_mir_enable_misched)
else:
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags,
options.enable_fp_fusion, False)
Expand Down
Loading