Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions autotune/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __init__(self, par_axis: str, tile_sizes: Dict[str, int], tile_coordinates:
f"Do not have exactly the same axes."
)
self.tile_coordinates = tile_coordinates
self.multibuffering = multibuffering
self.tensor = None

def load(self, source: HBMTensor) -> None:
"""Load data from HBM tensor into SBUF tiles with automatic padding.
Expand All @@ -131,6 +133,9 @@ def load(self, source: HBMTensor) -> None:
self.max_par_size = source.sizes[self.par_axis]
self.max_free_size = source.sizes[self.free_axis]

if self.tensor is None:
self.init_as_zero(dtype=source.tensor.dtype)

self.init_as_zero(dtype=source.tensor.dtype)
par_indices = nl.arange(self.tile_sizes[self.par_axis])[:, None]
free_indices = nl.arange(self.tile_sizes[self.free_axis])[None, :]
Expand All @@ -144,9 +149,9 @@ def load(self, source: HBMTensor) -> None:
for free_tile_id in nl.affine_range(self.tile_coordinates[self.free_axis]["num_tiles"]):
free_start = (free_tile_offset + free_tile_id) * self.tile_sizes[self.free_axis]
free_mask = free_start + free_indices < self.max_free_size
self.tensor[par_indices, par_tile_id, free_tile_id, free_indices] = nl.load(
source.tensor[par_start + par_indices, free_start + free_indices], mask=par_mask & free_mask
)
self.tensor[par_indices, block_id % self.multibuffering, par_tile_id, free_tile_id, free_indices] = nl.load(
source.tensor[par_start + par_indices, free_start + free_indices], mask=par_mask & free_mask
)

def init_as_zero(self, dtype):
"""
Expand All @@ -157,12 +162,21 @@ def init_as_zero(self, dtype):
num_tiles (Dict[str, int]): _description_
dtype (_type_): _description_
"""
tensor_shape = (
self.tile_sizes[self.par_axis],
self.tile_coordinates[self.par_axis]["num_tiles"],
self.tile_coordinates[self.free_axis]["num_tiles"],
self.tile_sizes[self.free_axis],
)
if self.multibuffering:
tensor_shape = (
nl.par_dim(self.tile_sizes[self.par_axis]),
self.multibuffering,
self.tile_coordinates[self.par_axis]["num_tiles"],
self.tile_coordinates[self.free_axis]["num_tiles"],
self.tile_sizes[self.free_axis],
)
else:
tensor_shape = (
self.tile_sizes[self.par_axis],
self.tile_coordinates[self.par_axis]["num_tiles"],
self.tile_coordinates[self.free_axis]["num_tiles"],
self.tile_sizes[self.free_axis],
)
self.tensor = nl.zeros(tensor_shape, dtype=dtype, buffer=nl.sbuf)

def dump(self):
Expand All @@ -186,7 +200,7 @@ def dump(self):
)
return result

def tile_transpose(self):
def tile_transpose(self, block_id):
"""Transpose tensor tile-by-tile in place.

Performs transpose operation on each tile,
Expand All @@ -199,7 +213,7 @@ def tile_transpose(self):
tileT_dtype = np.float32

idx_transp = nl.mgrid[0:pmax, 0:pmax]
par_tile_size, num_par_tiles, num_free_tiles, free_tile_size = self.tensor.shape
multibuffer, par_tile_size, num_par_tiles, num_free_tiles, free_tile_size = self.tensor.shape
num_par_transp_tiles = math.ceil(par_tile_size / pmax)
num_free_transp_tiles = math.ceil(free_tile_size / pmax)
par_tile_offset = self.tile_coordinates[self.par_axis]["start_tile_index"]
Expand All @@ -219,13 +233,13 @@ def tile_transpose(self):

tileT = nl.ndarray((nl.par_dim(pmax), pmax), dtype=tileT_dtype, buffer=nl.psum)
tileT[idx_transp.p, idx_transp.x] = nisa.nc_transpose(
self.tensor[par_indices, par_tile_id, free_tile_id, free_indices], mask=mask
self.tensor[par_indices, block_id % self.multibuffering, par_tile_id, free_tile_id, free_indices], mask=mask
)
self.tensor[par_indices, par_tile_id, free_tile_id, free_indices] = nl.copy(
self.tensor[par_indices, block_id % self.multibuffering, par_tile_id, free_tile_id, free_indices] = nl.copy(
tileT, dtype=self.tensor.dtype
)

def read_tile(self, tile_indices: Dict[str, int]):
def read_tile(self, tile_indices: Dict[str, int], block_id=0):
"""Extract a specific tile from the tensor using global tile indices.

Args:
Expand All @@ -234,7 +248,10 @@ def read_tile(self, tile_indices: Dict[str, int]):
Returns:
The requested tile as a tensor
"""
par_tile_size, par_num_tiles, free_num_tiles, free_tile_size = self.tensor.shape
if self.multibuffering:
par_tile_size, multibuff, par_num_tiles, free_num_tiles, free_tile_size = self.tensor.shape
else:
par_tile_size, par_num_tiles, free_num_tiles, free_tile_size = self.tensor.shape

# Convert global indices to local indices
par_tile_index = tile_indices[self.par_axis] - self.tile_coordinates[self.par_axis]["start_tile_index"]
Expand All @@ -252,6 +269,10 @@ def read_tile(self, tile_indices: Dict[str, int]):

idx_tile = nl.mgrid[0:par_tile_size, 0:free_tile_size]
tile = self.tensor[idx_tile.p, par_tile_index, free_tile_index, idx_tile.x]
if self.multibuffering:
tile = self.tensor[idx_tile.p, block_id % self.multibuffering, par_tile_index, free_tile_index, idx_tile.x]
else:
tile = self.tensor[idx_tile.p, par_tile_index, free_tile_index, idx_tile.x]
return tile

def save_to_hbm(self, result):
Expand Down Expand Up @@ -299,7 +320,7 @@ def save_to_hbm(self, result):

# Read tile using global indices
tile_indices = {self.par_axis: global_row_tile, self.free_axis: global_column_tile}
tile_data = self.read_tile(tile_indices)
tile_data = self.read_tile(tile_indices, block_id=0)

# Store tile to result tensor
nl.store(
Expand Down
2 changes: 2 additions & 0 deletions autotune/gemm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def __init__(
self.op_positions["result"] = self.loop_order["K"]
self.op_positions["save"] = self.loop_order["K"]
self.op_positions["x_op"] = self.loop_order["K"]
self.lhs_rel_position = lhs_position
self.rhs_rel_position = rhs_position

def _parse_absolute_position(self, relative_position: int, axes: Tuple[str, ...]) -> int:
"""
Expand Down
Loading