[Merge stable to main] Llama3.3-70b and 3.1-8b - Fix sampling parameters#36476
[Merge stable to main] Llama3.3-70b and 3.1-8b - Fix sampling parameters#36476djordje-tt merged 23 commits intomainfrom
Conversation
|
/codeowners ping |
CodeOwners Group AnalysisThis PR requires approval from one member of each of the following groups: Summary: 2 pending groups, 0 approved groups Group Information:
Note: At least one approval from each group is sufficient. |
|
Hi Ambrose Ling (@alingTT), Stuti Raizada (@sraizada-tt), Utku Aydonat (@uaydonat), Mark O'Connor (@yieldthought), this PR [Merge stable to main] Llama3.3-70b - Fix sampling parameters by Djordje Ivanovic (@djordje-tt) needs your approval/review to merge this. |
There was a problem hiding this comment.
Pull request overview
This PR merges sampling and log-probs enhancements from stable to main for Llama3.3-70b on Galaxy, addressing several key issues including non-uniform seeding, penalty bugs, batched prefill determinism, and missing log-probs support.
Changes:
- Introduces deterministic seeding flow with host-side RNG via
SeedManagerandttnn.manual_seedintegration for reproducible sampling across prefill and decode - Adds log-probs support for Galaxy with on-device log-softmax calculation using distributed reduction operations
- Updates matmul configurations to ensure consistency between batched and non-batched prefill paths
- Fixes penalty application bugs (frequency/presence/repetition penalties) with corrected tensor type handling and proper masking
- Extends prefill path to support on-device sampling with sharded logits accumulation
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| qwen_model_config.py | Updated matmul program config overwrite_per_core_k values for seq_len=128 operations to improve batched prefill consistency |
| model_config.py | Similar matmul config adjustments for Llama model (non-Qwen variant) |
| llama_model.py | Added process_output_prefill_logits method for on-device sampling; updated return signatures to include log_probs |
| llama_mlp.py | Threaded batch_size parameter through forward_prefill for CCL operations |
| llama_decoder.py | Passed batch_size to feed_forward layer for proper CCL buffer selection |
| llama_ccl.py | Added log-probs persistent buffers; updated reduce_scatter/all_gather for batched operations; removed WO from reduce_scatter (now uses WO_AG) |
| llama_attention.py | Updated SDPA program config and WO buffer key for batched prefill; deallocate intermediate output_11SH tensor |
| generator.py (galaxy) | Implemented on-device prefill sampling with logits accumulation; added slot-based parameter scattering; integrated SeedManager |
| text_demo.py | Added test configuration for batch-32 with non-uniform sampling and log-probs |
| demo_qwen_decode.py | Updated to extract and track log-probs during decode |
| outputs_batch_1.json | Expected output reference update (generation variance) |
| utils.py | Refactored LogProbsCalculator for Galaxy (32-device) support with proper all-gather operations and dimension handling |
| test_sampling.py | Added test_log_probs_with_sub_core_grids_on_galaxy for validating log-probs on 32-device mesh |
| tt_sampling.py | Added force_argmax_sampling optimization path; integrated manual_seed with per-user seed tensors |
| tt_penalties.py | Fixed penalty application order and type casting bugs; removed vocab expansion workaround; proper -1 padding handling |
| generator.py (sampling) | Introduced SeedManager class for deterministic per-user RNG; updated trace key to include force_argmax flag |
Comments suppressed due to low confidence (2)
models/common/utils.py:303
- Several intermediate tensors are not deallocated, which could lead to memory leaks. Consider deallocating:
global_idx_tilized_tensorafter line 229,chip_ids_tensorafter line 249,remainder_tensorafter line 239,outafter line 276, andlog_global_exp_sumafter line 276. Also,relevant_logitsafter line 303.
def _prepare_relevant_logits(self, logits_tensor: ttnn.Tensor, global_idx_tensor: ttnn.Tensor):
"""
Prepare global idx tensor with correct values on all devices.
"""
size_per_device = logits_tensor.shape[-1]
# convert global_idx_tensor to ttnn.TILE_LAYOUT
global_idx_tilized_tensor = ttnn.to_layout(global_idx_tensor, ttnn.TILE_LAYOUT, **self.common_args)
# TODO: Raise an issue on this since for UINT_32 ttnn.div produces incorrect output (all zeros)
global_idx_tilized_tensor = ttnn.typecast(global_idx_tilized_tensor, ttnn.float32, **self.common_args)
# Get chip_id for each user based on global_idx values in global_idx_tensor
chip_ids_tensor = ttnn.div(
global_idx_tilized_tensor,
size_per_device,
round_mode="floor",
memory_config=ttnn.DRAM_MEMORY_CONFIG,
**self.common_args,
)
# Get local index for each user based on global_idx values in global_idx_tensor
remainder_tensor = ttnn.remainder(
global_idx_tilized_tensor,
size_per_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
**self.common_args,
)
# Convert remainder_tensor to int32
remainder_tensor = ttnn.typecast(remainder_tensor, ttnn.uint32, **self.common_args)
# convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
remainder_tensor = ttnn.to_layout(remainder_tensor, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
remainder_tensor = ttnn.reshape(remainder_tensor, (1, 1, 32, 1), **self.common_args)
remainder_tensor = ttnn.to_layout(remainder_tensor, ttnn.TILE_LAYOUT, **self.common_args)
# Get logits for each user on each chip based on local index
selected_logits_tensor = ttnn.gather(logits_tensor, dim=3, index=remainder_tensor, **self.common_args)
# convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
selected_logits_tensor = ttnn.reshape(selected_logits_tensor, (1, 1, 1, 32), **self.common_args)
selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.TILE_LAYOUT, **self.common_args)
# Compare mask to chip_ids tensor and select correct positions for each user on all chips inplace
ttnn.eq_(chip_ids_tensor, self.mask, **self.common_args)
# Multiply selected_logits_tensor with chip_ids_tensor to get expected logits for each user
selected_logits_tensor = ttnn.multiply(selected_logits_tensor, chip_ids_tensor, **self.common_args)
# All gather logits across all devices
selected_logits_tensor = self._perform_all_gather(
selected_logits_tensor,
dim=1,
num_links=1,
buffer_key="LOGPROBS_LOGITS",
)
selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
selected_logits_tensor = ttnn.reshape(selected_logits_tensor, (1, 1, 8, 32), **self.common_args)
selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.TILE_LAYOUT, **self.common_args)
# Apply sum over device dimension to get logits for each user on all chips
selected_logits_tensor = ttnn.sum(selected_logits_tensor, dim=2, keepdim=True, **self.common_args)
return selected_logits_tensor
def _calculate_log_probs(self, sampled_logits_tensor: ttnn.Tensor):
"""
Calculate log-probs for a given logits tensor with formula:
log-prob(x) = logits(x) - global_max - log(global_exp_sum)
"""
out = ttnn.subtract(sampled_logits_tensor, self.global_max, **self.common_args)
log_global_exp_sum = ttnn.log(self.global_exp_sum, **self.common_args)
# Subtract and put result to self.output_tensor
ttnn.subtract(out, log_global_exp_sum, output_tensor=self.output_tensor, **self.common_args)
def calculate_log_probs(
self,
logits_tensor: ttnn.Tensor,
indices_tensor: ttnn.Tensor,
):
"""
Calculate log-probs for a given logits tensor and indices tensor.
"""
if not self.enable_log_probs:
return self.output_tensor
if self.mesh_device.get_num_devices() not in [8, 32]:
return self.output_tensor
# Calculating log-probs requires bfloat16 precision for near-stable sum-exp calculation
if logits_tensor.dtype == ttnn.bfloat8_b:
logits_tensor = ttnn.typecast(logits_tensor, ttnn.bfloat16, **self.common_args)
# Compute global max and global sum(exp(logits - global_max)) for each chip
self._compute_global_stats(logits_tensor)
# Prepare relevant logits for each user on each chip
relevant_logits = self._prepare_relevant_logits(logits_tensor, indices_tensor)
# Calculate log-probs for each user on each chip and stores in self.output_tensor
self._calculate_log_probs(relevant_logits)
models/demos/llama3_70b_galaxy/tt/llama_mlp.py:323
- Test is always true, because of this condition.
if 1024 <= seq_len < 4096:
|
|
||
| freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs) | ||
|
|
||
| freq_term_bf16 = ttnn.typecast(freq_term, ttnn.bfloat16, **op_kwargs) |
There was a problem hiding this comment.
The intermediate variable freq_term is created but never deallocated. This could lead to memory leaks in long-running scenarios. Consider adding freq_term.deallocate() after line 50.
| freq_term_bf16 = ttnn.typecast(freq_term, ttnn.bfloat16, **op_kwargs) | |
| freq_term_bf16 = ttnn.typecast(freq_term, ttnn.bfloat16, **op_kwargs) | |
| freq_term.deallocate() |
|
|
||
| # presence | ||
| presence_term = ttnn.multiply(context.output_mask, context.presence_penalties, **op_kwargs) | ||
| presence_term = ttnn.multiply( |
There was a problem hiding this comment.
The intermediate variable presence_term is created but never deallocated after creating presence_term_bf16. This could lead to memory leaks. Consider adding presence_term.deallocate() after line 41.
| local_max_tensor = ttnn.max(logits_tensor, dim=-1, keepdim=True, **self.common_args) | ||
|
|
||
| # All-gather local max to get global max | ||
| gathered_max_tensors = ttnn.all_gather( | ||
| gathered_max_tensors = self._perform_all_gather( | ||
| local_max_tensor, | ||
| dim=3, | ||
| dim=1, | ||
| num_links=1, | ||
| memory_config=ttnn.DRAM_MEMORY_CONFIG, | ||
| cluster_axis=None, | ||
| topology=ttnn.Topology.Linear, | ||
| buffer_key="LOGPROBS_MAX_REDUCTION", | ||
| ) | ||
| self.global_max = ttnn.max(gathered_max_tensors, dim=-1, keepdim=True) | ||
| # TODO: Convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT | ||
| gathered_max_tensors = ttnn.to_layout(gathered_max_tensors, ttnn.ROW_MAJOR_LAYOUT, **self.common_args) | ||
| gathered_max_tensors = ttnn.reshape(gathered_max_tensors, (1, 1, 8, 32), **self.common_args) | ||
| gathered_max_tensors = ttnn.to_layout(gathered_max_tensors, ttnn.TILE_LAYOUT, **self.common_args) | ||
|
|
||
| self.global_max = ttnn.max(gathered_max_tensors, dim=2, keepdim=True, **self.common_args) | ||
|
|
||
| global_max_to_subtract = ttnn.to_layout(self.global_max, ttnn.ROW_MAJOR_LAYOUT, **self.common_args) | ||
| global_max_to_subtract = ttnn.reshape(global_max_to_subtract, (1, 1, 32, 1), **self.common_args) | ||
| global_max_to_subtract = ttnn.to_layout(global_max_to_subtract, ttnn.TILE_LAYOUT, **self.common_args) | ||
|
|
||
| # Calculate stable local sum-exp using subtract of global-max from each local logit | ||
| subtracted_tensor = ttnn.subtract(logits_tensor, self.global_max) | ||
| sum_exp_tensor = ttnn.sum(ttnn.exp(subtracted_tensor), dim=-1, keepdim=True) | ||
| subtracted_tensor = ttnn.subtract(logits_tensor, global_max_to_subtract, **self.common_args) | ||
| exp_tensor = ttnn.exp(subtracted_tensor, **self.common_args) | ||
| sum_exp_tensor = ttnn.sum(exp_tensor, dim=-1, keepdim=True, **self.common_args) | ||
|
|
||
| # All-gather stable local sum-exp to get global sum-exp | ||
| gathered_sum_exp_tensors = ttnn.all_gather( | ||
| gathered_sum_exp_tensors = self._perform_all_gather( | ||
| sum_exp_tensor, | ||
| dim=3, | ||
| dim=1, | ||
| num_links=1, | ||
| memory_config=ttnn.DRAM_MEMORY_CONFIG, | ||
| cluster_axis=None, | ||
| topology=ttnn.Topology.Linear, | ||
| buffer_key="LOGPROBS_SUM_EXP_REDUCTION", | ||
| ) | ||
| self.global_exp_sum = ttnn.sum(gathered_sum_exp_tensors, dim=-1, keepdim=True) | ||
| gathered_sum_exp_tensors = ttnn.to_layout(gathered_sum_exp_tensors, ttnn.ROW_MAJOR_LAYOUT, **self.common_args) | ||
| gathered_sum_exp_tensors = ttnn.reshape(gathered_sum_exp_tensors, (1, 1, 8, 32), **self.common_args) | ||
| gathered_sum_exp_tensors = ttnn.to_layout(gathered_sum_exp_tensors, ttnn.TILE_LAYOUT, **self.common_args) | ||
|
|
||
| # reshape global_max and global_exp_sum to support same output shape as sampling output -> (1, 1, 1, 32) | ||
| # convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT | ||
| self.global_max = ttnn.to_layout(self.global_max, ttnn.ROW_MAJOR_LAYOUT) | ||
| self.global_max = ttnn.reshape(self.global_max, (1, 1, 1, 32)) | ||
| self.global_max = ttnn.to_layout(self.global_max, ttnn.TILE_LAYOUT) | ||
|
|
||
| # convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT | ||
| self.global_exp_sum = ttnn.to_layout(self.global_exp_sum, ttnn.ROW_MAJOR_LAYOUT) | ||
| self.global_exp_sum = ttnn.reshape(self.global_exp_sum, (1, 1, 1, 32)) | ||
| self.global_exp_sum = ttnn.to_layout(self.global_exp_sum, ttnn.TILE_LAYOUT) | ||
| self.global_exp_sum = ttnn.sum(gathered_sum_exp_tensors, dim=2, keepdim=True, **self.common_args) |
There was a problem hiding this comment.
Several intermediate tensors are not deallocated, which could lead to memory leaks. Consider deallocating: local_max_tensor after line 173, subtracted_tensor after line 187, exp_tensor after line 188, sum_exp_tensor after line 188, and global_max_to_subtract after line 186.
| if sampling_params.top_k[i] < 1: | ||
| sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32) |
There was a problem hiding this comment.
Duplicate check for top_k < 1 at lines 347-348 and 353-354. The second check (lines 353-354) is redundant and should be removed.
| if sampling_params.top_k[i] < 1: | |
| sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32) |
| @@ -630,28 +743,40 @@ def _decode_easy_trace_text( | |||
|
|
|||
| return trace_tok_rm | |||
There was a problem hiding this comment.
When enable_split_sampling is True and return_logits is False, the function returns the result from self.model.sampling.sample() which should return a tuple (tt_tokens, tt_log_probs) according to line 645-650 of llama_model.py. However, when enable_split_sampling is False or return_logits is True, the function returns trace_tok_rm which is the raw output. This creates inconsistent return types. The caller at line 565 expects a tuple (tt_tok, tt_log_probs). Consider ensuring consistent return types.
| return trace_tok_rm | |
| # For consistency, always return (tt_out, tt_log_probs) where log_probs may be None. | |
| return trace_tok_rm, None |
| # frequency | ||
| output_counts_bf16 = ttnn.typecast(context.output_counts, ttnn.bfloat16, **op_kwargs) | ||
|
|
||
| freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs) |
There was a problem hiding this comment.
The intermediate variable output_counts_bf16 is created but never deallocated. This could lead to memory leaks in long-running scenarios. Consider adding output_counts_bf16.deallocate() after line 48.
| freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs) | |
| freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs) | |
| output_counts_bf16.deallocate() |
| padded_batch = 32 | ||
|
|
There was a problem hiding this comment.
Variable padded_batch is not used.
| padded_batch = 32 |
b4abed2 to
01544b5
Compare
|
Codex review: The patch introduces at least one runtime break (missing SamplingGenerator.reset_seed) and a logic regression in batched prefill logits. It also changes a critical API keyword in log-prob calculation that can corrupt results when log_probs are enabled. Full review comments:
|
7a1b14e to
e44be11
Compare
Use dynamic num_links from tt_ccl.get_num_links(1) for prefill mode instead of hardcoded 1. This matches the behavior on main and fixes L1 buffer clash errors on P150x2 (BH-DB) during structured output benchmark. For P150x2, get_num_links(1) returns 2, which is required for proper tensor memory layout before rmsnorm operations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts commit b63c903.
9602978 to
bdfb675
Compare
|
|
||
| x = self.lm_head(x) | ||
|
|
||
| if mode == "prefill": |
There was a problem hiding this comment.
why is it safe to remove this?
| if self.model_config["LM_HEAD_INPUT_MEMCFG"].is_sharded(): | ||
| logits = ttnn.interleaved_to_sharded(logits, self.model_config["LM_HEAD_INPUT_MEMCFG"]) | ||
| logits = self.lm_head(logits) | ||
| logits = ttnn.to_layout(logits, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) |
There was a problem hiding this comment.
Is it safe to remove this as well? is logits already in DRAM? does downstream usage of this function continue to expect ROW MAJOR? please double check nothing breaks here
There was a problem hiding this comment.
Pipelines like Llama3.1-8b demo, vllm nightly for Llama3.1-8b and Models CI for Llama3.1-8b are passing. I can double check but I am pretty confident it's not needed anymore!
…ers (#36476) #36325 This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ers (#36476) #36325 This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ers (#36476) #36325 This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ers (#36476) ### Ticket #36325 ### Problem description This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b ### What's changed - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. ### Checklist - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) #### Model tests - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ers (#36476) ### Ticket #36325 ### Problem description This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b ### What's changed - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. ### Checklist - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) #### Model tests - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ers (#36476) #36325 This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Ticket
#36325
Problem description
This PR fixes couple of different issues for Llama3.3-70b:
Non-uniform seeding
Penalty trap bug
Penalty bugs for Llama3.1-8b
batched prefill determinism
diff between batched and non-batched prefill
missing logprobs support for Llama3.3-70b
Fixes same sampling parameters for Llama3.1-8b
What's changed
ttnn.manual_seedusage beforettnn.sampling) so prefill + decode produce deterministic sequences across repeats when seeds are fixed.Performance numbers on text_demo in t/s/u:
TTFT:
68.5ms -> 73.9ms drop due to disabling use_2d_grid in rms norm is expected.
Checklist
Model tests
Last pipelines list 6th Feb: