diff --git a/autotune/core/tensor.py b/autotune/core/tensor.py index bbcb1c0..eb1c711 100644 --- a/autotune/core/tensor.py +++ b/autotune/core/tensor.py @@ -115,6 +115,8 @@ def __init__(self, par_axis: str, tile_sizes: Dict[str, int], tile_coordinates: f"Do not have exactly the same axes." ) self.tile_coordinates = tile_coordinates + self.multibuffering = multibuffering + self.tensor = None def load(self, source: HBMTensor) -> None: """Load data from HBM tensor into SBUF tiles with automatic padding. @@ -131,6 +133,9 @@ def load(self, source: HBMTensor) -> None: self.max_par_size = source.sizes[self.par_axis] self.max_free_size = source.sizes[self.free_axis] + if self.tensor is None: + self.init_as_zero(dtype=source.tensor.dtype) + self.init_as_zero(dtype=source.tensor.dtype) par_indices = nl.arange(self.tile_sizes[self.par_axis])[:, None] free_indices = nl.arange(self.tile_sizes[self.free_axis])[None, :] @@ -144,9 +149,9 @@ def load(self, source: HBMTensor) -> None: for free_tile_id in nl.affine_range(self.tile_coordinates[self.free_axis]["num_tiles"]): free_start = (free_tile_offset + free_tile_id) * self.tile_sizes[self.free_axis] free_mask = free_start + free_indices < self.max_free_size - self.tensor[par_indices, par_tile_id, free_tile_id, free_indices] = nl.load( - source.tensor[par_start + par_indices, free_start + free_indices], mask=par_mask & free_mask - ) + self.tensor[par_indices, block_id % self.multibuffering, par_tile_id, free_tile_id, free_indices] = nl.load( + source.tensor[par_start + par_indices, free_start + free_indices], mask=par_mask & free_mask + ) def init_as_zero(self, dtype): """ @@ -157,12 +162,21 @@ def init_as_zero(self, dtype): num_tiles (Dict[str, int]): _description_ dtype (_type_): _description_ """ - tensor_shape = ( - self.tile_sizes[self.par_axis], - self.tile_coordinates[self.par_axis]["num_tiles"], - self.tile_coordinates[self.free_axis]["num_tiles"], - self.tile_sizes[self.free_axis], - ) + if self.multibuffering: + tensor_shape = ( + nl.par_dim(self.tile_sizes[self.par_axis]), + self.multibuffering, + self.tile_coordinates[self.par_axis]["num_tiles"], + self.tile_coordinates[self.free_axis]["num_tiles"], + self.tile_sizes[self.free_axis], + ) + else: + tensor_shape = ( + self.tile_sizes[self.par_axis], + self.tile_coordinates[self.par_axis]["num_tiles"], + self.tile_coordinates[self.free_axis]["num_tiles"], + self.tile_sizes[self.free_axis], + ) self.tensor = nl.zeros(tensor_shape, dtype=dtype, buffer=nl.sbuf) def dump(self): @@ -186,7 +200,7 @@ def dump(self): ) return result - def tile_transpose(self): + def tile_transpose(self, block_id): """Transpose tensor tile-by-tile in place. Performs transpose operation on each tile, @@ -199,7 +213,7 @@ def tile_transpose(self): tileT_dtype = np.float32 idx_transp = nl.mgrid[0:pmax, 0:pmax] - par_tile_size, num_par_tiles, num_free_tiles, free_tile_size = self.tensor.shape + multibuffer, par_tile_size, num_par_tiles, num_free_tiles, free_tile_size = self.tensor.shape num_par_transp_tiles = math.ceil(par_tile_size / pmax) num_free_transp_tiles = math.ceil(free_tile_size / pmax) par_tile_offset = self.tile_coordinates[self.par_axis]["start_tile_index"] @@ -219,13 +233,13 @@ def tile_transpose(self): tileT = nl.ndarray((nl.par_dim(pmax), pmax), dtype=tileT_dtype, buffer=nl.psum) tileT[idx_transp.p, idx_transp.x] = nisa.nc_transpose( - self.tensor[par_indices, par_tile_id, free_tile_id, free_indices], mask=mask + self.tensor[par_indices, block_id % self.multibuffering, par_tile_id, free_tile_id, free_indices], mask=mask ) - self.tensor[par_indices, par_tile_id, free_tile_id, free_indices] = nl.copy( + self.tensor[par_indices, block_id % self.multibuffering, par_tile_id, free_tile_id, free_indices] = nl.copy( tileT, dtype=self.tensor.dtype ) - def read_tile(self, tile_indices: Dict[str, int]): + def read_tile(self, tile_indices: Dict[str, int], block_id=0): """Extract a specific tile from the tensor using global tile indices. Args: @@ -234,7 +248,10 @@ def read_tile(self, tile_indices: Dict[str, int]): Returns: The requested tile as a tensor """ - par_tile_size, par_num_tiles, free_num_tiles, free_tile_size = self.tensor.shape + if self.multibuffering: + par_tile_size, multibuff, par_num_tiles, free_num_tiles, free_tile_size = self.tensor.shape + else: + par_tile_size, par_num_tiles, free_num_tiles, free_tile_size = self.tensor.shape # Convert global indices to local indices par_tile_index = tile_indices[self.par_axis] - self.tile_coordinates[self.par_axis]["start_tile_index"] @@ -252,6 +269,10 @@ def read_tile(self, tile_indices: Dict[str, int]): idx_tile = nl.mgrid[0:par_tile_size, 0:free_tile_size] tile = self.tensor[idx_tile.p, par_tile_index, free_tile_index, idx_tile.x] + if self.multibuffering: + tile = self.tensor[idx_tile.p, block_id % self.multibuffering, par_tile_index, free_tile_index, idx_tile.x] + else: + tile = self.tensor[idx_tile.p, par_tile_index, free_tile_index, idx_tile.x] return tile def save_to_hbm(self, result): @@ -299,7 +320,7 @@ def save_to_hbm(self, result): # Read tile using global indices tile_indices = {self.par_axis: global_row_tile, self.free_axis: global_column_tile} - tile_data = self.read_tile(tile_indices) + tile_data = self.read_tile(tile_indices, block_id=0) # Store tile to result tensor nl.store( diff --git a/autotune/gemm/config.py b/autotune/gemm/config.py index 0981c37..9ffb36b 100644 --- a/autotune/gemm/config.py +++ b/autotune/gemm/config.py @@ -186,6 +186,8 @@ def __init__( self.op_positions["result"] = self.loop_order["K"] self.op_positions["save"] = self.loop_order["K"] self.op_positions["x_op"] = self.loop_order["K"] + self.lhs_rel_position = lhs_position + self.rhs_rel_position = rhs_position def _parse_absolute_position(self, relative_position: int, axes: Tuple[str, ...]) -> int: """ diff --git a/autotune/gemm/kernels.py b/autotune/gemm/kernels.py index 8b848e3..c331215 100644 --- a/autotune/gemm/kernels.py +++ b/autotune/gemm/kernels.py @@ -12,6 +12,8 @@ from autotune.gemm.config import GEMMConfig from autotune.gemm.utils import calculate_tile_overlap_ranges +MULTIBUFF = True +PSUM_BANKING = True class MetaGEMM: """ @@ -43,6 +45,10 @@ def __init__(self, transposed_lhs: bool, config: GEMMConfig) -> None: self.gemm_config = config self.axes = {"lhs": ("K", "M") if self.transposed_lhs else ("M", "K"), "rhs": ("K", "N"), "result": ("M", "N")} self.loop_ranges = {position: self._get_loop_range(position) for position in range(3)} + self.multibuff = 2 if MULTIBUFF else 1 + self.num_group_acc = 2 + self.lhs_tiles = None + self.rhs_tiles = None def _get_loop_range(self, position: int) -> int: """Check if any tensor operations at position > current will use this axis""" @@ -78,28 +84,78 @@ def __call__(self, lhs: tensor, rhs: tensor) -> Any: self.rhs_hbm = HBMTensor(rhs, axes=("K", "N")) self.result_hbm = nl.ndarray((self.gemm_config.M, self.gemm_config.N), dtype=lhs.dtype, buffer=nl.shared_hbm) loop_vars = {} - self.maybe_init(curr_position=0, loop_vars=loop_vars) - for block_id_0 in nl.affine_range(self.loop_ranges[0]): + + # Determine whether LHS + RHS will be multibuffered + self.lhs_multibuff = self.gemm_config.lhs_rel_position > 0 and self.multibuff > 1 + self.rhs_multibuff = self.gemm_config.rhs_rel_position > 0 and self.multibuff > 1 + + # If multibuffered --> initialize tensor one loop iteration before loading + if self.lhs_multibuff: + self.lhs_multibuff_init_pos = self.gemm_config.op_positions["lhs"] - 1 + if self.rhs_multibuff: + self.rhs_multibuff_init_pos = self.gemm_config.op_positions["rhs"] - 1 + + self.maybe_multibuff_init(curr_position=0, loop_vars=loop_vars) + self.maybe_init_or_load(curr_position=0, loop_vars=loop_vars) + + for block_id_0 in nl.affine_range(self.loop_ranges[0]): + loop_vars[self.gemm_config.loop_order[0]] = block_id_0 - self.maybe_init(curr_position=1, loop_vars=loop_vars) + self.maybe_multibuff_init(curr_position=1, loop_vars=loop_vars) + self.maybe_init_or_load(curr_position=1, loop_vars=loop_vars) + for block_id_1 in nl.affine_range(self.loop_ranges[1]): + loop_vars[self.gemm_config.loop_order[1]] = block_id_1 - self.maybe_init(curr_position=2, loop_vars=loop_vars) - for block_id_2 in nl.affine_range(self.loop_ranges[2]): + self.maybe_multibuff_init(curr_position=2, loop_vars=loop_vars) + self.maybe_init_or_load(curr_position=2, loop_vars=loop_vars) + + for block_id_2 in nl.affine_range(self.loop_ranges[2]): + loop_vars[self.gemm_config.loop_order[2]] = block_id_2 - self.maybe_init(curr_position=3, loop_vars=loop_vars) - matmul_tiles( - self.lhs_tiles, self.rhs_tiles, self.result_tiles, tile_transposed_lhs=not self.transposed_lhs - ) + self.maybe_init_or_load(curr_position=3, loop_vars=loop_vars) + + lhs_block_id = loop_vars[self.gemm_config.loop_order[self.lhs_multibuff_init_pos]] if self.lhs_multibuff else 0 + rhs_block_id = loop_vars[self.gemm_config.loop_order[self.rhs_multibuff_init_pos]] if self.rhs_multibuff else 0 + + if MULTIBUFF and PSUM_BANKING: + matmul_tiles_manual_alloc( + self.lhs_tiles, + self.rhs_tiles, + self.result_tiles, + tile_transposed_lhs=not self.transposed_lhs, + lhs_block_id=lhs_block_id, + rhs_block_id=rhs_block_id, + num_group_acc=self.num_group_acc + ) + + else: + matmul_tiles( + self.lhs_tiles, + self.rhs_tiles, + self.result_tiles, + tile_transposed_lhs=not self.transposed_lhs, + lhs_block_id=lhs_block_id, + rhs_block_id=rhs_block_id, + ) + del loop_vars[self.gemm_config.loop_order[2]] - self.maybe_store(curr_position=2) + self.maybe_store(curr_position=2) # written to hbm here (only happens 2x) + del loop_vars[self.gemm_config.loop_order[1]] self.maybe_store(curr_position=1) + del loop_vars[self.gemm_config.loop_order[0]] self.maybe_store(curr_position=0) + return self.result_hbm - def maybe_init(self, curr_position: int, loop_vars: Dict): + def maybe_init_or_load(self, curr_position: int, loop_vars: Dict): + if curr_position == 0: + block_id = 0 + else: + block_id = loop_vars[self.gemm_config.loop_order[curr_position-1]] + if self.gemm_config.op_positions["lhs"] == curr_position: lhs_tile_sizes: Dict[str, int] = {} lhs_tile_coordinates = TileCoordinates() @@ -111,13 +167,21 @@ def maybe_init(self, curr_position: int, loop_vars: Dict): else: start_tile_index = 0 num_tiles = getattr(self.gemm_config, f"TILES_IN_{axis}") - lhs_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) - self.lhs_tiles = SBUFTensor( - par_axis=self.axes["lhs"][0], tile_sizes=lhs_tile_sizes, tile_coordinates=lhs_tile_coordinates - ) - self.lhs_tiles.load(source=self.lhs_hbm) + + if self.lhs_tiles: + self.lhs_tiles.tile_coordinates.add_axis(axis, start_tile_index, num_tiles) + else: + lhs_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) + + if not self.lhs_tiles: + self.lhs_tiles = SBUFTensor( + par_axis=self.axes["lhs"][0], tile_sizes=lhs_tile_sizes, tile_coordinates=lhs_tile_coordinates, multibuffering=1, + ) + + self.lhs_tiles.load(source=self.lhs_hbm, block_id=block_id) if not self.transposed_lhs: - self.lhs_tiles.tile_transpose() + self.lhs_tiles.tile_transpose(block_id=block_id) + if self.gemm_config.op_positions["rhs"] == curr_position: rhs_tile_sizes: Dict[str, int] = {} rhs_tile_coordinates = TileCoordinates() @@ -129,11 +193,19 @@ def maybe_init(self, curr_position: int, loop_vars: Dict): else: start_tile_index = 0 num_tiles = getattr(self.gemm_config, f"TILES_IN_{axis}") - rhs_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) - self.rhs_tiles = SBUFTensor( - par_axis=self.axes["rhs"][0], tile_sizes=rhs_tile_sizes, tile_coordinates=rhs_tile_coordinates - ) - self.rhs_tiles.load(source=self.rhs_hbm) + + if self.rhs_tiles: + self.rhs_tiles.tile_coordinates.add_axis(axis, start_tile_index, num_tiles) + else: + rhs_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) + + if not self.rhs_tiles: + self.rhs_tiles = SBUFTensor( + par_axis=self.axes["rhs"][0], tile_sizes=rhs_tile_sizes, tile_coordinates=rhs_tile_coordinates, multibuffering=1, + ) + + self.rhs_tiles.load(source=self.rhs_hbm, block_id=block_id) + if self.gemm_config.op_positions["result"] == curr_position: result_tile_sizes = {} result_tile_coordinates = TileCoordinates() @@ -147,16 +219,136 @@ def maybe_init(self, curr_position: int, loop_vars: Dict): num_tiles = getattr(self.gemm_config, f"TILES_IN_{axis}") result_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) self.result_tiles = SBUFTensor( - par_axis=self.axes["result"][0], tile_sizes=result_tile_sizes, tile_coordinates=result_tile_coordinates + par_axis=self.axes["result"][0], tile_sizes=result_tile_sizes, tile_coordinates=result_tile_coordinates, multibuffering=None, ) self.result_tiles.init_as_zero(self.result_hbm.dtype) + def maybe_multibuff_init(self, curr_position, loop_vars): + if self.lhs_multibuff and self.lhs_multibuff_init_pos == curr_position: + lhs_tile_sizes: Dict[str, int] = {} + lhs_tile_coordinates = TileCoordinates() + for axis in self.axes["lhs"]: + lhs_tile_sizes[axis] = getattr(self.gemm_config, f"TILE_{axis}") + if self.gemm_config.op_positions["lhs"] > self.gemm_config.loop_order[axis]: # Set size of tensor to that according to position of load + start_tile_index = 0 # Will be updated during load + num_tiles = getattr(self.gemm_config, f"TILES_PER_BLOCK_{axis}") + else: + start_tile_index = 0 + num_tiles = getattr(self.gemm_config, f"TILES_IN_{axis}") + lhs_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) + + self.lhs_tiles = SBUFTensor( + par_axis=self.axes["lhs"][0], tile_sizes=lhs_tile_sizes, tile_coordinates=lhs_tile_coordinates, multibuffering=self.multibuff, + ) + self.lhs_tiles.init_as_zero(dtype=self.lhs_hbm.tensor.dtype) + + if self.rhs_multibuff and self.rhs_multibuff_init_pos == curr_position: + rhs_tile_sizes: Dict[str, int] = {} + rhs_tile_coordinates = TileCoordinates() + for axis in self.axes["rhs"]: + rhs_tile_sizes[axis] = getattr(self.gemm_config, f"TILE_{axis}") + if self.gemm_config.op_positions["rhs"] > self.gemm_config.loop_order[axis]: + start_tile_index = 0 # Will be updated during load + num_tiles = getattr(self.gemm_config, f"TILES_PER_BLOCK_{axis}") + else: + start_tile_index = 0 + num_tiles = getattr(self.gemm_config, f"TILES_IN_{axis}") + rhs_tile_coordinates.add_axis(axis, start_tile_index, num_tiles) + + self.rhs_tiles = SBUFTensor( + par_axis=self.axes["rhs"][0], tile_sizes=rhs_tile_sizes, tile_coordinates=rhs_tile_coordinates, multibuffering=self.multibuff, + ) + self.rhs_tiles.init_as_zero(dtype=self.rhs_hbm.tensor.dtype) + def maybe_store(self, curr_position: int): if self.gemm_config.op_positions["save"] == curr_position: self.result_tiles.save_to_hbm(self.result_hbm) +def matmul_tiles_manual_alloc(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBUFTensor, tile_transposed_lhs: bool, lhs_block_id: int, rhs_block_id: int, num_group_acc: int): + """ + Perform tiled matrix multiplication between SBUF tiles. + + Computes result_tiles += matmul(lhs_tiles, rhs_tiles) for the overlapping regions within each block. + + Args: + lhs_tiles: Left-hand side matrix tiles stored in SBUF memory + rhs_tiles: Right-hand side matrix tiles stored in SBUF memory + result_tiles: Output matrix tiles stored in SBUF memory where results + will be accumulated + tile_transposed_lhs: (bool) - Whether lhs_tiles is transposed at the tile level. + Note that this is not the same as lhsT_tiles. + """ + if tile_transposed_lhs: + TILE_M, _, _, _, TILE_K = lhs_tiles.tensor.shape + else: + TILE_K, _, _, _, TILE_M = lhs_tiles.tensor.shape + _TILE_K, _, _, _, TILE_N = rhs_tiles.tensor.shape + _TILE_M, _, _, _TILE_N = result_tiles.tensor.shape + assert ( + TILE_K == _TILE_K + ), f"lhs_tiles {lhs_tiles.tensor.shape} TILE_K mismatch with rhs_tiles {rhs_tiles.tensor.shape}" + assert ( + TILE_M == _TILE_M and TILE_N == _TILE_N + ), f"result_tiles {result_tiles.tensor.shape} shape mismatch with lhs_tiles {lhs_tiles.tensor.shape} @ rhs_tiles {rhs_tiles.tensor.shape}" + + # Calculate overlapping regions using the helper function + overlap_info = calculate_tile_overlap_ranges(lhs_tiles, rhs_tiles, result_tiles) + num_M_tiles, num_N_tiles, num_K_tiles = overlap_info["num_tiles"] + M_start = overlap_info["global_starts"]["M"] + N_start = overlap_info["global_starts"]["N"] + K_start = overlap_info["global_starts"]["K"] + result_M_offset, result_N_offset = overlap_info["result_offsets"] + + # Iterate over tiles using nl.affine_range for hardware optimization + idx_res = nl.mgrid[0:TILE_M, 0:TILE_N] + for tile_idx_N in nl.affine_range(num_N_tiles): + global_N_tile = N_start + tile_idx_N + + # Manual SplitAccGrp (multibuffering of PSUM result_tile) + num_M_groups = num_M_tiles // num_group_acc + num_M_tiles_in_group = num_group_acc + num_M_tiles_leftover = num_M_tiles % num_group_acc + + for m_group in nl.affine_range(num_M_groups): + result_tile = nl.zeros((num_group_acc, nl.par_dim(TILE_M), TILE_N), dtype=nl.float32, buffer=nl.psum) + + for m_tile in nl.affine_range(num_M_tiles_in_group): + tile_idx_M = (num_M_tiles_in_group * m_group) + m_tile + global_M_tile = M_start + tile_idx_M + + for tile_idx_K in nl.affine_range(num_K_tiles): + + global_K_tile = K_start + tile_idx_K + lhs_tile = lhs_tiles.read_tile(tile_indices={"M": global_M_tile, "K": global_K_tile}, block_id=lhs_block_id) + rhs_tile = rhs_tiles.read_tile(tile_indices={"K": global_K_tile, "N": global_N_tile}, block_id=rhs_block_id) + result_tile[m_tile, idx_res.p, idx_res.x] += nisa.nc_matmul(lhs_tile, rhs_tile) + + # for m_tile in nl.affine_range(num_M_tiles_in_group): + # tile_idx_M = (num_M_tiles_in_group * m_group) + m_tile + result_tiles.tensor[ + idx_res.p, result_M_offset + tile_idx_M, result_N_offset + tile_idx_N, idx_res.x + ] += result_tile[m_tile, idx_res.p, idx_res.x] + + if num_M_tiles_leftover > 0: + result_tile = nl.zeros((num_M_tiles_leftover, nl.par_dim(TILE_M), TILE_N), dtype=nl.float32, buffer=nl.psum) + + for m_tile in nl.affine_range(num_M_tiles_leftover): + tile_idx_M = (num_M_groups * num_M_tiles_in_group) + m_tile + global_M_tile = M_start + tile_idx_M + + for tile_idx_K in nl.affine_range(num_K_tiles): + global_K_tile = K_start + tile_idx_K + lhs_tile = lhs_tiles.read_tile(tile_indices={"M": global_M_tile, "K": global_K_tile}, block_id=lhs_block_id) + rhs_tile = rhs_tiles.read_tile(tile_indices={"K": global_K_tile, "N": global_N_tile}, block_id=rhs_block_id) + result_tile[m_tile, idx_res.p, idx_res.x] += nisa.nc_matmul(lhs_tile, rhs_tile) + + # for m_tile in nl.affine_range(num_M_tiles_leftover): + # tile_idx_M = (num_M_groups * num_M_tiles_in_group) + m_tile + result_tiles.tensor[ + idx_res.p, result_M_offset + tile_idx_M, result_N_offset + tile_idx_N, idx_res.x + ] += result_tile[m_tile, idx_res.p, idx_res.x] -def matmul_tiles(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBUFTensor, tile_transposed_lhs: bool): +def matmul_tiles(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBUFTensor, tile_transposed_lhs: bool, lhs_block_id: int, rhs_block_id: int): """ Perform tiled matrix multiplication between SBUF tiles. @@ -204,8 +396,8 @@ def matmul_tiles(lhs_tiles: SBUFTensor, rhs_tiles: SBUFTensor, result_tiles: SBU for tile_idx_K in nl.affine_range(num_K_tiles): global_K_tile = K_start + tile_idx_K # Read tiles using global indices (the read_tile method now handles conversion) - lhs_tile = lhs_tiles.read_tile(tile_indices={"M": global_M_tile, "K": global_K_tile}) - rhs_tile = rhs_tiles.read_tile(tile_indices={"K": global_K_tile, "N": global_N_tile}) + lhs_tile = lhs_tiles.read_tile(tile_indices={"M": global_M_tile, "K": global_K_tile}, block_id=lhs_block_id) + rhs_tile = rhs_tiles.read_tile(tile_indices={"K": global_K_tile, "N": global_N_tile}, block_id=rhs_block_id) result_tile += nisa.nc_matmul(lhs_tile, rhs_tile) # Store result using local indices for direct tensor access # FIXME: if K=1, just copy not add diff --git a/examples/gemm.py b/examples/gemm.py index 79a6a36..f1264ef 100644 --- a/examples/gemm.py +++ b/examples/gemm.py @@ -37,7 +37,7 @@ def collect_job_configs(shapes: List[Tuple[int, int, int]], transposed_lhs: bool examples_dir = os.path.dirname(current_file) # /path/to/nki-autotune/examples/ project_root = os.path.dirname(examples_dir) # /path/to/nki-autotune/ - data_type = "float32" + data_type = "bf16" if data_type == "float32": data_type = np.float32 postprocessing = GEMMCorrectness(transposed_lhs=transposed_lhs)