Skip to content

Commit 4886c3c

Browse files
committed
refactor gemm backend
1 parent 5e9e274 commit 4886c3c

File tree

7 files changed

+202
-153
lines changed

7 files changed

+202
-153
lines changed

autotune/gemm/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
GEMM (General Matrix Multiplication) package for NKI autotune.
6+
7+
This package contains all GEMM-related functionality including:
8+
- Kernel implementations
9+
- Configuration management
10+
- Validation and correctness checking
11+
- Utility functions
12+
"""
13+
14+
from autotune.gemm.config import GEMMConfig, generate_gemm_configs
15+
from autotune.gemm.kernels import MetaGEMM, lhs_rhs_meta_gemm, lhsT_rhs_meta_gemm
16+
from autotune.gemm.utils import calculate_tile_overlap_ranges
17+
from autotune.gemm.validation import GEMMCorrectness
18+
19+
__all__ = [
20+
# Configuration
21+
"GEMMConfig",
22+
"generate_gemm_configs",
23+
# Kernels
24+
"lhsT_rhs_meta_gemm",
25+
"lhs_rhs_meta_gemm",
26+
"MetaGEMM",
27+
# Validation
28+
"GEMMCorrectness",
29+
# Utils
30+
"calculate_tile_overlap_ranges",
31+
]
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import math
25
from itertools import permutations, product
36
from typing import Dict, List, Tuple
@@ -62,7 +65,7 @@ def generate_configs(**kwargs) -> List[Dict]:
6265
return configs
6366

6467

65-
def _generate_blocks_for_axis(axis: str, size: int, tile_size: int) -> List[Dict[str, int]]:
68+
def _generate_blocks_for_axis(size: int, tile_size: int) -> List[Dict[str, int]]:
6669
"""
6770
Generate valid block configurations for tiling an axis.
6871
@@ -276,9 +279,9 @@ def generate_gemm_configs(M: int, N: int, K: int) -> List[Dict]:
276279
TILE_N = nl.tile_size.gemm_moving_fmax # 512
277280
TILE_K = nl.tile_size.pmax # 128
278281

279-
m_configs = _generate_blocks_for_axis("M", M, TILE_M)
280-
n_configs = _generate_blocks_for_axis("N", N, TILE_N)
281-
k_configs = _generate_blocks_for_axis("K", K, TILE_K)
282+
m_configs = _generate_blocks_for_axis(M, TILE_M)
283+
n_configs = _generate_blocks_for_axis(N, TILE_N)
284+
k_configs = _generate_blocks_for_axis(K, TILE_K)
282285
loop_orders = ["".join(loop_order) for loop_order in permutations("MNK")]
283286
lhs_positions = [0, 1, 2]
284287
rhs_positions = [0, 1, 2]
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
54
from typing import Any, Dict
65

76
import neuronxcc.nki as nki
87
import neuronxcc.nki.isa as nisa
98
import neuronxcc.nki.language as nl
109
from neuronxcc.nki.typing import tensor
1110

12-
from autotune.core.gemm_config import GEMMConfig
1311
from autotune.core.tensor import HBMTensor, SBUFTensor, TileCoordinates
14-
from autotune.modules.matmul import calculate_tile_overlap_ranges
12+
from autotune.gemm.config import GEMMConfig
13+
from autotune.gemm.utils import calculate_tile_overlap_ranges
1514

1615

1716
class MetaGEMM:

