Skip to content

Commit e3e8acf

Browse files
authored
Merge pull request #56 from amd/add-target-to-npu-config
Add target to npu_config (fixes #54)
2 parents fa84eb1 + 9343ed8 commit e3e8acf

2 files changed

Lines changed: 55 additions & 7 deletions

File tree

amd_triton_npu/backend/config.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616
# Direct attribute assignment
1717
npu_config.bf16_emulation = True
1818
npu_config.compile_only = True
19+
npu_config.target = "npu2" # cross-compile for npu2
1920
2021
# Or dict-style
21-
set_config(bf16_emulation=True, compile_only=False)
22+
set_config(bf16_emulation=True, compile_only=False, target="npu1")
2223
2324
# Temporary overrides via context manager
2425
from triton.backends.amd_triton_npu.config import config_context
25-
with config_context(compile_only=True):
26+
with config_context(compile_only=True, target="npu2"):
2627
kernel[grid](a, b, c)
2728
2829
# Environment variables still work as fallback
29-
# AMD_TRITON_NPU_BF16_EMULATION=1 python my_script.py
30+
# AMD_TRITON_NPU_TARGET=npu2 python my_script.py
3031
"""
3132

3233
import contextlib
@@ -36,6 +37,8 @@
3637

3738
_UNSET = object()
3839

40+
_VALID_TARGETS = frozenset(("npu1", "npu2"))
41+
3942

4043
class _NPUConfig:
4144
"""Process-global configuration for the NPU backend.
@@ -52,6 +55,7 @@ def __init__(self):
5255
self._output_format = _UNSET
5356
self._air_project_path = _UNSET
5457
self._debug = _UNSET
58+
self._target = _UNSET
5559

5660
# ---- compile_only ----
5761

@@ -175,6 +179,46 @@ def debug(self, value: bool):
175179
_drv = logging.getLogger("triton.backends.amd_triton_npu.driver")
176180
_drv.setLevel(logging.DEBUG if self._debug else logging.CRITICAL)
177181

182+
# ---- target ----
183+
184+
@property
185+
def target(self):
186+
"""Force the NPU target to ``"npu1"`` or ``"npu2"``.
187+
188+
When set, ``detect_npu_version()`` uses this value instead of
189+
querying hardware via xrt-smi. This enables cross-compilation
190+
without local NPU hardware.
191+
192+
Set to ``None`` for auto-detection from installed hardware.
193+
194+
Env var fallback: ``AMD_TRITON_NPU_TARGET``. If the environment
195+
variable is set to a non-empty unsupported value, a ``ValueError``
196+
is raised.
197+
"""
198+
if self._target is not _UNSET:
199+
return self._target
200+
v = os.getenv("AMD_TRITON_NPU_TARGET", "")
201+
if not v:
202+
return None
203+
v = v.lower()
204+
if v not in _VALID_TARGETS:
205+
raise ValueError(
206+
f"AMD_TRITON_NPU_TARGET must be one of {sorted(_VALID_TARGETS)} "
207+
f"or empty/unset; got {v!r}"
208+
)
209+
return v
210+
211+
@target.setter
212+
def target(self, value):
213+
if value is not None:
214+
value = value.lower()
215+
if value not in _VALID_TARGETS:
216+
raise ValueError(
217+
f"target must be one of {sorted(_VALID_TARGETS)} or None; "
218+
f"got {value!r}"
219+
)
220+
self._target = value
221+
178222
# ---- utilities ----
179223

180224
def reset(self):
@@ -185,6 +229,7 @@ def reset(self):
185229
self._output_format = _UNSET
186230
self._air_project_path = _UNSET
187231
self._debug = _UNSET
232+
self._target = _UNSET
188233

189234

190235
# Module-level singleton
@@ -207,6 +252,7 @@ def set_config(**kwargs):
207252
"bf16_emulation",
208253
"output_format",
209254
"air_project_path",
255+
"target",
210256
}
211257
for key, value in kwargs.items():
212258
if key not in valid_keys:

amd_triton_npu/backend/driver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,17 @@ def get_npu_device_info():
190190
def detect_npu_version():
191191
"""Map known device names to internal NPU version strings.
192192
193-
If AMD_TRITON_NPU_TARGET is set, use that value directly
193+
If ``npu_config.target`` is set (programmatically or via the
194+
``AMD_TRITON_NPU_TARGET`` env var), use that value directly
194195
(must be 'npu1' or 'npu2'). This enables cross-compilation
195196
without local NPU hardware.
196197
"""
197-
target = os.getenv("AMD_TRITON_NPU_TARGET", "").lower()
198-
if target:
198+
target = npu_config.target
199+
if target is not None:
199200
if target not in NPU_MODELS:
200201
raise RuntimeError(
201-
f"Invalid AMD_TRITON_NPU_TARGET='{target}'. "
202+
f"Invalid target='{target}' from npu_config.target "
203+
f"(or AMD_TRITON_NPU_TARGET). "
202204
f"Supported values: {list(NPU_MODELS.keys())}"
203205
)
204206
return target

0 commit comments

Comments
 (0)