Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tunix/generate/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ def sample_top_p(
# Upcast to float32 for numerical stability of softmax and subsequent cumsum.
next_token_logits = logits[:, -1].astype(jnp.float32) / temperature

# top_k=0 or None both mean "no top-k filtering" — use full vocabulary.
_no_topk = top_k is None or top_k <= 0
# Skip softmax and sorting if top_p is 1.0 and top_k is full vocab.
if top_p >= 1.0 and top_k is None:
if top_p >= 1.0 and _no_topk:
next_token = jax.random.categorical(key, logits=next_token_logits)
if not return_logprobs:
return next_token, None
Expand All @@ -130,7 +132,7 @@ def sample_top_p(
return next_token, logp_sampled

probs = jax.nn.softmax(next_token_logits, axis=-1)
k = probs.shape[-1] if top_k is None else top_k
k = probs.shape[-1] if _no_topk else top_k

probs_sorted, indices = jax.lax.top_k(probs, k=k)
cumsum_probs = jnp.cumsum(probs_sorted, axis=-1)
Expand Down
19 changes: 15 additions & 4 deletions tunix/generate/vllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,13 @@ def detokenize(
input_strings, request_outputs
):
for idx, single_output in enumerate(multi_sampling_output.outputs):
# vLLM still returns 1 eos id even if we ask it to stop at eos.
if single_output.token_ids[-1] == self.tokenizer.eos_id():
single_output.token_ids = single_output.token_ids[:-1]
single_output.logprobs = single_output.logprobs[:-1]
# KEEP the eos token in the returned token_ids — needed so multi-turn
# consumers (agentic engine) can reconstruct the exact sequence the
# next turn's prompt was rendered from. Combined with
# `include_stop_str_in_output=True`, vLLM emits one eos at the end of
# each generation. Stripping it (the previous behavior) made
# trainer-side concatenation miss `<|im_end|>` at every turn boundary
# and produced 30+ nat sampler-trainer logp diffs.

out_tokens[idx].append(
np.array(single_output.token_ids, dtype=np.int32)
Expand Down Expand Up @@ -461,6 +464,14 @@ def __call__(
sampling_params.prompt_logprobs = 0
sampling_params.stop_token_ids = [self.tokenizer.eos_id()]
sampling_params.skip_special_tokens = True
# Keep the stop token in the returned ``token_ids`` so multi-turn
# consumers can reconstruct the exact sequence the model was sampled
# on. This makes the trainer-side concatenation align with what
# ``apply_chat_template`` produces for the next turn's prompt; without
# it, the trailing ``<|im_end|>`` (or equivalent eos token) is missing
# at every turn boundary in the recorded sequence, biasing logp
# recomputation against the model's actual sampling context.
sampling_params.include_stop_str_in_output = True

if top_p is not None:
sampling_params.top_p = top_p
Expand Down
88 changes: 72 additions & 16 deletions tunix/models/qwen3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def block(
segment_pos: jaxtyping.Array,
cache: LayerCache | None,
attn_mask: jaxtyping.Array | None,
segment_ids: jaxtyping.Array | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used for sequence packing right? right now we haven't enabled that yet since the seq packing support is still WIP

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, here it's not for packing, but needed for per-batch padding masks. So the splash attention kernel only accepts a static causal mask. The per-batch dynamic pad positions can only flow in through SegmentIds. Without this, left-padded prompts leak attention from real queries onto pad keys, making softmax garbage that shifts the logits. We use segment_id=0 for pad and segment_id=1 for real tokens to avoid it

) -> tuple[LayerCache | None, jaxtyping.Array]:
seq_len = x.shape[1]

Expand Down Expand Up @@ -571,19 +572,59 @@ def block(
shd.NamedSharding(mesh, P(shd_n, shd_t))
)

@partial(
shard_map,
mesh=mesh,
in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq),
out_specs=shd_spec,
check_rep=False,
)
def sharded_splash_attn(kernel, q_block, k_block, v_block):
return jax.vmap(kernel)(q_block, k_block, v_block)
# Per-position segment ids let splash suppress cross-segment attention
# (e.g. real-token to pad-token, or sequence-packing cross-boundary).
# The pallas splash kernel only accepts a static causal mask kernel-side,
# so per-batch dynamic padding masks have to flow in via segment_ids.
if segment_ids is not None:
seg_spec = P(shd_b, shd_t)
unsharded_seg_spec = P(shd_b, None)

@partial(
shard_map,
mesh=mesh,
in_specs=(
kernel_spec,
shd_spec,
unsharded_seq,
unsharded_seq,
seg_spec,
unsharded_seg_spec,
),
out_specs=shd_spec,
check_rep=False,
)
def sharded_splash_attn(
kernel, q_block, k_block, v_block, q_seg_block, kv_seg_block
):
seg_ids = splash.SegmentIds(q=q_seg_block, kv=kv_seg_block)
return jax.vmap(kernel)(
q_block, k_block, v_block, segment_ids=seg_ids
)

qkv = sharded_splash_attn(
splash_attn_kernel,
query_proj,
key_proj,
value_proj,
segment_ids,
segment_ids,
)
else:

qkv = sharded_splash_attn(
splash_attn_kernel, query_proj, key_proj, value_proj
)
@partial(
shard_map,
mesh=mesh,
in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq),
out_specs=shd_spec,
check_rep=False,
)
def sharded_splash_attn(kernel, q_block, k_block, v_block):
return jax.vmap(kernel)(q_block, k_block, v_block)

qkv = sharded_splash_attn(
splash_attn_kernel, query_proj, key_proj, value_proj
)
qkv = qkv.transpose(0, 2, 1, 3)
else:
# GQA
Expand Down Expand Up @@ -621,6 +662,7 @@ def __call__(
segment_pos: jaxtyping.Array,
cache: LayerCache | None,
attn_mask: jaxtyping.Array | None,
segment_ids: jaxtyping.Array | None = None,
) -> tuple[LayerCache | None, jaxtyping.Array]:
if (
self.config.remat_config == RematConfig.BLOCK
Expand All @@ -629,10 +671,10 @@ def __call__(
# nnx.remat needs to be applied to the unbound function and take self
# as the first argument.
return nnx.remat(self.block.__func__)(
self, x, segment_pos, cache, attn_mask
self, x, segment_pos, cache, attn_mask, segment_ids
)
else:
return self.block(x, segment_pos, cache, attn_mask)
return self.block(x, segment_pos, cache, attn_mask, segment_ids=segment_ids)

@property
def head_dim(self):
Expand Down Expand Up @@ -1052,13 +1094,15 @@ def block(
segment_pos: jaxtyping.Array,
cache: LayerCache | None,
attn_mask: jaxtyping.Array,
segment_ids: jaxtyping.Array | None = None,
) -> tuple[LayerCache | None, jaxtyping.Array]:
inputs_normalized = self.input_layernorm(x)
cache, attn_output = self.attn(
inputs_normalized,
segment_pos,
cache,
attn_mask,
segment_ids=segment_ids,
)
attn_output += x
residual = attn_output
Expand All @@ -1073,14 +1117,19 @@ def __call__(
segment_pos: jaxtyping.Array,
cache: LayerCache | None,
attn_mask: jaxtyping.Array,
segment_ids: jaxtyping.Array | None = None,
) -> tuple[LayerCache | None, jaxtyping.Array]:
if (
self.config.remat_config == RematConfig.DECODER
or self.config.remat_config == RematConfig.DECODER.value
):
return nnx.remat(self.block.__func__)(self, x, segment_pos, cache, attn_mask)
return nnx.remat(self.block.__func__)(
self, x, segment_pos, cache, attn_mask, segment_ids
)
else:
return self.block(x, segment_pos, cache, attn_mask)
return self.block(
x, segment_pos, cache, attn_mask, segment_ids=segment_ids
)


class Qwen3(BackendMappingMixin, nnx.Module):
Expand Down Expand Up @@ -1146,6 +1195,7 @@ def __call__(
cache: Cache | None, # (sequence length L')
attention_mask: jaxtyping.Array, # [B, L, L']
output_hidden_states: bool = False,
segment_ids: jaxtyping.Array | None = None, # [B, L]
) -> tuple[jaxtyping.Array, Cache | None]:
"""Qwen3 model.

Expand All @@ -1155,6 +1205,11 @@ def __call__(
cache: Attention KV cache or None.
attention_mask: transformer input mask.
output_hidden_states: whether to output the hidden states.
segment_ids: optional per-position segment identifiers, [B, L]. Used by
flash attention to suppress cross-segment attention (e.g. real-token
to pad-token, or sequence-packing across document boundaries). Pass a
1/0 mask to skip pad positions; pass increasing integer ids per packed
document for sequence packing.

Returns:
predicted_logits, new_cache
Expand All @@ -1173,6 +1228,7 @@ def __call__(
positions,
layer_cache,
attention_mask,
segment_ids=segment_ids,
)
if cache is not None:
new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch
Expand Down
Loading