autotune/gemm/utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from autotune.core.tensor import SBUFTensor
5+
6+
7+
def max_nki_index(index1, index2):
8+
if index1 == 0:
9+
max_index = index2
10+
elif index2 == 0:
11+
max_index = index1
12+
else:
13+
max_index = max(index1, index2)
14+
return max_index
15+
16+
17+
def calculate_tile_overlap(coords1: dict, coords2: dict) -> tuple[int, int]:
18+
"""
19+
Calculate the overlapping tile region between two tile coordinate ranges.
20+
21+
Args:
22+
coords1: First tile coordinates dictionary with 'start_tile_index' and 'num_tiles'
23+
coords2: Second tile coordinates dictionary with 'start_tile_index' and 'num_tiles'
24+
25+
Returns:
26+
Tuple of (overlap_start, num_overlap_tiles)
27+
"""
28+
start_1 = coords1["start_tile_index"]
29+
start_2 = coords2["start_tile_index"]
30+
overlap_start = max_nki_index(start_1, start_2)
31+
# print(f"start_1 {start_1} start_2 {start_2} --> overlap_start = {overlap_start}.")
32+
33+
num_tiles_1 = coords1["num_tiles"]
34+
num_tiles_2 = coords2["num_tiles"]
35+
num_overlap_tiles = min(num_tiles_1, num_tiles_2)
36+
# print(f"num_tiles_1 {num_tiles_1} num_tiles_2 {num_tiles_2} --> num_overlap_tiles = {num_overlap_tiles}.")
37+
38+
return overlap_start, num_overlap_tiles
39+
40+
41+
def calculate_tile_overlap_ranges(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBUFTensor) -> dict:
42+
"""
43+
Calculate the overlapping tile ranges for matrix multiplication.
44+
45+
Args:
46+
lhs_tiles: Left-hand side SBUF tensor
47+
rhs_tiles: Right-hand side SBUF tensor
48+
result_tiles: Result SBUF tensor
49+
50+
Returns:
51+
Dictionary containing:
52+
- num_tiles: (num_M_tiles, num_N_tiles, num_K_tiles)
53+
- global_starts: Dictionary with global tile ranges for each dimension
54+
- M: M_start
55+
- N: N_start
56+
- K: K_start
57+
- result_offsets: (M_offset, N_offset) for result tensor local indexing
58+
"""
59+
# Calculate overlapping regions for each dimension (in global coordinates)
60+
K_start, num_K_tiles = calculate_tile_overlap(lhs_tiles.tile_coordinates["K"], rhs_tiles.tile_coordinates["K"])
61+
M_start, num_M_tiles = calculate_tile_overlap(lhs_tiles.tile_coordinates["M"], result_tiles.tile_coordinates["M"])
62+
N_start, num_N_tiles = calculate_tile_overlap(rhs_tiles.tile_coordinates["N"], result_tiles.tile_coordinates["N"])
63+
64+
# Calculate local offsets for result tensor (still needed for direct tensor access)
65+
result_M_offset = M_start - result_tiles.tile_coordinates["M"]["start_tile_index"]
66+
result_N_offset = N_start - result_tiles.tile_coordinates["N"]["start_tile_index"]
67+
68+
return {
69+
"num_tiles": (num_M_tiles, num_N_tiles, num_K_tiles),
70+
"global_starts": {"M": M_start, "N": N_start, "K": K_start},
71+
"result_offsets": (result_M_offset, result_N_offset),
72+
}

