|
| 1 | +# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +import ttnn |
| 5 | +import torch |
| 6 | + |
| 7 | + |
| 8 | +def from_torch(tensor: ttnn.Tensor): |
| 9 | + return ttnn.from_torch( |
| 10 | + tensor, |
| 11 | + dtype=ttnn.bfloat16, |
| 12 | + layout=ttnn.TILE_LAYOUT, |
| 13 | + device=device, |
| 14 | + memory_config=ttnn.DRAM_MEMORY_CONFIG, |
| 15 | + ) |
| 16 | + |
| 17 | + |
| 18 | +import ttl |
| 19 | + |
| 20 | +TILE_SIZE = 32 |
| 21 | +GRANULARITY = 4 |
| 22 | + |
| 23 | + |
| 24 | +@ttl.kernel(grid=(4, 4)) |
| 25 | +def __demo_kernel(a: ttnn.Tensor, b: ttnn.Tensor, c: ttnn.Tensor, y: ttnn.Tensor): |
| 26 | + row_tiles_per_block = GRANULARITY |
| 27 | + col_tiles_per_block = GRANULARITY |
| 28 | + |
| 29 | + grid_cols, grid_rows = ttl.grid_size(dims=2) |
| 30 | + |
| 31 | + rows_per_core = a.shape[0] // TILE_SIZE // row_tiles_per_block // grid_rows |
| 32 | + cols_per_core = a.shape[1] // TILE_SIZE // col_tiles_per_block // grid_rows |
| 33 | + |
| 34 | + a_cb = ttl.make_circular_buffer_like( |
| 35 | + a, shape=(row_tiles_per_block, col_tiles_per_block), buffer_factor=2 |
| 36 | + ) |
| 37 | + b_cb = ttl.make_circular_buffer_like( |
| 38 | + b, shape=(row_tiles_per_block, col_tiles_per_block), buffer_factor=2 |
| 39 | + ) |
| 40 | + c_cb = ttl.make_circular_buffer_like( |
| 41 | + c, shape=(row_tiles_per_block, col_tiles_per_block), buffer_factor=2 |
| 42 | + ) |
| 43 | + y_cb = ttl.make_circular_buffer_like( |
| 44 | + y, shape=(row_tiles_per_block, col_tiles_per_block), buffer_factor=2 |
| 45 | + ) |
| 46 | + |
| 47 | + @ttl.compute() |
| 48 | + def demo_compute(): |
| 49 | + for _ in range(rows_per_core): |
| 50 | + for _ in range(cols_per_core): |
| 51 | + with ( |
| 52 | + a_cb.wait() as a_blk, |
| 53 | + b_cb.wait() as b_blk, |
| 54 | + c_cb.wait() as c_blk, |
| 55 | + y_cb.reserve() as y_blk, |
| 56 | + ): |
| 57 | + y_blk.store(a_blk * b_blk + c_blk) |
| 58 | + |
| 59 | + @ttl.datamovement() |
| 60 | + def demo_read(): |
| 61 | + core_col, core_row = ttl.core(dims=2) |
| 62 | + |
| 63 | + for local_row in range(rows_per_core): |
| 64 | + row = core_row * rows_per_core + local_row |
| 65 | + start_row_tile = row * row_tiles_per_block |
| 66 | + end_row_tile = (row + 1) * row_tiles_per_block |
| 67 | + |
| 68 | + for local_col in range(cols_per_core): |
| 69 | + col = core_col * cols_per_core + local_col |
| 70 | + start_col_tile = col * col_tiles_per_block |
| 71 | + end_col_tile = (col + 1) * col_tiles_per_block |
| 72 | + |
| 73 | + with ( |
| 74 | + a_cb.reserve() as a_blk, |
| 75 | + b_cb.reserve() as b_blk, |
| 76 | + c_cb.reserve() as c_blk, |
| 77 | + ): |
| 78 | + tx_a = ttl.copy( |
| 79 | + a[ |
| 80 | + start_row_tile:end_row_tile, |
| 81 | + start_col_tile:end_col_tile, |
| 82 | + ], |
| 83 | + a_blk, |
| 84 | + ) |
| 85 | + tx_b = ttl.copy( |
| 86 | + b[ |
| 87 | + start_row_tile:end_row_tile, |
| 88 | + start_col_tile:end_col_tile, |
| 89 | + ], |
| 90 | + b_blk, |
| 91 | + ) |
| 92 | + tx_c = ttl.copy( |
| 93 | + c[ |
| 94 | + start_row_tile:end_row_tile, |
| 95 | + start_col_tile:end_col_tile, |
| 96 | + ], |
| 97 | + c_blk, |
| 98 | + ) |
| 99 | + |
| 100 | + tx_a.wait() |
| 101 | + tx_b.wait() |
| 102 | + tx_c.wait() |
| 103 | + |
| 104 | + @ttl.datamovement() |
| 105 | + def demo_write(): |
| 106 | + core_col, core_row = ttl.core(dims=2) |
| 107 | + |
| 108 | + for local_row in range(rows_per_core): |
| 109 | + row = core_row * rows_per_core + local_row |
| 110 | + start_row_tile = row * row_tiles_per_block |
| 111 | + end_row_tile = (row + 1) * row_tiles_per_block |
| 112 | + |
| 113 | + for local_col in range(cols_per_core): |
| 114 | + col = core_col * cols_per_core + local_col |
| 115 | + start_col_tile = col * col_tiles_per_block |
| 116 | + end_col_tile = (col + 1) * col_tiles_per_block |
| 117 | + |
| 118 | + with y_cb.wait() as y_blk: |
| 119 | + tx = ttl.copy( |
| 120 | + y_blk, |
| 121 | + y[ |
| 122 | + start_row_tile:end_row_tile, |
| 123 | + start_col_tile:end_col_tile, |
| 124 | + ], |
| 125 | + ) |
| 126 | + tx.wait() |
| 127 | + |
| 128 | + |
| 129 | +def demo_kernel(a: ttnn.Tensor, b: ttnn.Tensor, c: ttnn.Tensor): |
| 130 | + y = from_torch(torch.zeros((a.shape[0], a.shape[1]), dtype=torch.bfloat16)) |
| 131 | + __demo_kernel(a, b, c, y) |
| 132 | + return y |
| 133 | + |
| 134 | + |
| 135 | +torch.manual_seed(42) |
| 136 | + |
| 137 | +device = ttnn.open_device(device_id=0) |
| 138 | + |
| 139 | +try: |
| 140 | + shape = (2048, 2048) |
| 141 | + |
| 142 | + a = torch.rand(shape, dtype=torch.bfloat16) |
| 143 | + b = torch.rand(shape, dtype=torch.bfloat16) |
| 144 | + c = torch.rand(shape, dtype=torch.bfloat16) |
| 145 | + d = torch.rand(shape, dtype=torch.bfloat16) |
| 146 | + |
| 147 | + expected_y = (a * b + c) * d |
| 148 | + |
| 149 | + a = from_torch(a) |
| 150 | + b = from_torch(b) |
| 151 | + c = from_torch(c) |
| 152 | + d = from_torch(d) |
| 153 | + |
| 154 | + y = ttnn.multiply(demo_kernel(a, b, c), d) |
| 155 | + |
| 156 | + y = ttnn.to_torch(y) |
| 157 | + print(y) |
| 158 | + print(expected_y) |
| 159 | + |
| 160 | + assert torch.allclose(y, expected_y, rtol=1e-2, atol=1e-2), "Tensors do not match" |
| 161 | + |
| 162 | +finally: |
| 163 | + ttnn.close_device(device) |
0 commit comments