|
| 1 | +# |
| 2 | +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 3 | +# See https://llvm.org/LICENSE.txt for license information. |
| 4 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 5 | +# |
| 6 | +# Copyright (C) 2026, Advanced Micro Devices, Inc. |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +import sys |
| 11 | +import os |
| 12 | +import numpy as np |
| 13 | +import aie.utils.test as test_utils |
| 14 | +import aie.iron as iron |
| 15 | +from aie.utils import TraceConfig, HostRuntime, NPUKernel, DefaultNPURuntime |
| 16 | +from aie.utils.ml import DataShaper |
| 17 | + |
| 18 | +torch.use_deterministic_algorithms(True) |
| 19 | +torch.manual_seed(0) |
| 20 | + |
| 21 | + |
| 22 | +def main(opts): |
| 23 | + print("Starting main function...") |
| 24 | + design = "conv3d" |
| 25 | + xclbin_path = opts.xclbin |
| 26 | + insts_path = opts.instr |
| 27 | + |
| 28 | + log_folder = "log/" |
| 29 | + if not os.path.exists(log_folder): |
| 30 | + os.makedirs(log_folder) |
| 31 | + |
| 32 | + depth = int(opts.depth) |
| 33 | + height = int(opts.height) |
| 34 | + width = int(opts.width) |
| 35 | + ci = int(opts.in_channels) |
| 36 | + co = int(opts.out_channels) |
| 37 | + print(f"Parameters: d={depth}, h={height}, w={width}, ci={ci}, co={co}") |
| 38 | + |
| 39 | + ci8 = ci // 8 |
| 40 | + co8 = co // 8 |
| 41 | + |
| 42 | + num_iter = 1 |
| 43 | + npu_time_total = 0 |
| 44 | + trace_size = opts.trace_size |
| 45 | + enable_trace = False if not trace_size else True |
| 46 | + trace_file = "log/trace_" + design + ".txt" |
| 47 | + |
| 48 | + # Data types |
| 49 | + dtype_in = np.dtype("uint8") |
| 50 | + dtype_wts = np.dtype("int8") |
| 51 | + dtype_out = np.dtype("uint8") |
| 52 | + |
| 53 | + # Data layout shapes |
| 54 | + # Input: D{C/8}H{C8}W (depth, channel-groups, height, channels-per-group, width) |
| 55 | + shape_in_act = (depth, ci8, height, 8, width) |
| 56 | + # Weights: {O/8}{I/8}KDHW{I8}{O8} |
| 57 | + shape_in_wts = (co8, ci8, 3, 3, 3, 8, 8) |
| 58 | + # Output: D{C/8}H{C8}W |
| 59 | + shape_out = (depth, co8, height, 8, width) |
| 60 | + |
| 61 | + # Initialize random input and weights |
| 62 | + int_inp = torch.randint(1, 20, (1, ci, depth, height, width)).type( |
| 63 | + torch.FloatTensor |
| 64 | + ) |
| 65 | + # True 3D kernel (3x3x3) |
| 66 | + int_weight = torch.randint(-50, 50, (co, ci, 3, 3, 3)).type(torch.FloatTensor) |
| 67 | + |
| 68 | + # Quantization scales |
| 69 | + conv_scale = 7.6294e-06 |
| 70 | + int8_scale = 0.0078 |
| 71 | + min_val = 0 |
| 72 | + max_val = 255 |
| 73 | + |
| 74 | + # Load NPU kernel |
| 75 | + npu_kernel = NPUKernel(xclbin_path, insts_path, kernel_name=opts.kernel) |
| 76 | + kernel_handle = DefaultNPURuntime.load(npu_kernel) |
| 77 | + |
| 78 | + # Define PyTorch reference model |
| 79 | + class Conv3dModel(nn.Module): |
| 80 | + def __init__(self): |
| 81 | + super().__init__() |
| 82 | + # True 3D convolution with 3x3x3 kernel |
| 83 | + # No padding in conv since we manually pad with replicate mode |
| 84 | + self.conv = nn.Conv3d(ci, co, kernel_size=3, padding=0, bias=False) |
| 85 | + |
| 86 | + def forward(self, x): |
| 87 | + out_int = self.conv(x) |
| 88 | + # Quantization: match NPU behavior |
| 89 | + out_quant = out_int * conv_scale |
| 90 | + out_float = int8_scale * torch.clamp( |
| 91 | + torch.round(out_quant / int8_scale), min_val, max_val |
| 92 | + ) |
| 93 | + return out_float |
| 94 | + |
| 95 | + # Generate golden output |
| 96 | + model = Conv3dModel() |
| 97 | + model.eval() |
| 98 | + model.conv.weight.data.copy_(int_weight) |
| 99 | + |
| 100 | + # Apply replication padding to match NPU kernel border handling |
| 101 | + # Pad: (left, right, top, bottom, front, back) for (W, H, D) |
| 102 | + int_inp_padded = torch.nn.functional.pad(int_inp, (1, 1, 1, 1, 1, 1), mode='replicate') |
| 103 | + golden_output = model(int_inp_padded) |
| 104 | + |
| 105 | + # Reorder input data layout |
| 106 | + ds = DataShaper() |
| 107 | + before_input = int_inp.squeeze().data.numpy().astype(dtype_in) # [ci, depth, height, width] |
| 108 | + before_input.tofile( |
| 109 | + log_folder + "/before_ifm_conv3d.txt", sep=",", format="%d" |
| 110 | + ) |
| 111 | + |
| 112 | + # Reorder: CDHW → D{C/8}H{C8}W manually |
| 113 | + ci8 = ci // 8 |
| 114 | + ifm_mem_fmt = np.zeros((depth, ci8, height, 8, width), dtype=dtype_in) |
| 115 | + for d in range(depth): |
| 116 | + for ic8 in range(ci8): |
| 117 | + for h in range(height): |
| 118 | + for ic in range(8): |
| 119 | + for w in range(width): |
| 120 | + ifm_mem_fmt[d, ic8, h, ic, w] = before_input[ |
| 121 | + ic8 * 8 + ic, d, h, w |
| 122 | + ] |
| 123 | + |
| 124 | + ifm_mem_fmt = ifm_mem_fmt.flatten() |
| 125 | + ifm_mem_fmt.tofile( |
| 126 | + log_folder + "/after_ifm_conv3d.txt", sep=",", format="%d" |
| 127 | + ) |
| 128 | + |
| 129 | + # Reorder weights: OIKDHW → {O/8}{I/8}KDHW{I8}{O8} |
| 130 | + # Manual reordering since DataShaper doesn't support 3D pattern yet |
| 131 | + wts_orig = int_weight.data.numpy().astype(dtype_wts) # [co, ci, 3, 3, 3] |
| 132 | + co8, ci8 = co // 8, ci // 8 |
| 133 | + wts = np.zeros((co8, ci8, 3, 3, 3, 8, 8), dtype=dtype_wts) |
| 134 | + |
| 135 | + for oc8 in range(co8): |
| 136 | + for ic8 in range(ci8): |
| 137 | + for kd in range(3): |
| 138 | + for kh in range(3): |
| 139 | + for kw in range(3): |
| 140 | + for i in range(8): |
| 141 | + for o in range(8): |
| 142 | + wts[oc8, ic8, kd, kh, kw, i, o] = wts_orig[ |
| 143 | + oc8 * 8 + o, ic8 * 8 + i, kd, kh, kw |
| 144 | + ] |
| 145 | + |
| 146 | + wts.tofile(log_folder + "/weights_conv3d.txt", sep=",", format="%d") |
| 147 | + |
| 148 | + # Determine number of cores from xclbin |
| 149 | + # For now, detect from out_channels: if 16 or more, use 2 cores; if 32 or more, use 4 cores |
| 150 | + if co >= 32: |
| 151 | + n_cores = 4 |
| 152 | + elif co >= 16: |
| 153 | + n_cores = 2 |
| 154 | + else: |
| 155 | + n_cores = 1 |
| 156 | + |
| 157 | + print(f"Using {n_cores} cores for inference") |
| 158 | + print(f"Output channels per core: {co // n_cores if n_cores > 1 else co}") |
| 159 | + |
| 160 | + # Prepare NPU buffers |
| 161 | + if n_cores == 1: |
| 162 | + in1 = iron.tensor(ifm_mem_fmt, dtype=dtype_in) |
| 163 | + in2 = iron.tensor(wts, dtype=dtype_wts) |
| 164 | + out_size = np.prod(shape_out) |
| 165 | + out = iron.zeros(out_size, dtype=dtype_out) |
| 166 | + buffers = [in1, in2, out] |
| 167 | + else: |
| 168 | + # Multi-core: duplicate inputs, split weights and outputs |
| 169 | + buffers = [] |
| 170 | + # Duplicated inputs (one per core) |
| 171 | + for c in range(n_cores): |
| 172 | + buffers.append(iron.tensor(ifm_mem_fmt, dtype=dtype_in)) |
| 173 | + # Split weights per core |
| 174 | + oc_per_core = co // n_cores |
| 175 | + oc8_per_core = oc_per_core // 8 |
| 176 | + for c in range(n_cores): |
| 177 | + # Extract weights for this core's output channels |
| 178 | + # wts is already shaped as (co8, ci8, 3, 3, 3, 8, 8) |
| 179 | + wts_start = c * oc8_per_core |
| 180 | + wts_end = (c + 1) * oc8_per_core |
| 181 | + wts_core = wts[wts_start:wts_end].flatten() |
| 182 | + buffers.append(iron.tensor(wts_core, dtype=dtype_wts)) |
| 183 | + # Output buffers per core (in elements, not bytes) |
| 184 | + out_size_per_core = np.prod(shape_out) // n_cores |
| 185 | + for c in range(n_cores): |
| 186 | + buffers.append(iron.zeros(out_size_per_core, dtype=dtype_out)) |
| 187 | + |
| 188 | + # Trace configuration |
| 189 | + trace_config = None |
| 190 | + if enable_trace: |
| 191 | + last_tensor = buffers[-1] |
| 192 | + trace_config = TraceConfig( |
| 193 | + trace_size=trace_size, |
| 194 | + trace_file=trace_file, |
| 195 | + trace_after_last_tensor=True, |
| 196 | + enable_ctrl_pkts=False, |
| 197 | + last_tensor_shape=last_tensor.shape, |
| 198 | + last_tensor_dtype=last_tensor.dtype, |
| 199 | + ) |
| 200 | + HostRuntime.prepare_args_for_trace(buffers, trace_config) |
| 201 | + |
| 202 | + # Run on NPU |
| 203 | + print(f"Running on NPU with {len(buffers)} buffers...") |
| 204 | + for i in range(num_iter): |
| 205 | + try: |
| 206 | + print(f"Iteration {i}, calling NPU...") |
| 207 | + ret = DefaultNPURuntime.run(kernel_handle, buffers) |
| 208 | + print(f"NPU returned successfully") |
| 209 | + if enable_trace: |
| 210 | + trace_buffer, _ = HostRuntime.extract_trace_from_args( |
| 211 | + buffers, trace_config |
| 212 | + ) |
| 213 | + trace_buffer = trace_buffer.view(np.uint32) |
| 214 | + trace_config.write_trace(trace_buffer) |
| 215 | + |
| 216 | + # Collect output tensors |
| 217 | + if n_cores == 1: |
| 218 | + out_tensor = buffers[-1] |
| 219 | + if not isinstance(out_tensor, np.ndarray): |
| 220 | + out_tensor = out_tensor.numpy() |
| 221 | + data_buffer = out_tensor * int8_scale |
| 222 | + else: |
| 223 | + # Multi-core: concatenate outputs from all cores |
| 224 | + # Each core produces shape (depth, co8_per_core, height, 8, width) |
| 225 | + oc8_per_core = (co // n_cores) // 8 |
| 226 | + out_shape_per_core = (depth, oc8_per_core, height, 8, width) |
| 227 | + out_tensors = [] |
| 228 | + for c in range(n_cores): |
| 229 | + out_idx = n_cores * 2 + c # After inputs and weights |
| 230 | + out_t = buffers[out_idx] |
| 231 | + if not isinstance(out_t, np.ndarray): |
| 232 | + out_t = out_t.numpy() |
| 233 | + # Reshape to proper layout |
| 234 | + out_t_reshaped = out_t.reshape(out_shape_per_core) |
| 235 | + out_tensors.append(out_t_reshaped) |
| 236 | + # Concatenate along channel dimension (axis=1, the co8 dimension) |
| 237 | + data_buffer = np.concatenate(out_tensors, axis=1).flatten() * int8_scale |
| 238 | + npu_time_total += ret.npu_time |
| 239 | + except Exception as e: |
| 240 | + print(f"\nNPU execution error: {e}") |
| 241 | + if enable_trace: |
| 242 | + print("Extracting trace buffer for debugging...") |
| 243 | + try: |
| 244 | + trace_buffer, _ = HostRuntime.extract_trace_from_args( |
| 245 | + buffers, trace_config |
| 246 | + ) |
| 247 | + trace_buffer = trace_buffer.view(np.uint32) |
| 248 | + |
| 249 | + # Save raw trace buffer |
| 250 | + raw_trace_file = "log/trace_conv3d_raw.bin" |
| 251 | + trace_buffer.tofile(raw_trace_file) |
| 252 | + print(f"Raw trace buffer saved to {raw_trace_file}") |
| 253 | + print(f"Trace buffer size: {len(trace_buffer)} uint32 values") |
| 254 | + print(f"First 20 values: {trace_buffer[:20]}") |
| 255 | + |
| 256 | + trace_config.write_trace(trace_buffer) |
| 257 | + print(f"Trace written to {trace_file}") |
| 258 | + except Exception as trace_err: |
| 259 | + print(f"Failed to extract trace: {trace_err}") |
| 260 | + import traceback |
| 261 | + traceback.print_exc() |
| 262 | + raise |
| 263 | + |
| 264 | + # Reorder output data layout: D{C/8}H{C8}W → CDHW |
| 265 | + temp_out = data_buffer.reshape(shape_out) # [depth, co8, height, 8, width] |
| 266 | + co8 = co // 8 |
| 267 | + ofm_mem_fmt = np.zeros((co, depth, height, width), dtype=np.float32) |
| 268 | + |
| 269 | + for d in range(depth): |
| 270 | + for oc8 in range(co8): |
| 271 | + for h in range(height): |
| 272 | + for oc in range(8): |
| 273 | + for w in range(width): |
| 274 | + ofm_mem_fmt[oc8 * 8 + oc, d, h, w] = temp_out[ |
| 275 | + d, oc8, h, oc, w |
| 276 | + ] |
| 277 | + ofm_mem_fmt.tofile( |
| 278 | + log_folder + "/after_ofm_conv3d.txt", sep=",", format="%d" |
| 279 | + ) |
| 280 | + ofm_mem_fmt_out = torch.from_numpy(ofm_mem_fmt).unsqueeze(0) |
| 281 | + |
| 282 | + # Compare NPU output with golden reference |
| 283 | + print(f"\nAvg NPU time: {int((npu_time_total / num_iter) / 1000)}us.") |
| 284 | + print(f"Volume size: {depth}x{height}x{width}, Channels: {ci}→{co}") |
| 285 | + |
| 286 | + # Note: Using 16x tolerance due to quantization rounding and border handling differences |
| 287 | + # True 3D convolution has more accumulations, higher error at borders |
| 288 | + if np.allclose( |
| 289 | + ofm_mem_fmt_out.detach().numpy(), |
| 290 | + golden_output.detach().numpy(), |
| 291 | + rtol=0, |
| 292 | + atol=16 * int8_scale, # 16x tolerance for 3D quantization + border effects |
| 293 | + ): |
| 294 | + print("\nPASS!\n") |
| 295 | + exit(0) |
| 296 | + else: |
| 297 | + max_diff = np.max( |
| 298 | + np.abs( |
| 299 | + ofm_mem_fmt_out.detach().numpy() - golden_output.detach().numpy() |
| 300 | + ) |
| 301 | + ) |
| 302 | + print(f"\nFailed. Max difference: {max_diff}\n") |
| 303 | + exit(-1) |
| 304 | + |
| 305 | + |
| 306 | +if __name__ == "__main__": |
| 307 | + p = test_utils.create_default_argparser() |
| 308 | + p.add_argument( |
| 309 | + "-d", |
| 310 | + "--depth", |
| 311 | + dest="depth", |
| 312 | + default=8, |
| 313 | + help="Depth of 3D convolution volume", |
| 314 | + ) |
| 315 | + p.add_argument( |
| 316 | + "-ht", |
| 317 | + "--height", |
| 318 | + dest="height", |
| 319 | + default=8, |
| 320 | + help="Height of 3D convolution volume", |
| 321 | + ) |
| 322 | + p.add_argument( |
| 323 | + "-wd", |
| 324 | + "--width", |
| 325 | + dest="width", |
| 326 | + default=8, |
| 327 | + help="Width of 3D convolution volume", |
| 328 | + ) |
| 329 | + p.add_argument( |
| 330 | + "-ic", |
| 331 | + "--in_channels", |
| 332 | + dest="in_channels", |
| 333 | + default=8, |
| 334 | + help="Number of input channels", |
| 335 | + ) |
| 336 | + p.add_argument( |
| 337 | + "-oc", |
| 338 | + "--out_channels", |
| 339 | + dest="out_channels", |
| 340 | + default=8, |
| 341 | + help="Number of output channels", |
| 342 | + ) |
| 343 | + opts = p.parse_args(sys.argv[1:]) |
| 344 | + main(opts) |
0 commit comments