autotune/gemm/validation.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import neuronxcc.nki.language as nl
5+
import numpy as np
6+
7+
from autotune.core.metrics import check_correctness
8+
from autotune.typing import INPUT_TENSORS_DTYPE, KERNEL_KWARGS_DTYPE, OUTPUT_TENSORS_DTYPE
9+
10+
11+
class GEMMCorrectness:
12+
def __init__(self, transposed_lhs: bool) -> None:
13+
self.transposed_lhs = transposed_lhs
14+
15+
def __call__(
16+
self,
17+
input_tensors: INPUT_TENSORS_DTYPE,
18+
kernel_kwargs: KERNEL_KWARGS_DTYPE,
19+
nki_out_tensors: OUTPUT_TENSORS_DTYPE,
20+
):
21+
data_type = np.float32
22+
atol, rtol = 1e-5, 1e-2
23+
lhs, rhs = input_tensors
24+
if self.transposed_lhs:
25+
golden = nl.static_cast(lhsT_rhs_gemm_np(lhs, rhs), data_type)
26+
else:
27+
golden = nl.static_cast(lhs_rhs_gemm_np(lhs, rhs), data_type)
28+
nki_out_tensor = nl.static_cast(nki_out_tensors[0], data_type)
29+
30+
# Use the centralized check_correctness function from metrics module
31+
check_correctness(golden, nki_out_tensor, atol, rtol)
32+
33+
34+
def lhs_rhs_gemm_np(lhs, rhs):
35+
"""
36+
Calculate the general matrix multiplication (GEMM) between lhs and rhs.
37+
38+
Parameters:
39+
-----------
40+
lhs : numpy.ndarray
41+
Left-hand side matrix or tensor. Can have an extra batch dimension.
42+
rhs : numpy.ndarray
43+
Right-hand side matrix.
44+
45+
Returns:
46+
--------
47+
numpy.ndarray
48+
Result of the matrix multiplication.
49+
"""
50+
return np.matmul(lhs, rhs)
51+
52+
53+
def lhsT_rhs_gemm_np(lhsT, rhs):
54+
"""
55+
Calculate the general matrix multiplication (GEMM) between lhsT and rhs.
56+
57+
Parameters:
58+
-----------
59+
lhs : numpy.ndarray
60+
Left-hand side matrix or tensor. Can have an extra batch dimension.
61+
rhs : numpy.ndarray
62+
Right-hand side matrix.
63+
64+
Returns:
65+
--------
66+
numpy.ndarray
67+
Result of the matrix multiplication.
68+
"""
69+
if len(lhsT.shape) == 2:
70+
lhs = np.transpose(lhsT, (1, 0))
71+
elif len(lhsT.shape) == 3: # Batch dimension exists
72+
lhs = np.transpose(lhsT, (0, 2, 1))
73+
else:
74+
raise NotImplementedError(f"lhsT shape {lhsT.shape} is not supported in GEMM.")
75+
return np.matmul(lhs, rhs)

autotune/modules/matmul.py

Lines changed: 4 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,141 +1,5 @@
1-
import neuronxcc.nki.language as nl
2-
import numpy as np
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
33

