Skip to content
Open
68 changes: 66 additions & 2 deletions amd_triton_npu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,57 @@ def _replace_include(m):
return result


def _get_transform_ir_string():
def _detect_element_type(ir_str):
"""Detect the primary element type from the Linalg IR function signature.

Scans memref types in the first func.func line for the element type.
Returns the MLIR type string (e.g., "bf16", "f32", "i8", "i16").
Falls back to "bf16" if detection fails.
"""
import re

# Match memref<...xTYPE> in the function signature
Comment thread
erwei-xilinx marked this conversation as resolved.
Outdated
match = re.search(r"memref<[^>]*x(\w+)>", ir_str)
if match:
return match.group(1)
return "bf16"


# Dtype-aware placeholder info: padding value and default vector size per NPU.
_DTYPE_PLACEHOLDER_INFO = {
"bf16": {"pad_val": "0.0 : bf16", "vector_size": {"npu1": 16, "npu2": 32}},
"f32": {"pad_val": "0.0 : f32", "vector_size": {"npu1": 16, "npu2": 16}},
"i8": {"pad_val": "0 : i8", "vector_size": {"npu1": 32, "npu2": 32}},
"i16": {"pad_val": "0 : i16", "vector_size": {"npu1": 32, "npu2": 32}},
"i32": {"pad_val": "0 : i32", "vector_size": {"npu1": 16, "npu2": 16}},
}


def _substitute_dtype_placeholders(script, dtype, npu_version):
"""Substitute dtype-aware placeholders in a transform script.

Replaces @DTYPE@, @PAD_VAL@, and @VECTOR_SIZE@ with values derived
from the detected element type and target NPU version.
No-op if the script contains no placeholders (backward compatible).
"""
if (
"@DTYPE@" not in script
and "@PAD_VAL@" not in script
and "@VECTOR_SIZE@" not in script
):
return script
info = _DTYPE_PLACEHOLDER_INFO.get(dtype)
if info is None:
return script
script = script.replace("@DTYPE@", dtype)
Comment thread
erwei-xilinx marked this conversation as resolved.
script = script.replace("@PAD_VAL@", info["pad_val"])
script = script.replace(
"@VECTOR_SIZE@", str(info["vector_size"].get(npu_version, 16))
)
return script


def _get_transform_ir_string(ir_str=None):
"""
Get the transform IR string for tiling operations.

Expand All @@ -421,6 +471,12 @@ def _get_transform_ir_string():
If the script uses `transform.include`, the shared transform library
(transform_library.mlir) is automatically injected.

If ir_str is provided, dtype-aware placeholders (@DTYPE@, @PAD_VAL@,
@VECTOR_SIZE@) are substituted before library injection.

Args:
ir_str: Optional Linalg IR string for dtype detection.

Returns:
str: The transform IR string to use for tiling
"""
Expand All @@ -436,6 +492,14 @@ def _get_transform_ir_string():
with open(custom_script_path, "r") as f:
print(f"Using custom tiling script from: {custom_script_path}")
user_script = f.read()
if ir_str is not None:
dtype = _detect_element_type(
ir_str if isinstance(ir_str, str) else str(ir_str)
)
npu_version = detect_npu_version()
Comment thread
erwei-xilinx marked this conversation as resolved.
Outdated
user_script = _substitute_dtype_placeholders(
user_script, dtype, npu_version
)
return _inject_transform_library(user_script)

