Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions amd_triton_npu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, Tuple
from types import ModuleType
import hashlib
import sys
import tempfile
import os
import re
Expand Down Expand Up @@ -57,16 +58,32 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
amd_triton_npu_opt_path = _get_amd_triton_npu_opt_path()
subprocess.check_call(
[
amd_triton_npu_opt_path,
src_path,
"--triton-to-linalg-experimental",
"--mlir-print-debuginfo",
"-o",
dst_path,
]
)
cmd = [
amd_triton_npu_opt_path,
src_path,
"--triton-to-linalg-experimental",
"--mlir-print-debuginfo",
"-o",
dst_path,
]
if npu_config.debug:
subprocess.check_call(cmd)
else:
result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
if result.returncode != 0:
if result.stdout:
stderr_buf = getattr(sys.stderr, "buffer", None)
if stderr_buf is not None:
stderr_buf.write(result.stdout)
else:
sys.stderr.write(
result.stdout.decode("utf-8", errors="replace")
)
raise subprocess.CalledProcessError(
result.returncode, cmd, output=result.stdout
)
_dump_ir_if_needed([src_path])
return Path(dst_path).read_text()

Expand Down
24 changes: 24 additions & 0 deletions amd_triton_npu/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""

import contextlib
import logging
import os
from pathlib import Path

Expand All @@ -50,6 +51,7 @@ def __init__(self):
self._bf16_emulation = _UNSET
self._output_format = _UNSET
self._air_project_path = _UNSET
self._debug = _UNSET

# ---- compile_only ----

Expand Down Expand Up @@ -153,6 +155,26 @@ def air_project_path(self, value):
return
self._air_project_path = value

# ---- debug ----

@property
def debug(self) -> bool:
"""If True, enable verbose logging from subprocesses and the C++ launcher.

Env var fallback: ``AMD_TRITON_NPU_DEBUG`` (``"1"`` to enable).
"""
if self._debug is not _UNSET:
return self._debug
return os.getenv("AMD_TRITON_NPU_DEBUG", "0") == "1"

@debug.setter
def debug(self, value: bool):
self._debug = bool(value)
# Keep the driver logger level in sync so logger.debug() calls
# are enabled/suppressed when the flag is toggled programmatically.
_drv = logging.getLogger("triton.backends.amd_triton_npu.driver")
_drv.setLevel(logging.DEBUG if self._debug else logging.CRITICAL)

# ---- utilities ----

def reset(self):
Expand All @@ -162,6 +184,7 @@ def reset(self):
self._bf16_emulation = _UNSET
self._output_format = _UNSET
self._air_project_path = _UNSET
self._debug = _UNSET


# Module-level singleton
Expand All @@ -179,6 +202,7 @@ def set_config(**kwargs):
"""
valid_keys = {
"compile_only",
"debug",
"transform_tiling_script",
"bf16_emulation",
"output_format",
Expand Down
56 changes: 49 additions & 7 deletions amd_triton_npu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)
if os.getenv("AMD_TRITON_NPU_DEBUG", "0") == "1":
if npu_config.debug:
logger.setLevel(logging.DEBUG)
if not logger.handlers:
Comment thread
erwei-xilinx marked this conversation as resolved.
_handler = logging.StreamHandler()
Expand Down Expand Up @@ -606,7 +606,7 @@ def _generate_launcher(constants, signature, kernel_name):
std::vector<uint32_t> instr_v =
test_utils::load_instr_binary(insts_path);

int verbosity = 1;
int verbosity = {1 if npu_config.debug else 0};
if (verbosity >= 1)
std::cout << "Sequence instr count: " << instr_v.size() << std::endl;

Expand All @@ -627,9 +627,9 @@ def _generate_launcher(constants, signature, kernel_name):
// Get the kernel from the xclbin
auto xkernels = xclbin.get_kernels();
auto xkernel = *std::find_if(xkernels.begin(), xkernels.end(),
[Node](xrt::xclbin::kernel &k) {{
[Node, verbosity](xrt::xclbin::kernel &k) {{
auto name = k.get_name();
std::cout << "Name: " << name << std::endl;
if (verbosity >= 1) std::cout << "Name: " << name << std::endl;
return name.rfind(Node, 0) == 0;
}});
auto kernelName = xkernel.get_name();
Expand Down Expand Up @@ -939,7 +939,7 @@ def _generate_elf_launcher(constants, signature, kernel_name):
static void _launch(int gridX, int gridY, int gridZ, {', '.join(f"long size{i}" for i, ty in ptr_args)}, {arg_decls}) {{
if (gridX*gridY*gridZ > 0) {{

int verbosity = 1;
int verbosity = {1 if npu_config.debug else 0};

// Get a device handle
unsigned int device_index = 0;
Expand Down Expand Up @@ -1275,7 +1275,28 @@ def launch(
"-ltest_utils",
]
compile_flags += ["-o", so_path]
subprocess.check_call(compile_flags)
if npu_config.debug:
subprocess.check_call(compile_flags)
else:
result = subprocess.run(
compile_flags,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
if result.returncode != 0:
if result.stdout:
stderr_buf = getattr(sys.stderr, "buffer", None)
if stderr_buf is not None:
stderr_buf.write(result.stdout)
else:
sys.stderr.write(
result.stdout.decode("utf-8", errors="replace")
)
raise subprocess.CalledProcessError(
result.returncode,
compile_flags,
output=result.stdout,
)
Comment thread
erwei-xilinx marked this conversation as resolved.

###### Compile to binary (ELF or xclbin + insts)
air_mlir_path = os.path.join(air_proj_path, "asm_air_output.mlir")
Expand Down Expand Up @@ -1324,7 +1345,28 @@ def launch(
# default changed from [4,4] to [] in mlir-air #1470).
aircc_cmd.insert(-1, "--air-runtime-loop-tiling-sizes=4")
aircc_cmd.insert(-1, "--air-runtime-loop-tiling-sizes=4")
subprocess.check_call(aircc_cmd)
if npu_config.debug:
subprocess.check_call(aircc_cmd)
else:
result = subprocess.run(
aircc_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
if result.returncode != 0:
if result.stdout:
stderr_buf = getattr(sys.stderr, "buffer", None)
if stderr_buf is not None:
stderr_buf.write(result.stdout)
else:
sys.stderr.write(
result.stdout.decode("utf-8", errors="replace")
)
raise subprocess.CalledProcessError(
result.returncode,
aircc_cmd,
output=result.stdout,
)
Comment thread
erwei-xilinx marked this conversation as resolved.

# Cache format-specific artifacts first, then the .so last.
# This avoids partial cache entries if aircc or kernel name
Expand Down
Loading