Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f8e7236
Merge stable to main: Llama3.3-70b sampling params and logprobs fixes
djordje-tt Jan 23, 2026
6f3be71
Fix merging bugs
djordje-tt Jan 23, 2026
3150809
Another fix of merging
djordje-tt Jan 23, 2026
7f1c437
Fix ttnn.div round_mode -> rounding_mode
djordje-tt Jan 26, 2026
7337d9c
Fix output mismatch and TTFT drop 68.5->74ms
djordje-tt Jan 26, 2026
dc54c8b
Update ttnn.combine_device_tensors call
djordje-tt Jan 26, 2026
833c72f
Fix ttnn.combine_device_tensors call
djordje-tt Jan 27, 2026
a567436
Add prefill sampling support to TTT models (#35021)
sraizada-tt Dec 27, 2025
70cc1ae
Llama-3.1-8B decode TSU optimizations (#35142)
jonathansuTT Dec 29, 2025
28e4af8
Fix non-uniform seeding (#35906)
rdraskicTT Jan 16, 2026
3b1d0e4
Fix reduce_scatter_minimial_async merge conflict bugs
djordje-tt Jan 30, 2026
5ff6d53
Fix vLLM nightly CI test failures
djordje-tt Feb 1, 2026
3a40cd3
Update condition for setting links in model_config
djordje-tt Feb 2, 2026
7c89c43
Fix prefill all_gather num_links for non-Galaxy multichip devices
djordje-tt Feb 2, 2026
8703cd7
Fix gemma
rdraskicTT Feb 4, 2026
8f4d87f
Fix autoflake
rdraskicTT Feb 4, 2026
02427da
Temporarily enable cache gen
rdraskicTT Feb 4, 2026
ca530ea
Fix padded vocab size bug
rdraskicTT Feb 5, 2026
f3a15b5
Increase server timeout
rdraskicTT Feb 5, 2026
f59afb7
Resolve comments and remove dead code
djordje-tt Feb 5, 2026
b58330e
Revert "Temporarily enable cache gen"
rdraskicTT Feb 6, 2026
a807555
Reduce global_cb_size to 728 tiles back
djordje-tt Feb 6, 2026
bdfb675
Fix qwen test
djordje-tt Feb 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/vllm-nightly-tests-impl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ jobs:
{
"name": "[WH-T3K][v1] Gemma3-27B",
"model": "google/gemma-3-27b-it",
"server-timeout": 20,
"server-timeout": 45,
"benchmark-timeout": 5,
"runner-label": "config-t3000",
"arch": "arch-wormhole_b0",
Expand All @@ -148,7 +148,7 @@ jobs:
{
"name": "[WH-GLX][v1] Gemma3-27B DP=4",
"model": "google/gemma-3-27b-it",
"server-timeout": 20,
"server-timeout": 45,
"benchmark-timeout": 5,
"runner-label": "topology-6u",
"arch": "arch-wormhole_b0",
Expand Down
98 changes: 59 additions & 39 deletions models/common/sampling/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@

# SPDX-License-Identifier: Apache-2.0

import logging
import random
import secrets
from dataclasses import dataclass, fields, replace
from typing import List, Optional

import torch
from loguru import logger

import ttnn

from .tt_penalties import TTPenalties
from .tt_sampling import TTSampling

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class _TraceKey:
penalties_on: bool
log_probs_on: bool
force_argmax: bool


class SamplingGenerator:
Expand Down Expand Up @@ -60,12 +61,13 @@ def __init__(
self._penalties_active = False

self._trace_states: dict[_TraceKey, dict] = {}
self.seed_manager = SeedManager(self.tt_sampling)

def _new_trace_state(self):
return {"id": None, "input": None, "output": None, "kwargs": {}}

def _trace_slot(self, penalties_on: bool, log_probs_on: bool):
key = _TraceKey(penalties_on=penalties_on, log_probs_on=log_probs_on)
def _trace_slot(self, penalties_on: bool, log_probs_on: bool, force_argmax: bool):
key = _TraceKey(penalties_on=penalties_on, log_probs_on=log_probs_on, force_argmax=force_argmax)
slot = self._trace_states.get(key)
if slot is None:
slot = self._new_trace_state()
Expand All @@ -79,10 +81,7 @@ def reset_trace(self):
for key, slot in self._trace_states.items():
if slot["id"] is not None:
logger.debug(
"Resetting sampling trace (penalties=%s, log_probs=%s, trace_id=%s)",
key.penalties_on,
key.log_probs_on,
slot["id"],
f"Resetting sampling trace (penalties={key.penalties_on}, log_probs={key.log_probs_on}, force_argmax={key.force_argmax}, trace_id={slot['id']})"
)
self._trace_states.clear()

Expand All @@ -98,7 +97,7 @@ def reset_prompt_tokens(self, prompt_tokens):
return
self.tt_penalties.reset_prompt_tokens(prompt_tokens)

def reset_output_state(self, tokens):
def reset_output_state(self, tokens=None):
if not self._penalties_active:
return
self.tt_penalties.reset_output_tokens(tokens)
Expand All @@ -107,20 +106,30 @@ def reset_output_state(self, tokens):
# Sampling helpers
# ---------------------------------------------------------------------
def reset_sampling_params(self, sampling_params):
old_force_argmax_sampling = self.tt_sampling._force_argmax_sampling
self.tt_sampling.reset_params(
k=sampling_params.top_k,
p=sampling_params.top_p,
temp=sampling_params.temperature,
enable_log_probs=sampling_params.enable_log_probs,
)
self.tt_penalties.reset_params(
sampling_params.presence_penalty, sampling_params.frequency_penalty, sampling_params.repetition_penalty
)
if self.tt_sampling._force_argmax_sampling != old_force_argmax_sampling:
self.reset_trace()

old_penalties_active = self._penalties_active
self._penalties_active = not (
self._is_default_penalty(sampling_params.presence_penalty, self._DEFAULT_PENALTIES["presence"])
and self._is_default_penalty(sampling_params.frequency_penalty, self._DEFAULT_PENALTIES["frequency"])
and self._is_default_penalty(sampling_params.repetition_penalty, self._DEFAULT_PENALTIES["repetition"])
)
if (
not self.tt_sampling._force_argmax_sampling
or self._penalties_active
or self._penalties_active != old_penalties_active
):
self.tt_penalties.reset_params(
sampling_params.presence_penalty, sampling_params.frequency_penalty, sampling_params.repetition_penalty
)
self._log_probs_active = self.tt_sampling.log_probs_calculator.enable_log_probs

def _validate_trace_inputs(self, slot, logits: ttnn.Tensor, tt_out_tok: Optional[ttnn.Tensor]):
Expand Down Expand Up @@ -153,7 +162,7 @@ def _run_sampling(
tt_out_tok: Optional[ttnn.Tensor],
):
if penalties_on:
self.tt_penalties.apply(logits)
logits = self.tt_penalties.apply(logits)
tt_tokens, tt_log_probs = self.tt_sampling(logits, tt_out_tok=tt_out_tok)
return tt_tokens, tt_log_probs

Expand All @@ -167,11 +176,14 @@ def capture_trace(
Capture a trace of the sampling pipeline for the given configuration.
"""
penalties_on = self._penalties_active
log_probs_on = self._log_probs_active
log_probs_on = getattr(self, "_log_probs_active", False)
force_argmax = self.tt_sampling._force_argmax_sampling

key, slot = self._trace_slot(penalties_on, log_probs_on)
key, slot = self._trace_slot(penalties_on, log_probs_on, force_argmax)

logger.debug("Pre-compiling sampling path before trace capture (penalties=%s)", penalties_on)
logger.debug(
f"Pre-compiling sampling path before trace capture (penalties={penalties_on},log_probs_on={log_probs_on},force_argmax={force_argmax})"
)
self._run_sampling(
logits,
penalties_on=penalties_on,
Expand Down Expand Up @@ -210,7 +222,6 @@ def _execute_trace(self, key: _TraceKey) -> ttnn.Tensor:
raise RuntimeError("Trace has not been captured yet.")

ttnn.execute_trace(self.mesh_device, slot["id"], cq_id=self.cq_id, blocking=False)

return slot["output"]

def sample(
Expand All @@ -226,7 +237,8 @@ def sample(
"""

penalties_on = self._penalties_active
log_probs_on = self._log_probs_active
log_probs_on = getattr(self, "_log_probs_active", False)
force_argmax = self.tt_sampling._force_argmax_sampling
use_internal_trace = enable_trace and self.enable_internal_trace

if not use_internal_trace:
Expand All @@ -236,7 +248,7 @@ def sample(
tt_out_tok=tt_out_tok,
)
else:
key, slot = self._trace_slot(penalties_on, log_probs_on)
key, slot = self._trace_slot(penalties_on, log_probs_on, force_argmax)
if slot["id"] is None:
return self.capture_trace(
logits,
Expand All @@ -253,24 +265,6 @@ def sample(
self.tt_penalties.update_output_tokens(tt_out)
return tt_out

def reset_seed(self, seed):
for i, s in enumerate(seed):
if s is None:
# set to default seed value which is 0
seed[i] = 0
seed = torch.tensor(seed)
user_ids = torch.arange(seed.shape[0])

user_ids_tt = ttnn.from_torch(
user_ids, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT
)
seeds_tt = ttnn.from_torch(seed, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)

# reset seed for each user_id
ttnn.manual_seed(seeds=seeds_tt, user_ids=user_ids_tt, sub_core_grids=self.sub_core_grids)
seeds_tt.deallocate()
user_ids_tt.deallocate()


def clamp(value, min_value, max_value):
if value < min_value:
Expand All @@ -297,7 +291,7 @@ def format_sampling_params(sampling_params, max_batch_size):
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.0,
"seed": 0,
"seed": random.randint(0, 1000000), # set to random seed to have variability while using tensor manual_seed
}
target_len = max_batch_size
assert target_len == 32, "Sampling only support batch_size=32"
Expand Down Expand Up @@ -359,4 +353,30 @@ def format_sampling_params(sampling_params, max_batch_size):

if sampling_params.repetition_penalty[i] == 0:
sampling_params.repetition_penalty[i] = default_params["repetition_penalty"]

if sampling_params.top_k[i] < 1:
sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32)
Comment on lines +357 to +358
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate check for top_k < 1 at lines 347-348 and 353-354. The second check (lines 353-354) is redundant and should be removed.

Suggested change
if sampling_params.top_k[i] < 1:
sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32)

Copilot uses AI. Check for mistakes.
return sampling_params


class SeedManager:
def __init__(self, tt_sampling):
self.seeds = [secrets.randbits(64) for _ in range(32)]
self.rngs = [random.Random(seed) for seed in self.seeds]
self.tt_sampling = tt_sampling

def reset_seed(self, seeds, user_ids):
for i, user in enumerate(user_ids):
self.rngs[user].seed(seeds[i])
self.seeds[user] = seeds[i]

def get_new_values(self, empty_slots=range(32), replicate_seeds=False):
# get new seeds for each user in empty_slots otherwise 0
new_seeds = [rng.randint(0, 1000000) if i in empty_slots else 0 for i, rng in enumerate(self.rngs)]

if replicate_seeds:
assert len(empty_slots) == 1, "Cannot replicate seeds if empty_slots is not length 1"
new_seeds = 32 * [new_seeds[empty_slots[0]]]
# send new seeds to sampling module
new_seed_tt = ttnn.from_torch(torch.tensor(new_seeds), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
ttnn.copy_host_to_device_tensor(new_seed_tt, self.tt_sampling.seeds_tt_tensor)
Loading
Loading