Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/aie-c/TargetModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ MLIR_CAPI_EXPORTED int aieTargetModelRows(AieTargetModel targetModel);
/// Returns true if this is an NPU target model.
MLIR_CAPI_EXPORTED bool aieTargetModelIsNPU(AieTargetModel targetModel);

/// Returns the AIE architecture (as the underlying value of xilinx::AIE::AIEArch:
/// AIE1=1, AIE2=2, AIE2p=3).
MLIR_CAPI_EXPORTED uint32_t
aieTargetModelGetTargetArch(AieTargetModel targetModel);

/// Returns the tile type for the given coordinates.
MLIR_CAPI_EXPORTED uint32_t
aieTargetModelGetTileType(AieTargetModel targetModel, int col, int row);
Expand Down
4 changes: 4 additions & 0 deletions lib/CAPI/TargetModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ bool aieTargetModelIsNPU(AieTargetModel targetModel) {
return unwrap(targetModel).hasProperty(xilinx::AIE::AIETargetModel::IsNPU);
}

uint32_t aieTargetModelGetTargetArch(AieTargetModel targetModel) {
return static_cast<uint32_t>(unwrap(targetModel).getTargetArch());
}

uint32_t aieTargetModelGetColumnShift(AieTargetModel targetModel) {
return unwrap(targetModel).getColumnShift();
}
Expand Down
4 changes: 4 additions & 0 deletions python/AIEMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ NB_MODULE(_aie, m) {
[](PyAieTargetModel &self) {
return aieTargetModelIsNPU(self.get());
})
.def("get_target_arch",
[](PyAieTargetModel &self) {
return aieTargetModelGetTargetArch(self.get());
})
.def("get_column_shift",
[](PyAieTargetModel &self) {
return aieTargetModelGetColumnShift(self.get());
Expand Down
7 changes: 6 additions & 1 deletion python/iron/device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Generator

from ... import ir # type: ignore
from ...dialects._aie_enum_gen import AIETileType, WireBundle # type: ignore
from ...dialects._aie_enum_gen import AIEArch, AIETileType, WireBundle # type: ignore
from ...dialects.aie import (
AIEDevice,
logical_tile,
Expand Down Expand Up @@ -45,6 +45,11 @@ def rows(self) -> int:
"""Number of rows in the device tile array."""
return self._tm.rows()

@property
def arch(self) -> AIEArch:
"""AIE architecture of the device (AIE1, AIE2, or AIE2p)."""
return AIEArch(self._tm.get_target_arch())

def _validate_coordinates(self, col, row):
"""Raise ValueError if coordinates are outside the device grid."""
if col < 0 or col >= self._tm.columns() or row < 0 or row >= self._tm.rows():
Expand Down
21 changes: 11 additions & 10 deletions python/utils/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from aie.extras.context import mlir_mod_ctx
from .compile import compile_mlir_module, compile_external_kernel
from .npukernel import NPUKernel
from aie.dialects.aie import AIEDevice
from .compile.cache.circular_cache import CircularCache
from .compile.cache.utils import _create_function_cache_key, file_lock
from .compile import NPU_CACHE_HOME
Expand Down Expand Up @@ -48,8 +47,10 @@ def jit(function=None, use_cache=True):

@functools.wraps(function)
def decorator(*args, **kwargs):
from aie.iron.device import NPU1, NPU2, NPU1Col1, NPU2Col1
from aie.iron.device import Device
from aie.iron.kernel import ExternalFunction
from aie.dialects._aie_enum_gen import AIEArch
from aie.dialects.aie import get_target_model
from . import DefaultNPURuntime

if DefaultNPURuntime is None:
Expand Down Expand Up @@ -118,17 +119,17 @@ def decorator(*args, **kwargs):

current_device = DefaultNPURuntime.device()

# Determine target architecture based on device type
if isinstance(current_device, (NPU2, NPU2Col1)):
target_arch = "aie2p"
elif isinstance(current_device, (NPU1, NPU1Col1)):
target_arch = "aie2"
elif current_device in (AIEDevice.npu2, AIEDevice.npu2_1col):
# Determine target architecture from the device's target model.
if isinstance(current_device, Device):
arch = current_device.arch
else:
arch = AIEArch(get_target_model(current_device).get_target_arch())
if arch == AIEArch.AIE2p:
target_arch = "aie2p"
elif current_device in (AIEDevice.npu1, AIEDevice.npu1_1col):
elif arch == AIEArch.AIE2:
target_arch = "aie2"
else:
raise RuntimeError(f"Unsupported device type: {type(current_device)}")
raise RuntimeError(f"Unsupported device arch: {arch}")

# Hash of the IR string, ExternalFunction compiler options, and target architecture
module_hash = hash_module(mlir_module, external_kernels, target_arch)
Expand Down
Loading