Skip to content

Commit 1e16021

Browse files
committed
rm reshape
1 parent 18de958 commit 1e16021

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

models/demos/gpt_oss/tt/attention/decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import ttnn
5+
from models.demos.gpt_oss.tt.common import row_major_reshape
56

67
from .config import AttentionConfig, ProgramConfig
78
from .operations import apply_allreduce, apply_rope
@@ -160,7 +161,8 @@ def decode_forward(
160161
tt_sdpa_out.deallocate(True)
161162
tt_out = ttnn.add(tt_out, weights.o_proj_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
162163
tt_out = ttnn.typecast(tt_out, ttnn.bfloat8_b)
163-
tt_out = ttnn.reshape(
164+
# tt_out = ttnn.reshape(
165+
tt_out = row_major_reshape(
164166
tt_out,
165167
(1, 1, batch_size, hidden_size),
166168
(1, 1, 32, hidden_size),

models/demos/gpt_oss/tt/common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010
from models.tt_transformers.tt.common import PagedAttentionConfig
1111

1212

13+
def row_major_reshape(tensor: ttnn.Tensor, shape: ttnn.Shape) -> ttnn.Tensor:
14+
"""Reshape a tensor to row major layout.
15+
16+
Args:
17+
tensor: Input tensor
18+
shape: New shape
19+
20+
Returns:
21+
"""
22+
tensor_is_tile = tensor.layout == ttnn.TILE_LAYOUT
23+
if tensor_is_tile:
24+
tensor = ttnn.to_layout(tensor, ttnn.ROW_MAJOR_LAYOUT)
25+
out = ttnn.reshape(tensor, shape)
26+
if tensor_is_tile:
27+
out = ttnn.to_layout(out, ttnn.TILE_LAYOUT)
28+
return out
29+
30+
1331
def create_tt_model(
1432
mesh_device,
1533
max_batch_size,

models/demos/gpt_oss/tt/experts_throughput/decode.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from math import prod
1818

1919
import ttnn
20+
from models.demos.gpt_oss.tt.common import row_major_reshape
2021

2122
from .config import AllToAllCombineConfig, AllToAllDispatchConfig, ThroughputExpertConfig, ThroughputProgramConfig
2223
from .weights import ThroughputExpertWeights
@@ -124,19 +125,22 @@ def decode_forward(
124125
tokens_per_device = input_shape[0] * input_shape[2] # B * S
125126

126127
# Reshape hidden states: put all tokens on dim -2
127-
hidden_states = ttnn.reshape(hidden_states, (1, 1, tokens_per_device, config.hidden_size))
128+
# hidden_states = ttnn.reshape(hidden_states, (1, 1, tokens_per_device, config.hidden_size))
129+
hidden_states = row_major_reshape(hidden_states, (1, 1, tokens_per_device, config.hidden_size))
128130

129131
# typecast creates new tensors - safe to deallocate originals
130132
topk_expert_indices_orig = topk_expert_indices
131133
topk_expert_indices = ttnn.typecast(topk_expert_indices, dtype=ttnn.uint32)
132134
ttnn.deallocate(topk_expert_indices_orig)
133135

134136
# Reshape indices: put all tokens on dim -2
135-
topk_expert_indices = ttnn.reshape(topk_expert_indices, (1, 1, tokens_per_device, config.num_experts_per_tok))
137+
# topk_expert_indices = ttnn.reshape(topk_expert_indices, (1, 1, tokens_per_device, config.num_experts_per_tok))
138+
topk_expert_indices = row_major_reshape(topk_expert_indices, (1, 1, tokens_per_device, config.num_experts_per_tok))
136139
topk_expert_indices_u32 = topk_expert_indices
137140
topk_expert_indices = ttnn.typecast(topk_expert_indices, dtype=ttnn.uint16)
138141
ttnn.deallocate(topk_expert_indices_u32)
139-
topk_expert_weights = ttnn.reshape(topk_expert_weights, (1, 1, tokens_per_device, config.num_experts_per_tok))
142+
# topk_expert_weights = ttnn.reshape(topk_expert_weights, (1, 1, tokens_per_device, config.num_experts_per_tok))
143+
topk_expert_weights = row_major_reshape(topk_expert_weights, (1, 1, tokens_per_device, config.num_experts_per_tok))
140144

141145
num_dispatch_devices = (
142146
mesh_device.shape[dispatch_config.cluster_axis]
@@ -154,12 +158,14 @@ def decode_forward(
154158
hidden_rm = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT)
155159
ttnn.deallocate(hidden_states)
156160
# Shape is already [1, 1, tokens_per_device, H], just ensure it's correct
157-
hidden_rm = ttnn.reshape(hidden_rm, shape=(1, 1, tokens_per_device, config.hidden_size))
161+
# hidden_rm = ttnn.reshape(hidden_rm, shape=(1, 1, tokens_per_device, config.hidden_size))
162+
hidden_rm = row_major_reshape(hidden_rm, (1, 1, tokens_per_device, config.hidden_size))
158163

159164
# Expert indices: [1, 1, tokens_per_device, K]
160165
topk_indices_rm = ttnn.to_layout(topk_expert_indices, ttnn.ROW_MAJOR_LAYOUT)
161166
ttnn.deallocate(topk_expert_indices)
162-
topk_indices_rm = ttnn.reshape(topk_indices_rm, shape=(1, 1, tokens_per_device, config.num_experts_per_tok))
167+
# topk_indices_rm = ttnn.reshape(topk_indices_rm, shape=(1, 1, tokens_per_device, config.num_experts_per_tok))
168+
topk_indices_rm = row_major_reshape(topk_indices_rm, (1, 1, tokens_per_device, config.num_experts_per_tok))
163169

164170
# ==========================================================================
165171
# STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
@@ -197,7 +203,8 @@ def decode_forward(
197203
# -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
198204
# -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
199205
remap_mask = ttnn.repeat(remap_topk_mask, ttnn.Shape((1, 1, tokens_per_device, 1)))
200-
remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
206+
# remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
207+
remap_mask = row_major_reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
201208
# moe_expert_token_remap returns:
202209
# - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
203210
# - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
@@ -222,15 +229,17 @@ def decode_forward(
222229
# The sparse matmul operates on blocks of tokens, with sparsity indicating
223230
# which (token_block, expert) pairs need computation.
224231
# Note: reshape returns view, but to_layout creates new tensor
225-
post_dispatch = ttnn.reshape(dispatch_output, shape=(1, 1, total_tokens, config.hidden_size))
232+
# post_dispatch = ttnn.reshape(dispatch_output, shape=(1, 1, total_tokens, config.hidden_size))
233+
post_dispatch = row_major_reshape(dispatch_output, (1, 1, total_tokens, config.hidden_size))
226234
post_dispatch_rm = post_dispatch
227235
post_dispatch = ttnn.to_layout(post_dispatch, ttnn.TILE_LAYOUT)
228236
ttnn.deallocate(post_dispatch_rm) # This deallocates dispatch_output via the view
229237

230238
# Reshape to sparse block format for matmul
231239
# Note: reshape returns a view - don't deallocate post_dispatch separately
232240
num_sparse_blocks = total_tokens // config.sparsity_block_size
233-
expert_input = ttnn.reshape(
241+
# expert_input = ttnn.reshape(
242+
expert_input = row_major_reshape(
234243
post_dispatch,
235244
shape=(1, num_sparse_blocks, config.sparsity_block_size, config.hidden_size),
236245
)
@@ -328,7 +337,8 @@ def decode_forward(
328337
ttnn.deallocate(expert_output_sparse)
329338
# Note: reshape returns a view, to_layout creates new tensor
330339
# With tokens on dim -2: [experts_per_device, 1, total_tokens, H]
331-
expert_output = ttnn.reshape(
340+
# expert_output = ttnn.reshape(
341+
expert_output = row_major_reshape(
332342
expert_output,
333343
shape=(config.num_experts_per_device, 1, total_tokens, config.hidden_size),
334344
)

models/demos/gpt_oss/tt/topk.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313
import ttnn
14+
from models.demos.gpt_oss.tt.common import row_major_reshape
1415
from models.demos.gpt_oss.utils.general_utils import get_cache_file_name
1516

1617

@@ -146,7 +147,8 @@ def __call__(self, hidden_states, use_throughput_experts):
146147
# )
147148
mem_config = ttnn.DRAM_MEMORY_CONFIG
148149

149-
hidden_states = ttnn.reshape(hidden_states, (-1, self.hidden_dim))
150+
# hidden_states = ttnn.reshape(hidden_states, (-1, self.hidden_dim))
151+
hidden_states = row_major_reshape(hidden_states, (-1, self.hidden_dim))
150152
router_logits = ttnn.linear(
151153
hidden_states,
152154
self.weight,

0 commit comments

Comments
 (0)