# Default hardcoded transform IR string
Expand Down Expand Up @@ -493,7 +557,7 @@ def _ttshared_to_air(mod, gridX, gridY, gridZ, actual_sizes=None):
pm = air.passmanager.PassManager.parse(pipeline, context=air_context)
pm.run(air_module.operation)
# MLIR-AIR compilation step 2: tiling the launch body
transform_ir_string = _get_transform_ir_string()
transform_ir_string = _get_transform_ir_string(ir_str=mod)
transform_ir = Module.parse(transform_ir_string, context=air_context)
run_transform(transform_ir, air_module)
# MLIR-AIR compilation step 3: converting to AIR
Expand Down
94 changes: 94 additions & 0 deletions amd_triton_npu/backend/transform_library/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,97 @@ transform.named_sequence @pad_and_promote_binary_bf16(
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant for f32: 2 inputs + 1 output = 3 operands.
// Used with bf16-emulation (f32 data, bf16 compute on AIE cores).
transform.named_sequence @pad_and_promote_binary_f32(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_lhs = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_lhs_buffer, %padded_lhs_new =
transform.structured.bufferize_to_allocation %padded_lhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_rhs = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_rhs_buffer, %padded_rhs_new =
transform.structured.bufferize_to_allocation %padded_rhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[2]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant for i8: 2 inputs + 1 output = 3 operands.
transform.named_sequence @pad_and_promote_binary_i8(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0 : i8, 0 : i8, 0 : i8],
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_lhs = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_lhs_buffer, %padded_lhs_new =
transform.structured.bufferize_to_allocation %padded_lhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_rhs = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_rhs_buffer, %padded_rhs_new =
transform.structured.bufferize_to_allocation %padded_rhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[2]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant for i16: 2 inputs + 1 output = 3 operands.
transform.named_sequence @pad_and_promote_binary_i16(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0 : i16, 0 : i16, 0 : i16],
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_lhs = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_lhs_buffer, %padded_lhs_new =
transform.structured.bufferize_to_allocation %padded_lhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_rhs = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_rhs_buffer, %padded_rhs_new =
transform.structured.bufferize_to_allocation %padded_rhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[2]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}
10 changes: 6 additions & 4 deletions examples/vec-add/transform_aie2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
////////////////////////////////////////////////////////////////////////////////
// Transform Script for Vector Addition (AIE2)
// Simple elementwise add: out = a + b
// Binary op (2 inputs + 1 output). No fusion needed. Vec tile = 16 (AIE2).
// No type casts needed (bf16 add is native).
// Binary op (2 inputs + 1 output). No fusion needed.
// No type casts needed (bf16/i8/i16 add is native; f32 uses bf16-emulation).
// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders substituted
// by the driver based on the IR element type and NPU version.
// Uses shared library sequences from transform_library.mlir (auto-injected).
////////////////////////////////////////////////////////////////////////////////

Expand All @@ -18,7 +20,7 @@ module attributes {transform.with_named_sequence} {
(%arg1) : (!transform.any_op) -> ()
transform.include @canonicalize_with_cse failures(propagate)
(%arg1) : (!transform.any_op) -> ()
transform.include @pad_and_promote_binary_bf16 failures(propagate)
transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate)
(%arg1) : (!transform.any_op) -> ()
transform.include @canonicalize_with_cse failures(propagate)
(%arg1) : (!transform.any_op) -> ()
Expand All @@ -27,7 +29,7 @@ module attributes {transform.with_named_sequence} {
transform.include @post_bufferize_cleanup failures(propagate)
(%arg1) : (!transform.any_op) -> ()

transform.include @vectorize_generics_at_16 failures(propagate)
transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate)
(%arg1) : (!transform.any_op) -> ()
%vh = transform.include @air_herd_mapping_and_vectorize
failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op
Expand Down
10 changes: 6 additions & 4 deletions examples/vec-add/transform_aie2p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
////////////////////////////////////////////////////////////////////////////////
// Transform Script for Vector Addition (AIE2P)
// Simple elementwise add: out = a + b
// Binary op (2 inputs + 1 output). No fusion needed. Vec tile = 32 (AIE2P).
// No type casts needed (bf16 add is native).
// Binary op (2 inputs + 1 output). No fusion needed.
// No type casts needed (bf16/i8/i16 add is native; f32 uses bf16-emulation).
// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders substituted
// by the driver based on the IR element type and NPU version.
// Uses shared library sequences from transform_library.mlir (auto-injected).
////////////////////////////////////////////////////////////////////////////////

Expand All @@ -18,7 +20,7 @@ module attributes {transform.with_named_sequence} {
(%arg1) : (!transform.any_op) -> ()
transform.include @canonicalize_with_cse failures(propagate)
(%arg1) : (!transform.any_op) -> ()
transform.include @pad_and_promote_binary_bf16 failures(propagate)
transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate)
(%arg1) : (!transform.any_op) -> ()
transform.include @canonicalize_with_cse failures(propagate)
(%arg1) : (!transform.any_op) -> ()
Expand All @@ -27,7 +29,7 @@ module attributes {transform.with_named_sequence} {
transform.include @post_bufferize_cleanup failures(propagate)
(%arg1) : (!transform.any_op) -> ()

transform.include @vectorize_generics_at_32 failures(propagate)
transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate)
(%arg1) : (!transform.any_op) -> ()
%vh = transform.include @air_herd_mapping_and_vectorize
failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op
Expand Down
Loading
Loading