Skip to content

Commit 5eaeb71

Browse files
erwei-xilinxclaude
andcommitted
Fix black formatting in driver.py and padded_matmul.py
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 50362c9 commit 5eaeb71

2 files changed

Lines changed: 21 additions & 15 deletions

File tree

amd_triton_npu/backend/driver.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,8 +1141,9 @@ def _generate_elf_launcher(constants, signature, kernel_name):
11411141
"""
11421142

11431143

1144-
def compile_module(launcher_src, kernel_placeholder_name, output_format="xclbin",
1145-
actual_sizes=None):
1144+
def compile_module(
1145+
launcher_src, kernel_placeholder_name, output_format="xclbin", actual_sizes=None
1146+
):
11461147
py_version = sys.version_info
11471148
if platform.system() == "Windows":
11481149
py_include_dir = os.path.join(sys.base_prefix, "include")
@@ -1186,8 +1187,9 @@ def launch(
11861187
air_proj_path = _get_air_project_path()
11871188
os.makedirs(air_proj_path, exist_ok=True)
11881189
Path(os.path.join(air_proj_path, "asm_src.mlir")).write_bytes(asm_src)
1189-
air_output = _ttshared_to_air(asm_src, gridX, gridY, gridZ,
1190-
actual_sizes=actual_sizes)
1190+
air_output = _ttshared_to_air(
1191+
asm_src, gridX, gridY, gridZ, actual_sizes=actual_sizes
1192+
)
11911193
with open(Path(os.path.join(air_proj_path, "asm_air_output.mlir")), "w") as f:
11921194
f.write(str(air_output))
11931195

@@ -1408,10 +1410,16 @@ def __init__(self, src, metadata):
14081410
n_val = raw_constants.get((n_idx,))
14091411
if m_val is not None and n_val is not None:
14101412
# Check if BLOCK_SIZE_M/N are available to determine alignment
1411-
bsm_idx = (arg_names.index("BLOCK_SIZE_M")
1412-
if "BLOCK_SIZE_M" in arg_names else None)
1413-
bsn_idx = (arg_names.index("BLOCK_SIZE_N")
1414-
if "BLOCK_SIZE_N" in arg_names else None)
1413+
bsm_idx = (
1414+
arg_names.index("BLOCK_SIZE_M")
1415+
if "BLOCK_SIZE_M" in arg_names
1416+
else None
1417+
)
1418+
bsn_idx = (
1419+
arg_names.index("BLOCK_SIZE_N")
1420+
if "BLOCK_SIZE_N" in arg_names
1421+
else None
1422+
)
14151423
bsm = raw_constants.get((bsm_idx,)) if bsm_idx is not None else None
14161424
bsn = raw_constants.get((bsn_idx,)) if bsn_idx is not None else None
14171425
needs_padding = True
@@ -1423,7 +1431,9 @@ def __init__(self, src, metadata):
14231431
# Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name
14241432
# in the following launch function.
14251433
self.launch = compile_module(
1426-
launcher_src, kernel_placeholder_name, self.output_format,
1434+
launcher_src,
1435+
kernel_placeholder_name,
1436+
self.output_format,
14271437
actual_sizes=actual_sizes,
14281438
)
14291439

examples/padded_matmul/padded_matmul.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
N_actual = 500
4242
K_val = 1024
4343

44-
assert K_val % K_L2_TILE == 0, (
45-
f"K={K_val} must be divisible by K_L2_TILE={K_L2_TILE}"
46-
)
44+
assert K_val % K_L2_TILE == 0, f"K={K_val} must be divisible by K_L2_TILE={K_L2_TILE}"
4745

4846
# === Padded/allocated dimensions ===
4947
M_padded = math.ceil(M_actual / LAUNCH_TILE_M) * LAUNCH_TILE_M # 512
@@ -175,9 +173,7 @@ def run_padded_matmul():
175173
if not np.isclose(actual, expected, rtol=0.1, atol=10.0):
176174
errors += 1
177175
if errors <= 5:
178-
print(
179-
f"Mismatch at ({m}, {n}): actual={actual}, expected={expected}"
180-
)
176+
print(f"Mismatch at ({m}, {n}): actual={actual}, expected={expected}")
181177

182178
total = len(sample_m)
183179
if errors == 0:

0 commit comments

Comments
 (0)