-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathdriver.py
More file actions
1667 lines (1415 loc) · 57.1 KB
/
Copy pathdriver.py
File metadata and controls
1667 lines (1415 loc) · 57.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
import hashlib
import json
import tempfile
import sys
import sysconfig
import os, subprocess, tempfile, platform
import importlib.util
import shutil
from pathlib import Path
from triton.runtime.cache import get_cache_manager
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget
import aie.compiler.aiecc.main as aiecc
import air.compiler.aircc.main as aircc
from air.compiler.util import run_transform
from air.ir import *
import air.passmanager
autotune_time = False
# -------------------- Launcher ----------------------------
def _ty_to_cpp(ty):
if ty[0] == "*":
return "void*"
if ty == "constexpr":
return "PyObject*"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u1": "uint32_t",
"u8": "uint8_t",
"u16": "uint16_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "bfloat16",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]
def _extracted_type(ty):
if ty[0] == "*":
return "PyObject*"
if ty == "constexpr":
return "PyObject*"
return _ty_to_cpp(ty)
def _format_of(ty):
return {
"PyObject*": "O",
"constexpr": "O",
"float": "f",
"double": "d",
"long": "l",
"int8_t": "b",
"int16_t": "h",
"int32_t": "i",
"int64_t": "l",
"uint8_t": "B",
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
}[ty]
def _get_air_opt_path() -> str:
"""
Get the path to air-opt binary from pip-installed mlir-air package.
Uses the aircc module's location to find the mlir_air package root,
then locates the air-opt binary in the bin/ directory.
Returns:
str: Path to air-opt binary
Raises:
RuntimeError: If air-opt binary not found
"""
# aircc.__file__ gives: /path/to/mlir_air/python/air/compiler/aircc/main.py
# We need: /path/to/mlir_air/bin/air-opt
aircc_path = Path(aircc.__file__).resolve()
# Navigate from .../mlir_air/python/air/compiler/aircc/main.py to .../mlir_air/
mlir_air_root = aircc_path.parent.parent.parent.parent.parent
air_opt_path = mlir_air_root / "bin" / "air-opt"
if not air_opt_path.exists():
raise RuntimeError(f"Could not find air-opt binary at {air_opt_path}")
return str(air_opt_path)
def _get_xrt_path() -> str:
path = os.getenv("XILINX_XRT", "")
if path == "":
raise Exception("XILINX_XRT is not set. Is xrt installed in system?")
return path
def _get_aie_test_utils_path() -> str:
path = (
Path(aiecc.__file__).parent.parent.parent.parent.parent
/ "runtime_lib"
/ "x86_64"
/ "test_lib"
)
return path
def _get_air_project_path() -> Path:
"""
Get the path for air_project directory.
If AMD_TRITON_NPU_AIR_PROJECT_PATH is set, use that path.
Otherwise, default to 'air_project' in the current working directory.
Returns:
Path: The path to the air_project directory
"""
custom_path = os.getenv("AMD_TRITON_NPU_AIR_PROJECT_PATH")
if custom_path:
return Path(custom_path)
return Path(os.getcwd()) / "air_project"
def _dump_ir_if_needed(files):
"""
Dump intermediate IR files to the air_project directory.
Files are always dumped to the air_project path (controlled by
AMD_TRITON_NPU_AIR_PROJECT_PATH or defaulting to ./air_project/).
"""
air_proj_path = _get_air_project_path()
os.makedirs(air_proj_path, exist_ok=True)
for f in files:
shutil.copy(f, os.path.join(air_proj_path, os.path.basename(f)))
def get_npu_device_info():
try:
import re
result = subprocess.run(
["xrt-smi", "examine"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
text=True,
)
output = result.stdout
# Match either one or two pipes with optional whitespace around them
device_pattern = re.compile(
r"\[(?P<bdf>[0-9a-fA-F:.]+)\]\s*\|{1,2}\s*(?P<name>.+?)\s*\|"
)
matches = device_pattern.findall(output)
devices = []
for bdf, name in matches:
devices.append({"bdf": bdf, "name": name.strip()})
return devices
except subprocess.CalledProcessError as e:
print("Failed to run xrt-smi:", e.stderr)
return []
except Exception as e:
print("Unexpected error:", str(e))
return []
# Device name mappings aligned with mlir-aie (lit_config_helpers.py, iron_setup.py)
NPU_MODELS = {
"npu1": ["npu1", "Phoenix"],
"npu2": ["npu4", "Strix", "npu5", "Strix Halo", "npu6", "Krackan"],
}
def detect_npu_version():
"""Map known device names to internal NPU version strings."""
devices = get_npu_device_info()
for device in devices:
name = device["name"]
for version, keywords in NPU_MODELS.items():
if any(kw.lower() in name.lower() for kw in keywords):
return version
if not devices:
raise RuntimeError(
"No NPU devices found. Ensure XRT is installed and xrt-smi is available."
)
device_names = [d["name"] for d in devices]
raise RuntimeError(
f"Unsupported NPU device(s): {device_names}. "
f"Supported models: {dict(NPU_MODELS)}"
)
def _get_output_format():
"""Determine the output format for the NPU backend.
Checks AMD_TRITON_NPU_OUTPUT_FORMAT env var first.
If not set, defaults to "elf" on npu2 and "xclbin" on npu1.
ELF format is only supported on npu2 (AIE2P) devices.
"""
npu_version = detect_npu_version()
env_format = os.getenv("AMD_TRITON_NPU_OUTPUT_FORMAT", "").lower()
if env_format in ("elf", "xclbin"):
if env_format == "elf" and npu_version == "npu1":
raise RuntimeError(
"ELF output format is not supported on npu1 (AIE2) devices. "
"Use 'xclbin' or unset AMD_TRITON_NPU_OUTPUT_FORMAT."
)
return env_format
# Auto-detect: ELF for npu2, xclbin for npu1
return "elf" if npu_version == "npu2" else "xclbin"
def _extract_elf_kernel_name(config_json_path):
"""Extract the ELF kernel name from full_elf_config.json.
The kernel name for XRT is "{kernel_name}:{instance_id}".
Looks for the "main" kernel entry (the runtime dispatch kernel)
and uses its first instance ID.
"""
with open(config_json_path) as f:
config = json.load(f)
for kernel in config["xrt-kernels"]:
if kernel["name"] == "main" and kernel.get("instance"):
instance_id = kernel["instance"][0]["id"]
return f"main:{instance_id}"
# Fallback: use the last kernel entry (which is typically "main")
last_kernel = config["xrt-kernels"][-1]
instance_id = last_kernel["instance"][0]["id"]
return f"{last_kernel['name']}:{instance_id}"
def _inject_transform_library(user_script):
"""
Process library references in a user transform script.
Two mechanisms:
1. transform.include calls are expanded inline (parameter substitution,
SSA renaming) to avoid segfaults in mlir-air's transform interpreter
when resolving transform.include across region boundaries.
2. foreach_match @name symbol references are resolved by injecting the
referenced named_sequence definitions into the module (these cannot
be inlined because foreach_match resolves symbols at runtime).
Args:
user_script: The user's transform script as a string
Returns:
str: The processed script
"""
has_includes = "transform.include" in user_script
has_foreach_match = "foreach_match" in user_script
if not has_includes and not has_foreach_match:
return user_script
# Load library content from transform_library/ directory
lib_dir = os.path.join(os.path.dirname(__file__), "transform_library")
if not os.path.isdir(lib_dir):
return user_script
parts = []
for fname in sorted(os.listdir(lib_dir)):
if fname.endswith(".mlir"):
with open(os.path.join(lib_dir, fname), "r") as f:
parts.append(f.read())
lib_content = "\n".join(parts)
import re
# Parse all named sequences: full text (for injection) and decomposed (for inlining)
full_seq_pattern = re.compile(
r"((?://[^\n]*\n)*"
r"transform\.named_sequence\s+@(\w+)\s*\([^)]*\)"
r"(?:\s*->\s*!transform\.any_op)?"
r"\s*\{.*?\n\})",
re.DOTALL,
)
full_sequences = {}
for m in full_seq_pattern.finditer(lib_content):
full_sequences[m.group(2)] = m.group(1)
# Parse inlinable sequences (readonly or consumed param, for transform.include)
inline_seq_pattern = re.compile(
r"transform\.named_sequence\s+@(\w+)\s*\(\s*"
r"%(\w+)\s*:\s*!transform\.any_op\s*\{transform\.(?:readonly|consumed)\}\s*\)"
r"(\s*->\s*!transform\.any_op)?"
r"\s*\{(.*?)\n\}",
re.DOTALL,
)
sequences = {}
for match in inline_seq_pattern.finditer(lib_content):
name = match.group(1)
param = match.group(2)
has_result = match.group(3) is not None
body = match.group(4)
sequences[name] = (param, body, has_result)
if not sequences and not full_sequences:
return user_script
# Inline transform.include calls to avoid mlir-air segfaults
include_pattern = re.compile(
r"(?:(%\w+)\s*=\s*)?"
r"transform\.include\s+@(\w+)\s+"
r"failures\(\w+\)\s*"
r"\((%\w+)\)\s*"
r":\s*\(!transform\.any_op\)\s*->\s*"
r"(?:!transform\.any_op|\(\s*\))"
)
_counter = [0]
def _expand(text, depth=0):
if depth > 20 or "transform.include" not in text:
return text
def _replace_include(m):
result_var = m.group(1)
seq_name = m.group(2)
actual_arg = m.group(3)
if seq_name not in sequences:
return m.group(0)
param, body, has_result = sequences[seq_name]
expanded = body.replace(f"%{param}", actual_arg)
yield_match = re.search(
r"transform\.yield(?:\s+(%\w+)\s*:\s*!transform\.any_op)?",
expanded,
)
if yield_match:
yielded_var = yield_match.group(1)
expanded = expanded[: yield_match.start()].rstrip()
if result_var and yielded_var:
expanded = expanded.replace(yielded_var, result_var)
suffix = f"_lib{_counter[0]}"
_counter[0] += 1
local_vars = set(re.findall(r"%(\w+)", expanded))
actual_name = actual_arg.lstrip("%")
result_name = result_var.lstrip("%") if result_var else ""
skip = {actual_name, result_name, "__", ""}
for var in local_vars:
if var not in skip and not var.startswith("_lib"):
expanded = re.sub(
rf"(?<!\w)%{re.escape(var)}(?!\w)",
f"%{var}{suffix}",
expanded,
)
return expanded
text = include_pattern.sub(_replace_include, text)
return _expand(text, depth + 1)
result = _expand(user_script) if has_includes else user_script
# Inject named sequences referenced by foreach_match (symbol references
# that cannot be inlined — they must exist as definitions in the module).
if has_foreach_match or "foreach_match" in result:
all_refs = set(re.findall(r"@(\w+)", result))
all_refs.discard("__transform_main")
# Transitively resolve dependencies
needed = set()
worklist = [n for n in all_refs if n in full_sequences]
while worklist:
name = worklist.pop()
if name in needed:
continue
needed.add(name)
for dep in re.findall(r"@(\w+)", full_sequences[name]):
if dep in full_sequences and dep not in needed:
worklist.append(dep)
# Inject definitions for all unresolved @name references
# (matchers/actions referenced by foreach_match, plus their deps)
if needed:
module_marker = "module attributes {transform.with_named_sequence} {"
idx = result.find(module_marker)
if idx != -1:
insert_pos = idx + len(module_marker)
injection = "\n\n".join(
full_sequences[n] for n in full_sequences if n in needed
)
result = (
result[:insert_pos]
+ "\n\n"
+ injection
+ "\n\n"
+ result[insert_pos:]
)
return result
def _detect_element_type(ir_str):
"""Detect the primary element type from the Linalg IR function signature.
Scans memref types in the first func.func line for the element type.
Returns the MLIR type string (e.g., "bf16", "f32", "i8", "i16").
Falls back to "bf16" if detection fails.
"""
import re
# Match memref<...xTYPE> in the function signature
match = re.search(r"memref<[^>]*x(\w+)>", ir_str)
if match:
return match.group(1)
return "bf16"
# Dtype-aware placeholder info: padding value and default vector size per NPU.
_DTYPE_PLACEHOLDER_INFO = {
"bf16": {"pad_val": "0.0 : bf16", "vector_size": {"npu1": 16, "npu2": 32}},
"f32": {"pad_val": "0.0 : f32", "vector_size": {"npu1": 16, "npu2": 16}},
"i8": {"pad_val": "0 : i8", "vector_size": {"npu1": 32, "npu2": 32}},
"i16": {"pad_val": "0 : i16", "vector_size": {"npu1": 32, "npu2": 32}},
"i32": {"pad_val": "0 : i32", "vector_size": {"npu1": 16, "npu2": 16}},
}
def _substitute_dtype_placeholders(script, dtype, npu_version):
"""Substitute dtype-aware placeholders in a transform script.
Replaces @DTYPE@, @PAD_VAL@, and @VECTOR_SIZE@ with values derived
from the detected element type and target NPU version.
No-op if the script contains no placeholders (backward compatible).
"""
if (
"@DTYPE@" not in script
and "@PAD_VAL@" not in script
and "@VECTOR_SIZE@" not in script
):
return script
info = _DTYPE_PLACEHOLDER_INFO.get(dtype)
if info is None:
return script
script = script.replace("@DTYPE@", dtype)
script = script.replace("@PAD_VAL@", info["pad_val"])
script = script.replace(
"@VECTOR_SIZE@", str(info["vector_size"].get(npu_version, 16))
)
return script
def _get_transform_ir_string(ir_str=None):
"""
Get the transform IR string for tiling operations.
If the environment variable AIR_TRANSFORM_TILING_SCRIPT is set,
read the transform IR from that file. Otherwise, use the default
hardcoded transform IR string.
If the script uses `transform.include`, the shared transform library
(transform_library.mlir) is automatically injected.
If ir_str is provided, dtype-aware placeholders (@DTYPE@, @PAD_VAL@,
@VECTOR_SIZE@) are substituted before library injection.
Args:
ir_str: Optional Linalg IR string for dtype detection.
Returns:
str: The transform IR string to use for tiling
"""
custom_script_path = os.getenv("AIR_TRANSFORM_TILING_SCRIPT")
if custom_script_path:
if not os.path.isfile(custom_script_path):
raise FileNotFoundError(
f"AIR_TRANSFORM_TILING_SCRIPT is set to '{custom_script_path}' "
f"but the file was not found (cwd: {os.getcwd()}). "
f"Use an absolute path or run from the directory containing the script."
)
with open(custom_script_path, "r") as f:
print(f"Using custom tiling script from: {custom_script_path}")
user_script = f.read()
if ir_str is not None:
dtype = _detect_element_type(
ir_str if isinstance(ir_str, str) else str(ir_str)
)
npu_version = detect_npu_version()
user_script = _substitute_dtype_placeholders(
user_script, dtype, npu_version
)
return _inject_transform_library(user_script)
# Default hardcoded transform IR string
matmul_tiling_size_l1_m = 32
matmul_tiling_size_l1_n = 32
matmul_tiling_size_l1_k = 32
elemwise_tiling_size_l1_m = 32
elemwise_tiling_size_l1_n = 32
return f"""
module attributes {{transform.with_named_sequence}} {{
transform.named_sequence @__transform_main(%arg1: !transform.any_op {{transform.readonly}}) {{
%mul = transform.structured.match ops{{["linalg.mul"]}} in %arg1 : (!transform.any_op) -> !transform.any_op
%mul_1, %loop = transform.air.linalg_tile %mul [{elemwise_tiling_size_l1_m}, {elemwise_tiling_size_l1_n}]
transform.air.linalg_promote %mul_1 {{"operands_to_promote"=[2], "memory_space"="L1"}}
transform.air.linalg_promote %mul_1 {{"operands_to_promote"=[0,1], "memory_space"="L1"}}
%add = transform.structured.match ops{{["linalg.add"]}} in %arg1 : (!transform.any_op) -> !transform.any_op
%add_1, %add_loop = transform.air.linalg_tile %add [{elemwise_tiling_size_l1_m}, {elemwise_tiling_size_l1_n}]
transform.air.linalg_promote %add_1 {{"operands_to_promote"=[2], "memory_space"="L1"}}
transform.air.linalg_promote %add_1 {{"operands_to_promote"=[0,1], "memory_space"="L1"}}
%matmul = transform.structured.match ops{{["linalg.matmul"]}} in %arg1 : (!transform.any_op) -> !transform.any_op
%fill = transform.structured.match ops{{["linalg.fill"]}} in %arg1 : (!transform.any_op) -> !transform.any_op
%matmul_1, %matmul_loop = transform.air.linalg_tile %matmul [{matmul_tiling_size_l1_m}, {matmul_tiling_size_l1_n}]
%fill_1 = transform.air.fuse_into_containing_op %fill into %matmul_loop
transform.air.linalg_promote %fill_1 {{"operands_to_promote"=[1], "memory_space"="L1"}}
transform.air.linalg_promote %matmul_1 {{"operands_to_promote"=[2], "memory_space"="L1"}}
%matmul_2, %reduction_loop = transform.air.linalg_tile %matmul_1 [0, 0, {matmul_tiling_size_l1_k}]
transform.air.linalg_promote %matmul_2 {{"operands_to_promote"=[0,1], "memory_space"="L1"}}
transform.yield
}}
}}
"""
def _ttshared_to_air(mod, gridX, gridY, gridZ, actual_sizes=None):
# Get Triton-Shared-MLIR as string
with tempfile.TemporaryDirectory() as tmpdir:
dst_path = os.path.join(tmpdir, "airinput.mlir")
air_opt_path = _get_air_opt_path()
# MLIR-AIR compilation step 1: mapping grid to air.launch
pipeline = (
"builtin.module("
+ ",".join(
[
"air-resolve-tensor-opoperand-conflicts",
"air-override-memref-memory-space{scope=func memory-space=1}",
]
)
+ ")"
)
air_context = air.ir.Context()
air_module = Module.parse(mod, context=air_context)
pm = air.passmanager.PassManager.parse(pipeline, context=air_context)
pm.run(air_module.operation)
# MLIR-AIR compilation step 2: tiling the launch body
transform_ir_string = _get_transform_ir_string(ir_str=mod)
transform_ir = Module.parse(transform_ir_string, context=air_context)
run_transform(transform_ir, air_module)
# MLIR-AIR compilation step 3: converting to AIR
wrap_params = f"loop-bounds={gridX},{gridY},{gridZ}"
if actual_sizes:
wrap_params += f" actual-sizes={actual_sizes}"
pipeline = (
"builtin.module("
+ ",".join(
[
f"func.func(air-wrap-func-with-parallel{{{wrap_params}}})",
"air-par-to-launch{depth=0 has-air-segment=true}",
"canonicalize",
"cse",
"air-copy-to-dma",
]
)
+ ")"
)
pm = air.passmanager.PassManager.parse(pipeline, context=air_context)
pm.run(air_module.operation)
with open(dst_path, "w") as f:
f.write(str(air_module))
_dump_ir_if_needed([dst_path])
return air_module
def _generate_launcher(constants, signature, kernel_name):
arg_decls = ", ".join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
args_format = "".join(
[_format_of(_extracted_type(ty)) for ty in signature.values()]
)
format = "iiiOOOO" + args_format
args_list = (
", " + ", ".join(f"&_arg{i}" for i, ty in signature.items())
if len(signature) > 0
else ""
)
kernel_arg_decls = ", ".join(
_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*"
for i, ty in signature.items()
if ty != "constexpr"
)
kernel_arg_decls += ", " if kernel_arg_decls else ""
kernel_parameters = ", ".join(
f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}"
for i, ty in signature.items()
if ty != "constexpr"
)
kernel_parameters += ", " if kernel_parameters else ""
global autotune_time
return f"""
#include <assert.h>
#include <fstream>
#include <iostream>
#include <stdbool.h>
#include <Python.h>
#include "ExecutionEngine/CRunnerUtils.h"
#include "ExecutionEngine/CRunnerUtils.cpp"
#include "test_utils.h"
#include <chrono>
#include <cstdlib>
#include <ctime>
#include "xrt/xrt_bo.h"
#include "xrt/xrt_device.h"
#include "xrt/xrt_kernel.h"
static char aie_path[1024] = {{0}};
static char insts_path[1024] = {{0}};
static PyObject* py_set_paths(PyObject* self, PyObject* args) {{
const char* aie;
const char* insts;
if (!PyArg_ParseTuple(args, "ss", &aie, &insts)) {{
return NULL;
}}
strncpy(aie_path, aie, sizeof(aie_path) - 1);
strncpy(insts_path, insts, sizeof(insts_path) - 1);
aie_path[sizeof(aie_path) - 1] = '\\0';
insts_path[sizeof(insts_path) - 1] = '\\0';
Py_RETURN_NONE;
}}
// Call to XRT goes here:
static void _launch(int gridX, int gridY, int gridZ, {', '.join(f"long size{i}" for i, ty in signature.items() if i not in constants and ty[0]=="*")}, {arg_decls}) {{
if (gridX*gridY*gridZ > 0) {{
std::vector<uint32_t> instr_v =
test_utils::load_instr_binary(insts_path);
int verbosity = 1;
if (verbosity >= 1)
std::cout << "Sequence instr count: " << instr_v.size() << std::endl;
// Start the XRT test code
// Get a device handle
unsigned int device_index = 0;
auto device = xrt::device(device_index);
// Load the xclbin
if (verbosity >= 1)
std::cout << "Loading xclbin." << std::endl;
auto xclbin = xrt::xclbin(std::string(aie_path));
if (verbosity >= 1)
std::cout << "Kernel opcode: " << "MLIR_AIE" << std::endl;
std::string Node = "MLIR_AIE";
// Get the kernel from the xclbin
auto xkernels = xclbin.get_kernels();
auto xkernel = *std::find_if(xkernels.begin(), xkernels.end(),
[Node](xrt::xclbin::kernel &k) {{
auto name = k.get_name();
std::cout << "Name: " << name << std::endl;
return name.rfind(Node, 0) == 0;
}});
auto kernelName = xkernel.get_name();
if (verbosity >= 1)
std::cout << "Registering xclbin." << std::endl;
device.register_xclbin(xclbin);
// get a hardware context
if (verbosity >= 1)
std::cout << "Getting hardware context." << std::endl;
xrt::hw_context context(device, xclbin.get_uuid());
// get a kernel handle
if (verbosity >= 1)
std::cout << "Getting handle to kernel:" << kernelName << std::endl;
auto kernel = xrt::kernel(context, kernelName);
// get instruction sequence
auto bo_instr = xrt::bo(device, instr_v.size() * sizeof(int),
XCL_BO_FLAGS_CACHEABLE, kernel.group_id(1));
{' '.join(f'auto bo_{i} = xrt::bo(device, size{i}, XRT_BO_FLAGS_HOST_ONLY, kernel.group_id({i+3}));' for i, ty in signature.items() if i not in constants and ty[0] == "*")}
if (verbosity >= 1)
std::cout << "Writing data into buffer objects." << std::endl;
{' '.join(f'void *buf{i} = bo_{i}.map<void *>(); memcpy(buf{i}, arg{i}, size{i});' for i, ty in signature.items() if i not in constants and ty[0] == "*")}
void *bufInstr = bo_instr.map<void *>();
memcpy(bufInstr, instr_v.data(), instr_v.size() * sizeof(int));
bo_instr.sync(XCL_BO_SYNC_BO_TO_DEVICE);
{' '.join(f'bo_{i}.sync(XCL_BO_SYNC_BO_TO_DEVICE);' for i, ty in signature.items() if i not in constants and ty[0] == "*")}
if (verbosity >= 1)
std::cout << "Running Kernel." << std::endl;
unsigned int opcode = 3;
{'auto start = std::chrono::high_resolution_clock::now();' if autotune_time else ''}
auto run = kernel(opcode, bo_instr, instr_v.size(), {','.join(f'bo_{i}' for i, ty in signature.items() if i not in constants and ty[0] == "*")});
run.wait();
{'auto stop = std::chrono::high_resolution_clock::now(); float npu_time = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count();' if autotune_time else ''}
{'std::ofstream file("data.txt"); file << npu_time << std::endl; file.close();' if autotune_time else ''}
if (verbosity >= 1)
std::cout << "Copying results." << std::endl;
// TODO: Assuming the last tensor is the only output tensor.
bo_{next((i for i, ty in reversed(signature.items()) if i not in constants and ty[0] == "*"), None)}.sync(XCL_BO_SYNC_BO_FROM_DEVICE);
memcpy(arg{next((i for i, ty in reversed(signature.items()) if i not in constants and ty[0] == "*"), None)}, buf{next((i for i, ty in reversed(signature.items()) if i not in constants and ty[0] == "*"), None)}, size{next((i for i, ty in reversed(signature.items()) if i not in constants and ty[0] == "*"), None)});
if (verbosity >= 1)
std::cout << "Launch finished." << std::endl;
}}
}}
typedef struct _DevicePtrInfo {{
void *dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = reinterpret_cast<void *>(PyLong_AsUnsignedLongLong(obj));
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = reinterpret_cast<void *>(PyLong_AsUnsignedLongLong(ret));
if(!ptr_info.dev_ptr)
return ptr_info;
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
long getNumElements(PyObject *obj) {{
PyObject *shape = PyObject_GetAttrString(obj, "shape");
if (!shape) {{
PyErr_Print();
return -1;
}}
if (!PySequence_Check(shape)) {{
Py_DECREF(shape);
PyErr_SetString(PyExc_TypeError, "Attribute 'shape' is not a sequence.");
return -1;
}}
Py_ssize_t ndim = PySequence_Size(shape);
if (ndim < 0) {{
Py_DECREF(shape);
PyErr_Print();
return -1;
}}
long num_elements = 1;
for (Py_ssize_t i = 0; i < ndim; ++i) {{
PyObject *dim_obj = PySequence_GetItem(shape, i);
if (!dim_obj) {{
Py_DECREF(shape);
PyErr_Print();
return -1;
}}
long dim = PyLong_AsLong(dim_obj);
Py_DECREF(dim_obj);
if (dim == -1 && PyErr_Occurred()) {{
Py_DECREF(shape);
PyErr_Print();
return -1;
}}
num_elements *= dim;
}}
Py_DECREF(shape);
return num_elements;
}}
long getElementSizeInBytes(PyObject *obj) {{
if (!obj) return -1;
PyObject *dtype = PyObject_GetAttrString(obj, "dtype");
if (!dtype) {{
PyErr_Print();
return -1;
}}
PyObject *itemsize = PyObject_GetAttrString(dtype, "itemsize");
Py_DECREF(dtype);
if (!itemsize) {{
PyErr_Print();
return -1;
}}
long size = PyLong_AsLong(itemsize);
Py_DECREF(itemsize);
if (size == -1 && PyErr_Occurred()) {{
PyErr_Print();
return -1;
}}
return size;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *kernel_metadata = NULL;
PyObject *launch_metadata = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
&kernel_metadata, &launch_metadata,
&launch_enter_hook, &launch_exit_hook {args_list})) {{
return NULL;
}}
// extract launch metadata
if (launch_enter_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
Py_DECREF(args);
if (!ret)
return NULL;
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
{"; ".join([f"long tensor_volume{i} = getNumElements(_arg{i}) * getElementSizeInBytes(_arg{i}); if (tensor_volume{i} == -1) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, {', '.join(f"tensor_volume{i}" for i, ty in signature.items() if i not in constants and ty[0]=="*")}, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (PyErr_Occurred()) {{
return NULL;
}}
if(launch_exit_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
Py_DECREF(args);
if (!ret)
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{"set_paths", py_set_paths, METH_VARARGS, "Set paths to aie.bin and insts.bin"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__npu_dispatch\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___npu_dispatch(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
def _generate_elf_launcher(constants, signature, kernel_name):
"""Generate C++ launcher code using XRT ELF APIs (for NPU2/AIE2P only)."""
arg_decls = ", ".join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
args_format = "".join(
[_format_of(_extracted_type(ty)) for ty in signature.values()]
)
format = "iiiOOOO" + args_format
args_list = (
", " + ", ".join(f"&_arg{i}" for i, ty in signature.items())
if len(signature) > 0
else ""
)
global autotune_time
# Collect pointer (tensor) args excluding constants
ptr_args = [
(i, ty) for i, ty in signature.items() if i not in constants and ty[0] == "*"
]
last_ptr_idx = next(
(
i
for i, ty in reversed(signature.items())
if i not in constants and ty[0] == "*"
),
None,
)
# Build set_arg lines for kernel invocation
set_arg_lines = "\n ".join(
f"run.set_arg({idx}, bo_{i});" for idx, (i, ty) in enumerate(ptr_args)
)
return f"""
#include <assert.h>
#include <fstream>
#include <iostream>
#include <stdbool.h>
#include <Python.h>
#include "ExecutionEngine/CRunnerUtils.h"
#include "ExecutionEngine/CRunnerUtils.cpp"
#include <chrono>
#include <cstdlib>
#include <ctime>
#include "xrt/xrt_bo.h"
#include "xrt/xrt_device.h"
#include "xrt/xrt_kernel.h"
#include <xrt/experimental/xrt_elf.h>
#include <xrt/experimental/xrt_ext.h>
static char elf_path[1024] = {{0}};
static char elf_kernel_name[256] = {{0}};
static PyObject* py_set_paths(PyObject* self, PyObject* args) {{
const char* elf;
const char* kname;
if (!PyArg_ParseTuple(args, "ss", &elf, &kname)) {{
return NULL;
}}
strncpy(elf_path, elf, sizeof(elf_path) - 1);
elf_path[sizeof(elf_path) - 1] = '\\0';
strncpy(elf_kernel_name, kname, sizeof(elf_kernel_name) - 1);
elf_kernel_name[sizeof(elf_kernel_name) - 1] = '\\0';
Py_RETURN_NONE;
}}
// ELF-based XRT launch:
static void _launch(int gridX, int gridY, int gridZ, {', '.join(f"long size{i}" for i, ty in ptr_args)}, {arg_decls}) {{
if (gridX*gridY*gridZ > 0) {{
int verbosity = 1;
// Get a device handle
unsigned int device_index = 0;
auto device = xrt::device(device_index);
// Load the ELF
if (verbosity >= 1)