Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions amd_triton_npu/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
# Direct attribute assignment
npu_config.bf16_emulation = True
npu_config.compile_only = True
npu_config.target = "npu2" # cross-compile for npu2

# Or dict-style
set_config(bf16_emulation=True, compile_only=False)
set_config(bf16_emulation=True, compile_only=False, target="npu1")

# Temporary overrides via context manager
from triton.backends.amd_triton_npu.config import config_context
with config_context(compile_only=True):
with config_context(compile_only=True, target="npu2"):
kernel[grid](a, b, c)

# Environment variables still work as fallback
# AMD_TRITON_NPU_BF16_EMULATION=1 python my_script.py
# AMD_TRITON_NPU_TARGET=npu2 python my_script.py
"""

import contextlib
Expand All @@ -36,6 +37,8 @@

_UNSET = object()

_VALID_TARGETS = frozenset(("npu1", "npu2"))


class _NPUConfig:
"""Process-global configuration for the NPU backend.
Expand All @@ -52,6 +55,7 @@ def __init__(self):
self._output_format = _UNSET
self._air_project_path = _UNSET
self._debug = _UNSET
self._target = _UNSET

# ---- compile_only ----

Expand Down Expand Up @@ -175,6 +179,46 @@ def debug(self, value: bool):
_drv = logging.getLogger("triton.backends.amd_triton_npu.driver")
_drv.setLevel(logging.DEBUG if self._debug else logging.CRITICAL)

# ---- target ----

@property
def target(self):
"""Force the NPU target to ``"npu1"`` or ``"npu2"``.

When set, ``detect_npu_version()`` uses this value instead of
querying hardware via xrt-smi. This enables cross-compilation
without local NPU hardware.

Set to ``None`` for auto-detection from installed hardware.

Env var fallback: ``AMD_TRITON_NPU_TARGET``. If the environment
variable is set to a non-empty unsupported value, a ``ValueError``
is raised.
"""
if self._target is not _UNSET:
return self._target
v = os.getenv("AMD_TRITON_NPU_TARGET", "")
if not v:
return None
v = v.lower()
if v not in _VALID_TARGETS:
raise ValueError(
f"AMD_TRITON_NPU_TARGET must be one of {sorted(_VALID_TARGETS)} "
f"or empty/unset; got {v!r}"
)
return v

@target.setter
def target(self, value):
if value is not None:
value = value.lower()
if value not in _VALID_TARGETS:
raise ValueError(
f"target must be one of {sorted(_VALID_TARGETS)} or None; "
f"got {value!r}"
)
self._target = value

# ---- utilities ----

def reset(self):
Expand All @@ -185,6 +229,7 @@ def reset(self):
self._output_format = _UNSET
self._air_project_path = _UNSET
self._debug = _UNSET
self._target = _UNSET


# Module-level singleton
Expand All @@ -207,6 +252,7 @@ def set_config(**kwargs):
"bf16_emulation",
"output_format",
"air_project_path",
"target",
}
for key, value in kwargs.items():
if key not in valid_keys:
Expand Down
10 changes: 6 additions & 4 deletions amd_triton_npu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,17 @@ def get_npu_device_info():
def detect_npu_version():
"""Map known device names to internal NPU version strings.

If AMD_TRITON_NPU_TARGET is set, use that value directly
If ``npu_config.target`` is set (programmatically or via the
``AMD_TRITON_NPU_TARGET`` env var), use that value directly
(must be 'npu1' or 'npu2'). This enables cross-compilation
without local NPU hardware.
"""
Comment thread
erwei-xilinx marked this conversation as resolved.
target = os.getenv("AMD_TRITON_NPU_TARGET", "").lower()
if target:
target = npu_config.target
if target is not None:
if target not in NPU_MODELS:
raise RuntimeError(
f"Invalid AMD_TRITON_NPU_TARGET='{target}'. "
f"Invalid target='{target}' from npu_config.target "
f"(or AMD_TRITON_NPU_TARGET). "
f"Supported values: {list(NPU_MODELS.keys())}"
)
return target
Expand Down
Loading