Skip to content

Commit 49ab40b

Browse files
djordje-ttsraizada-tttchedaTTjonathansuTTalnah005
authored andcommitted
[Merge stable to main] Llama3.3-70b and 3.1-8b - Fix sampling parameters (#36476)
#36325 This PR fixes couple of different issues for Llama3.3-70b: - Non-uniform seeding - Penalty trap bug - Penalty bugs for Llama3.1-8b - batched prefill determinism - diff between batched and non-batched prefill - missing logprobs support for Llama3.3-70b - Fixes same sampling parameters for Llama3.1-8b - Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI. - Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`) so prefill + decode produce deterministic sequences across repeats when seeds are fixed. - Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users. - Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence. Performance numbers on text_demo in t/s/u: | branch | without penalties | with penalties | |-------|-------|-------| | branch | 71.88 t/s/u | 42.36 t/s/u | | main | 72.05 t/s/u | - | **TTFT**: **68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm is expected. - [ ] [All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046) - [x] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284) - [x] [vllm nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050) - [x] [Models CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475) Last pipelines list 6th Feb: - [] [vllm Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798) - [] [Shield CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631) - [] [Galaxy Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409) --------- Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com> Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com> Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com> Co-authored-by: alnah005 <salnahari@tenstorrent.com> Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com> Co-authored-by: handrewsTT <handrews@tenstorrent.com> Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com> Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com> Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com> Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f87c34a commit 49ab40b

38 files changed

+2008
-600
lines changed

.github/workflows/vllm-nightly-tests-impl.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ jobs:
146146
{
147147
"name": "[WH-T3K][v1] Gemma3-27B",
148148
"model": "google/gemma-3-27b-it",
149-
"server-timeout": 20,
149+
"server-timeout": 45,
150150
"benchmark-timeout": 5,
151151
"runner-label": "config-t3000",
152152
"arch": "arch-wormhole_b0",
@@ -161,7 +161,7 @@ jobs:
161161
{
162162
"name": "[WH-GLX][v1] Gemma3-27B DP=4",
163163
"model": "google/gemma-3-27b-it",
164-
"server-timeout": 20,
164+
"server-timeout": 45,
165165
"benchmark-timeout": 5,
166166
"runner-label": "topology-6u",
167167
"arch": "arch-wormhole_b0",

models/common/sampling/generator.py

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@
22

33
# SPDX-License-Identifier: Apache-2.0
44

5-
import logging
5+
import random
6+
import secrets
67
from dataclasses import dataclass, fields, replace
78
from typing import List, Optional
89

910
import torch
11+
from loguru import logger
1012

1113
import ttnn
1214

1315
from .tt_penalties import TTPenalties
1416
from .tt_sampling import TTSampling
1517

16-
logger = logging.getLogger(__name__)
17-
1818

1919
@dataclass(frozen=True)
2020
class _TraceKey:
2121
penalties_on: bool
2222
log_probs_on: bool
23+
force_argmax: bool
2324

2425

2526
class SamplingGenerator:
@@ -60,12 +61,13 @@ def __init__(
6061
self._penalties_active = False
6162

6263
self._trace_states: dict[_TraceKey, dict] = {}
64+
self.seed_manager = SeedManager(self.tt_sampling)
6365

6466
def _new_trace_state(self):
6567
return {"id": None, "input": None, "output": None, "kwargs": {}}
6668

67-
def _trace_slot(self, penalties_on: bool, log_probs_on: bool):
68-
key = _TraceKey(penalties_on=penalties_on, log_probs_on=log_probs_on)
69+
def _trace_slot(self, penalties_on: bool, log_probs_on: bool, force_argmax: bool):
70+
key = _TraceKey(penalties_on=penalties_on, log_probs_on=log_probs_on, force_argmax=force_argmax)
6971
slot = self._trace_states.get(key)
7072
if slot is None:
7173
slot = self._new_trace_state()
@@ -79,10 +81,7 @@ def reset_trace(self):
7981
for key, slot in self._trace_states.items():
8082
if slot["id"] is not None:
8183
logger.debug(
82-
"Resetting sampling trace (penalties=%s, log_probs=%s, trace_id=%s)",
83-
key.penalties_on,
84-
key.log_probs_on,
85-
slot["id"],
84+
f"Resetting sampling trace (penalties={key.penalties_on}, log_probs={key.log_probs_on}, force_argmax={key.force_argmax}, trace_id={slot['id']})"
8685
)
8786
self._trace_states.clear()
8887

@@ -98,7 +97,7 @@ def reset_prompt_tokens(self, prompt_tokens):
9897
return
9998
self.tt_penalties.reset_prompt_tokens(prompt_tokens)
10099

101-
def reset_output_state(self, tokens):
100+
def reset_output_state(self, tokens=None):
102101
if not self._penalties_active:
103102
return
104103
self.tt_penalties.reset_output_tokens(tokens)
@@ -107,20 +106,30 @@ def reset_output_state(self, tokens):
107106
# Sampling helpers
108107
# ---------------------------------------------------------------------
109108
def reset_sampling_params(self, sampling_params):
109+
old_force_argmax_sampling = self.tt_sampling._force_argmax_sampling
110110
self.tt_sampling.reset_params(
111111
k=sampling_params.top_k,
112112
p=sampling_params.top_p,
113113
temp=sampling_params.temperature,
114114
enable_log_probs=sampling_params.enable_log_probs,
115115
)
116-
self.tt_penalties.reset_params(
117-
sampling_params.presence_penalty, sampling_params.frequency_penalty, sampling_params.repetition_penalty
118-
)
116+
if self.tt_sampling._force_argmax_sampling != old_force_argmax_sampling:
117+
self.reset_trace()
118+
119+
old_penalties_active = self._penalties_active
119120
self._penalties_active = not (
120121
self._is_default_penalty(sampling_params.presence_penalty, self._DEFAULT_PENALTIES["presence"])
121122
and self._is_default_penalty(sampling_params.frequency_penalty, self._DEFAULT_PENALTIES["frequency"])
122123
and self._is_default_penalty(sampling_params.repetition_penalty, self._DEFAULT_PENALTIES["repetition"])
123124
)
125+
if (
126+
not self.tt_sampling._force_argmax_sampling
127+
or self._penalties_active
128+
or self._penalties_active != old_penalties_active
129+
):
130+
self.tt_penalties.reset_params(
131+
sampling_params.presence_penalty, sampling_params.frequency_penalty, sampling_params.repetition_penalty
132+
)
124133
self._log_probs_active = self.tt_sampling.log_probs_calculator.enable_log_probs
125134

126135
def _validate_trace_inputs(self, slot, logits: ttnn.Tensor, tt_out_tok: Optional[ttnn.Tensor]):
@@ -153,7 +162,7 @@ def _run_sampling(
153162
tt_out_tok: Optional[ttnn.Tensor],
154163
):
155164
if penalties_on:
156-
self.tt_penalties.apply(logits)
165+
logits = self.tt_penalties.apply(logits)
157166
tt_tokens, tt_log_probs = self.tt_sampling(logits, tt_out_tok=tt_out_tok)
158167
return tt_tokens, tt_log_probs
159168

@@ -167,11 +176,14 @@ def capture_trace(
167176
Capture a trace of the sampling pipeline for the given configuration.
168177
"""
169178
penalties_on = self._penalties_active
170-
log_probs_on = self._log_probs_active
179+
log_probs_on = getattr(self, "_log_probs_active", False)
180+
force_argmax = self.tt_sampling._force_argmax_sampling
171181

172-
key, slot = self._trace_slot(penalties_on, log_probs_on)
182+
key, slot = self._trace_slot(penalties_on, log_probs_on, force_argmax)
173183

174-
logger.debug("Pre-compiling sampling path before trace capture (penalties=%s)", penalties_on)
184+
logger.debug(
185+
f"Pre-compiling sampling path before trace capture (penalties={penalties_on},log_probs_on={log_probs_on},force_argmax={force_argmax})"
186+
)
175187
self._run_sampling(
176188
logits,
177189
penalties_on=penalties_on,
@@ -210,7 +222,6 @@ def _execute_trace(self, key: _TraceKey) -> ttnn.Tensor:
210222
raise RuntimeError("Trace has not been captured yet.")
211223

212224
ttnn.execute_trace(self.mesh_device, slot["id"], cq_id=self.cq_id, blocking=False)
213-
214225
return slot["output"]
215226

216227
def sample(
@@ -226,7 +237,8 @@ def sample(
226237
"""
227238

228239
penalties_on = self._penalties_active
229-
log_probs_on = self._log_probs_active
240+
log_probs_on = getattr(self, "_log_probs_active", False)
241+
force_argmax = self.tt_sampling._force_argmax_sampling
230242
use_internal_trace = enable_trace and self.enable_internal_trace
231243

232244
if not use_internal_trace:
@@ -236,7 +248,7 @@ def sample(
236248
tt_out_tok=tt_out_tok,
237249
)
238250
else:
239-
key, slot = self._trace_slot(penalties_on, log_probs_on)
251+
key, slot = self._trace_slot(penalties_on, log_probs_on, force_argmax)
240252
if slot["id"] is None:
241253
return self.capture_trace(
242254
logits,
@@ -253,24 +265,6 @@ def sample(
253265
self.tt_penalties.update_output_tokens(tt_out)
254266
return tt_out
255267

256-
def reset_seed(self, seed):
257-
for i, s in enumerate(seed):
258-
if s is None:
259-
# set to default seed value which is 0
260-
seed[i] = 0
261-
seed = torch.tensor(seed)
262-
user_ids = torch.arange(seed.shape[0])
263-
264-
user_ids_tt = ttnn.from_torch(
265-
user_ids, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT
266-
)
267-
seeds_tt = ttnn.from_torch(seed, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
268-
269-
# reset seed for each user_id
270-
ttnn.manual_seed(seeds=seeds_tt, user_ids=user_ids_tt, sub_core_grids=self.sub_core_grids)
271-
seeds_tt.deallocate()
272-
user_ids_tt.deallocate()
273-
274268

275269
def clamp(value, min_value, max_value):
276270
if value < min_value:
@@ -297,7 +291,7 @@ def format_sampling_params(sampling_params, max_batch_size):
297291
"presence_penalty": 0.0,
298292
"frequency_penalty": 0.0,
299293
"repetition_penalty": 1.0,
300-
"seed": 0,
294+
"seed": random.randint(0, 1000000), # set to random seed to have variability while using tensor manual_seed
301295
}
302296
target_len = max_batch_size
303297
assert target_len == 32, "Sampling only support batch_size=32"
@@ -355,4 +349,30 @@ def format_sampling_params(sampling_params, max_batch_size):
355349

356350
if sampling_params.repetition_penalty[i] == 0:
357351
sampling_params.repetition_penalty[i] = default_params["repetition_penalty"]
352+
353+
if sampling_params.top_k[i] < 1:
354+
sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32)
358355
return sampling_params
356+
357+
358+
class SeedManager:
359+
def __init__(self, tt_sampling):
360+
self.seeds = [secrets.randbits(64) for _ in range(32)]
361+
self.rngs = [random.Random(seed) for seed in self.seeds]
362+
self.tt_sampling = tt_sampling
363+
364+
def reset_seed(self, seeds, user_ids):
365+
for i, user in enumerate(user_ids):
366+
self.rngs[user].seed(seeds[i])
367+
self.seeds[user] = seeds[i]
368+
369+
def get_new_values(self, empty_slots=range(32), replicate_seeds=False):
370+
# get new seeds for each user in empty_slots otherwise 0
371+
new_seeds = [rng.randint(0, 1000000) if i in empty_slots else 0 for i, rng in enumerate(self.rngs)]
372+
373+
if replicate_seeds:
374+
assert len(empty_slots) == 1, "Cannot replicate seeds if empty_slots is not length 1"
375+
new_seeds = 32 * [new_seeds[empty_slots[0]]]
376+
# send new seeds to sampling module
377+
new_seed_tt = ttnn.from_torch(torch.tensor(new_seeds), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT)
378+
ttnn.copy_host_to_device_tensor(new_seed_tt, self.tt_sampling.seeds_tt_tensor)

0 commit comments

Comments
 (0)