1717from math import prod
1818
1919import ttnn
20+ from models .demos .gpt_oss .tt .common import row_major_reshape
2021
2122from .config import AllToAllCombineConfig , AllToAllDispatchConfig , ThroughputExpertConfig , ThroughputProgramConfig
2223from .weights import ThroughputExpertWeights
@@ -124,19 +125,22 @@ def decode_forward(
124125 tokens_per_device = input_shape [0 ] * input_shape [2 ] # B * S
125126
126127 # Reshape hidden states: put all tokens on dim -2
127- hidden_states = ttnn .reshape (hidden_states , (1 , 1 , tokens_per_device , config .hidden_size ))
128+ # hidden_states = ttnn.reshape(hidden_states, (1, 1, tokens_per_device, config.hidden_size))
129+ hidden_states = row_major_reshape (hidden_states , (1 , 1 , tokens_per_device , config .hidden_size ))
128130
129131 # typecast creates new tensors - safe to deallocate originals
130132 topk_expert_indices_orig = topk_expert_indices
131133 topk_expert_indices = ttnn .typecast (topk_expert_indices , dtype = ttnn .uint32 )
132134 ttnn .deallocate (topk_expert_indices_orig )
133135
134136 # Reshape indices: put all tokens on dim -2
135- topk_expert_indices = ttnn .reshape (topk_expert_indices , (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
137+ # topk_expert_indices = ttnn.reshape(topk_expert_indices, (1, 1, tokens_per_device, config.num_experts_per_tok))
138+ topk_expert_indices = row_major_reshape (topk_expert_indices , (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
136139 topk_expert_indices_u32 = topk_expert_indices
137140 topk_expert_indices = ttnn .typecast (topk_expert_indices , dtype = ttnn .uint16 )
138141 ttnn .deallocate (topk_expert_indices_u32 )
139- topk_expert_weights = ttnn .reshape (topk_expert_weights , (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
142+ # topk_expert_weights = ttnn.reshape(topk_expert_weights, (1, 1, tokens_per_device, config.num_experts_per_tok))
143+ topk_expert_weights = row_major_reshape (topk_expert_weights , (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
140144
141145 num_dispatch_devices = (
142146 mesh_device .shape [dispatch_config .cluster_axis ]
@@ -154,12 +158,14 @@ def decode_forward(
154158 hidden_rm = ttnn .to_layout (hidden_states , ttnn .ROW_MAJOR_LAYOUT )
155159 ttnn .deallocate (hidden_states )
156160 # Shape is already [1, 1, tokens_per_device, H], just ensure it's correct
157- hidden_rm = ttnn .reshape (hidden_rm , shape = (1 , 1 , tokens_per_device , config .hidden_size ))
161+ # hidden_rm = ttnn.reshape(hidden_rm, shape=(1, 1, tokens_per_device, config.hidden_size))
162+ hidden_rm = row_major_reshape (hidden_rm , (1 , 1 , tokens_per_device , config .hidden_size ))
158163
159164 # Expert indices: [1, 1, tokens_per_device, K]
160165 topk_indices_rm = ttnn .to_layout (topk_expert_indices , ttnn .ROW_MAJOR_LAYOUT )
161166 ttnn .deallocate (topk_expert_indices )
162- topk_indices_rm = ttnn .reshape (topk_indices_rm , shape = (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
167+ # topk_indices_rm = ttnn.reshape(topk_indices_rm, shape=(1, 1, tokens_per_device, config.num_experts_per_tok))
168+ topk_indices_rm = row_major_reshape (topk_indices_rm , (1 , 1 , tokens_per_device , config .num_experts_per_tok ))
163169
164170 # ==========================================================================
165171 # STEP 2: ALL_TO_ALL_DISPATCH - Route tokens to expert devices
@@ -197,7 +203,8 @@ def decode_forward(
197203 # -> repeat to [1, dispatch_rows, tokens_per_device, num_experts]
198204 # -> reshape to [1, 1, total_tokens, num_experts] to match dispatch_metadata batch/seq dims
199205 remap_mask = ttnn .repeat (remap_topk_mask , ttnn .Shape ((1 , 1 , tokens_per_device , 1 )))
200- remap_mask = ttnn .reshape (remap_mask , (1 , 1 , total_tokens , config .num_experts ))
206+ # remap_mask = ttnn.reshape(remap_mask, (1, 1, total_tokens, config.num_experts))
207+ remap_mask = row_major_reshape (remap_mask , (1 , 1 , total_tokens , config .num_experts ))
201208 # moe_expert_token_remap returns:
202209 # - mapping: [D, tokens, 1, experts_per_device] - local expert activation weights
203210 # - sparsity: [D, 1, tokens/reduction_size, experts_per_device] - which blocks are active
@@ -222,15 +229,17 @@ def decode_forward(
222229 # The sparse matmul operates on blocks of tokens, with sparsity indicating
223230 # which (token_block, expert) pairs need computation.
224231 # Note: reshape returns view, but to_layout creates new tensor
225- post_dispatch = ttnn .reshape (dispatch_output , shape = (1 , 1 , total_tokens , config .hidden_size ))
232+ # post_dispatch = ttnn.reshape(dispatch_output, shape=(1, 1, total_tokens, config.hidden_size))
233+ post_dispatch = row_major_reshape (dispatch_output , (1 , 1 , total_tokens , config .hidden_size ))
226234 post_dispatch_rm = post_dispatch
227235 post_dispatch = ttnn .to_layout (post_dispatch , ttnn .TILE_LAYOUT )
228236 ttnn .deallocate (post_dispatch_rm ) # This deallocates dispatch_output via the view
229237
230238 # Reshape to sparse block format for matmul
231239 # Note: reshape returns a view - don't deallocate post_dispatch separately
232240 num_sparse_blocks = total_tokens // config .sparsity_block_size
233- expert_input = ttnn .reshape (
241+ # expert_input = ttnn.reshape(
242+ expert_input = row_major_reshape (
234243 post_dispatch ,
235244 shape = (1 , num_sparse_blocks , config .sparsity_block_size , config .hidden_size ),
236245 )
@@ -328,7 +337,8 @@ def decode_forward(
328337 ttnn .deallocate (expert_output_sparse )
329338 # Note: reshape returns a view, to_layout creates new tensor
330339 # With tokens on dim -2: [experts_per_device, 1, total_tokens, H]
331- expert_output = ttnn .reshape (
340+ # expert_output = ttnn.reshape(
341+ expert_output = row_major_reshape (
332342 expert_output ,
333343 shape = (config .num_experts_per_device , 1 , total_tokens , config .hidden_size ),
334344 )
0 commit comments