4-
from autotune.core.metrics import check_correctness
5-
from autotune.core.tensor import SBUFTensor
6-
from autotune.typing import INPUT_TENSORS_DTYPE, KERNEL_KWARGS_DTYPE, OUTPUT_TENSORS_DTYPE
7-
8-
9-
class GEMMCorrectness:
10-
def __init__(self, transposed_lhs: bool) -> None:
11-
self.transposed_lhs = transposed_lhs
12-
13-
def __call__(
14-
self,
15-
input_tensors: INPUT_TENSORS_DTYPE,
16-
kernel_kwargs: KERNEL_KWARGS_DTYPE,
17-
nki_out_tensors: OUTPUT_TENSORS_DTYPE,
18-
):
19-
data_type = np.float32
20-
atol, rtol = 1e-5, 1e-2
21-
lhs, rhs = input_tensors
22-
if self.transposed_lhs:
23-
golden = nl.static_cast(lhsT_rhs_gemm_np(lhs, rhs), data_type)
24-
else:
25-
golden = nl.static_cast(lhs_rhs_gemm_np(lhs, rhs), data_type)
26-
nki_out_tensor = nl.static_cast(nki_out_tensors[0], data_type)
27-
28-
# Use the centralized check_correctness function from metrics module
29-
check_correctness(golden, nki_out_tensor, atol, rtol)
30-
31-
32-
def max_nki_index(index1, index2):
33-
if index1 == 0:
34-
max_index = index2
35-
elif index2 == 0:
36-
max_index = index1
37-
else:
38-
max_index = max(index1, index2)
39-
return max_index
40-
41-
42-
def calculate_tile_overlap(coords1: dict, coords2: dict) -> tuple[int, int]:
43-
"""
44-
Calculate the overlapping tile region between two tile coordinate ranges.
45-
46-
Args:
47-
coords1: First tile coordinates dictionary with 'start_tile_index' and 'num_tiles'
48-
coords2: Second tile coordinates dictionary with 'start_tile_index' and 'num_tiles'
49-
50-
Returns:
51-
Tuple of (overlap_start, num_overlap_tiles)
52-
"""
53-
start_1 = coords1["start_tile_index"]
54-
start_2 = coords2["start_tile_index"]
55-
overlap_start = max_nki_index(start_1, start_2)
56-
# print(f"start_1 {start_1} start_2 {start_2} --> overlap_start = {overlap_start}.")
57-
58-
num_tiles_1 = coords1["num_tiles"]
59-
num_tiles_2 = coords2["num_tiles"]
60-
num_overlap_tiles = min(num_tiles_1, num_tiles_2)
61-
# print(f"num_tiles_1 {num_tiles_1} num_tiles_2 {num_tiles_2} --> num_overlap_tiles = {num_overlap_tiles}.")
62-
63-
return overlap_start, num_overlap_tiles
64-
65-
66-
def calculate_tile_overlap_ranges(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBUFTensor) -> dict:
67-
"""
68-
Calculate the overlapping tile ranges for matrix multiplication.
69-
70-
Args:
71-
lhs_tiles: Left-hand side SBUF tensor
72-
rhs_tiles: Right-hand side SBUF tensor
73-
result_tiles: Result SBUF tensor
74-
75-
Returns:
76-
Dictionary containing:
77-
- num_tiles: (num_M_tiles, num_N_tiles, num_K_tiles)
78-
- global_starts: Dictionary with global tile ranges for each dimension
79-
- M: M_start
80-
- N: N_start
81-
- K: K_start
82-
- result_offsets: (M_offset, N_offset) for result tensor local indexing
83-
"""
84-
# Calculate overlapping regions for each dimension (in global coordinates)
85-
K_start, num_K_tiles = calculate_tile_overlap(lhs_tiles.tile_coordinates["K"], rhs_tiles.tile_coordinates["K"])
86-
M_start, num_M_tiles = calculate_tile_overlap(lhs_tiles.tile_coordinates["M"], result_tiles.tile_coordinates["M"])
87-
N_start, num_N_tiles = calculate_tile_overlap(rhs_tiles.tile_coordinates["N"], result_tiles.tile_coordinates["N"])
88-
89-
# Calculate local offsets for result tensor (still needed for direct tensor access)
90-
result_M_offset = M_start - result_tiles.tile_coordinates["M"]["start_tile_index"]
91-
result_N_offset = N_start - result_tiles.tile_coordinates["N"]["start_tile_index"]
92-
93-
return {
94-
"num_tiles": (num_M_tiles, num_N_tiles, num_K_tiles),
95-
"global_starts": {"M": M_start, "N": N_start, "K": K_start},
96-
"result_offsets": (result_M_offset, result_N_offset),
97-
}
98-
99-
100-
def lhs_rhs_gemm_np(lhs, rhs):
101-
"""
102-
Calculate the general matrix multiplication (GEMM) between lhs and rhs.
103-
104-
Parameters:
105-
-----------
106-
lhs : numpy.ndarray
107-
Left-hand side matrix or tensor. Can have an extra batch dimension.
108-
rhs : numpy.ndarray
109-
Right-hand side matrix.
110-
111-
Returns:
112-
--------
113-
numpy.ndarray
114-
Result of the matrix multiplication.
115-
"""
116-
return np.matmul(lhs, rhs)
117-
118-
119-
def lhsT_rhs_gemm_np(lhsT, rhs):
120-
"""
121-
Calculate the general matrix multiplication (GEMM) between lhsT and rhs.
122-
123-
Parameters:
124-
-----------
125-
lhs : numpy.ndarray
126-
Left-hand side matrix or tensor. Can have an extra batch dimension.
127-
rhs : numpy.ndarray
128-
Right-hand side matrix.
129-
130-
Returns:
131-
--------
132-
numpy.ndarray
133-
Result of the matrix multiplication.
134-
"""
135-
if len(lhsT.shape) == 2:
136-
lhs = np.transpose(lhsT, (1, 0))
137-
elif len(lhsT.shape) == 3: # Batch dimension exists
138-
lhs = np.transpose(lhsT, (0, 2, 1))
139-
else:
140-
raise NotImplementedError(f"lhsT shape {lhsT.shape} is not supported in GEMM.")
141-
return np.matmul(lhs, rhs)
4+
# This module is for non-GEMM matmul functionality
5+
# GEMM-specific functionality has been moved to autotune.gemm package

0 commit comments

Comments
 (0)