|
1 | | -import neuronxcc.nki.language as nl |
2 | | -import numpy as np |
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | | -from autotune.core.metrics import check_correctness |
5 | | -from autotune.core.tensor import SBUFTensor |
6 | | -from autotune.typing import INPUT_TENSORS_DTYPE, KERNEL_KWARGS_DTYPE, OUTPUT_TENSORS_DTYPE |
7 | | - |
8 | | - |
9 | | -class GEMMCorrectness: |
10 | | - def __init__(self, transposed_lhs: bool) -> None: |
11 | | - self.transposed_lhs = transposed_lhs |
12 | | - |
13 | | - def __call__( |
14 | | - self, |
15 | | - input_tensors: INPUT_TENSORS_DTYPE, |
16 | | - kernel_kwargs: KERNEL_KWARGS_DTYPE, |
17 | | - nki_out_tensors: OUTPUT_TENSORS_DTYPE, |
18 | | - ): |
19 | | - data_type = np.float32 |
20 | | - atol, rtol = 1e-5, 1e-2 |
21 | | - lhs, rhs = input_tensors |
22 | | - if self.transposed_lhs: |
23 | | - golden = nl.static_cast(lhsT_rhs_gemm_np(lhs, rhs), data_type) |
24 | | - else: |
25 | | - golden = nl.static_cast(lhs_rhs_gemm_np(lhs, rhs), data_type) |
26 | | - nki_out_tensor = nl.static_cast(nki_out_tensors[0], data_type) |
27 | | - |
28 | | - # Use the centralized check_correctness function from metrics module |
29 | | - check_correctness(golden, nki_out_tensor, atol, rtol) |
30 | | - |
31 | | - |
32 | | -def max_nki_index(index1, index2): |
33 | | - if index1 == 0: |
34 | | - max_index = index2 |
35 | | - elif index2 == 0: |
36 | | - max_index = index1 |
37 | | - else: |
38 | | - max_index = max(index1, index2) |
39 | | - return max_index |
40 | | - |
41 | | - |
42 | | -def calculate_tile_overlap(coords1: dict, coords2: dict) -> tuple[int, int]: |
43 | | - """ |
44 | | - Calculate the overlapping tile region between two tile coordinate ranges. |
45 | | -
|
46 | | - Args: |
47 | | - coords1: First tile coordinates dictionary with 'start_tile_index' and 'num_tiles' |
48 | | - coords2: Second tile coordinates dictionary with 'start_tile_index' and 'num_tiles' |
49 | | -
|
50 | | - Returns: |
51 | | - Tuple of (overlap_start, num_overlap_tiles) |
52 | | - """ |
53 | | - start_1 = coords1["start_tile_index"] |
54 | | - start_2 = coords2["start_tile_index"] |
55 | | - overlap_start = max_nki_index(start_1, start_2) |
56 | | - # print(f"start_1 {start_1} start_2 {start_2} --> overlap_start = {overlap_start}.") |
57 | | - |
58 | | - num_tiles_1 = coords1["num_tiles"] |
59 | | - num_tiles_2 = coords2["num_tiles"] |
60 | | - num_overlap_tiles = min(num_tiles_1, num_tiles_2) |
61 | | - # print(f"num_tiles_1 {num_tiles_1} num_tiles_2 {num_tiles_2} --> num_overlap_tiles = {num_overlap_tiles}.") |
62 | | - |
63 | | - return overlap_start, num_overlap_tiles |
64 | | - |
65 | | - |
66 | | -def calculate_tile_overlap_ranges(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBUFTensor) -> dict: |
67 | | - """ |
68 | | - Calculate the overlapping tile ranges for matrix multiplication. |
69 | | -
|
70 | | - Args: |
71 | | - lhs_tiles: Left-hand side SBUF tensor |
72 | | - rhs_tiles: Right-hand side SBUF tensor |
73 | | - result_tiles: Result SBUF tensor |
74 | | -
|
75 | | - Returns: |
76 | | - Dictionary containing: |
77 | | - - num_tiles: (num_M_tiles, num_N_tiles, num_K_tiles) |
78 | | - - global_starts: Dictionary with global tile ranges for each dimension |
79 | | - - M: M_start |
80 | | - - N: N_start |
81 | | - - K: K_start |
82 | | - - result_offsets: (M_offset, N_offset) for result tensor local indexing |
83 | | - """ |
84 | | - # Calculate overlapping regions for each dimension (in global coordinates) |
85 | | - K_start, num_K_tiles = calculate_tile_overlap(lhs_tiles.tile_coordinates["K"], rhs_tiles.tile_coordinates["K"]) |
86 | | - M_start, num_M_tiles = calculate_tile_overlap(lhs_tiles.tile_coordinates["M"], result_tiles.tile_coordinates["M"]) |
87 | | - N_start, num_N_tiles = calculate_tile_overlap(rhs_tiles.tile_coordinates["N"], result_tiles.tile_coordinates["N"]) |
88 | | - |
89 | | - # Calculate local offsets for result tensor (still needed for direct tensor access) |
90 | | - result_M_offset = M_start - result_tiles.tile_coordinates["M"]["start_tile_index"] |
91 | | - result_N_offset = N_start - result_tiles.tile_coordinates["N"]["start_tile_index"] |
92 | | - |
93 | | - return { |
94 | | - "num_tiles": (num_M_tiles, num_N_tiles, num_K_tiles), |
95 | | - "global_starts": {"M": M_start, "N": N_start, "K": K_start}, |
96 | | - "result_offsets": (result_M_offset, result_N_offset), |
97 | | - } |
98 | | - |
99 | | - |
100 | | -def lhs_rhs_gemm_np(lhs, rhs): |
101 | | - """ |
102 | | - Calculate the general matrix multiplication (GEMM) between lhs and rhs. |
103 | | -
|
104 | | - Parameters: |
105 | | - ----------- |
106 | | - lhs : numpy.ndarray |
107 | | - Left-hand side matrix or tensor. Can have an extra batch dimension. |
108 | | - rhs : numpy.ndarray |
109 | | - Right-hand side matrix. |
110 | | -
|
111 | | - Returns: |
112 | | - -------- |
113 | | - numpy.ndarray |
114 | | - Result of the matrix multiplication. |
115 | | - """ |
116 | | - return np.matmul(lhs, rhs) |
117 | | - |
118 | | - |
119 | | -def lhsT_rhs_gemm_np(lhsT, rhs): |
120 | | - """ |
121 | | - Calculate the general matrix multiplication (GEMM) between lhsT and rhs. |
122 | | -
|
123 | | - Parameters: |
124 | | - ----------- |
125 | | - lhs : numpy.ndarray |
126 | | - Left-hand side matrix or tensor. Can have an extra batch dimension. |
127 | | - rhs : numpy.ndarray |
128 | | - Right-hand side matrix. |
129 | | -
|
130 | | - Returns: |
131 | | - -------- |
132 | | - numpy.ndarray |
133 | | - Result of the matrix multiplication. |
134 | | - """ |
135 | | - if len(lhsT.shape) == 2: |
136 | | - lhs = np.transpose(lhsT, (1, 0)) |
137 | | - elif len(lhsT.shape) == 3: # Batch dimension exists |
138 | | - lhs = np.transpose(lhsT, (0, 2, 1)) |
139 | | - else: |
140 | | - raise NotImplementedError(f"lhsT shape {lhsT.shape} is not supported in GEMM.") |
141 | | - return np.matmul(lhs, rhs) |
| 4 | +# This module is for non-GEMM matmul functionality |
| 5 | +# GEMM-specific functionality has been moved to autotune.gemm package |
0 commit comments