Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
af584b4
Update gemma3_causal_lm_preprocessor.py
pctablet505 Apr 17, 2025
dc4ae8c
Update gemma3_causal_lm_preprocessor.py
pctablet505 Apr 17, 2025
07c5c77
Update gemma3_causal_lm_preprocessor_test.py
pctablet505 Apr 17, 2025
3fdc7fd
Update reversible_embedding.py
pctablet505 Jun 10, 2025
fa57e33
Merge branch 'master' of https://github.com/pctablet505/keras-hub
pctablet505 Jun 10, 2025
8da3303
upadated Gemma3InterleaveEmbeddings
pctablet505 Jun 19, 2025
adac2c6
Update gemma3_interleave_embeddings.py
pctablet505 Jun 19, 2025
bd27ec0
Revert "Update reversible_embedding.py"
pctablet505 Jun 19, 2025
f5163e8
Merge branch 'keras-team:master' into master
pctablet505 Jun 19, 2025
1904136
Update gemma3_interleave_embeddings.py
pctablet505 Jun 19, 2025
552fecb
Merge branch 'keras-team:master' into master
pctablet505 Jul 7, 2025
3aa11e9
Merge branch 'master' of https://github.com/pctablet505/keras-hub
pctablet505 Nov 17, 2025
63d529a
Merge branch 'keras-team:master' into master
pctablet505 Nov 27, 2025
fcada92
Merge branch 'keras-team:master' into master
pctablet505 Dec 8, 2025
2cfe17b
Merge branch 'keras-team:master' into master
pctablet505 Dec 19, 2025
f3f85cb
Merge branch 'keras-team:master' into master
pctablet505 Dec 23, 2025
ddf14d5
Merge branch 'keras-team:master' into master
pctablet505 Dec 23, 2025
2e176e7
Merge branch 'keras-team:master' into master
pctablet505 Jan 6, 2026
a69c99c
Ensure int32 type for indices in NMS layer
pctablet505 Jan 6, 2026
2fd457c
Merge branch 'master' of https://github.com/pctablet505/keras-hub
pctablet505 Jan 6, 2026
d39d485
Update mask assertion in embedding layer test
pctablet505 Jan 6, 2026
527c427
Revert "Update mask assertion in embedding layer test"
pctablet505 Jan 7, 2026
3eaa5f4
Merge branch 'keras-team:master' into master
pctablet505 Jan 8, 2026
9b03ed9
Merge branch 'keras-team:master' into master
pctablet505 Jan 23, 2026
2e84113
Merge branch 'keras-team:master' into master
pctablet505 Jan 27, 2026
ae39725
Merge branch 'keras-team:master' into master
pctablet505 Jan 29, 2026
dbdf2d2
Merge branch 'keras-team:master' into master
pctablet505 Feb 3, 2026
bff2ac9
Merge branch 'keras-team:master' into master
pctablet505 Feb 9, 2026
a63d237
Performance optimizations for keras-hub generate() pipeline
pctablet505 Feb 9, 2026
3197607
Fix critical attention mask inversion for PyTorch SDPA
pctablet505 Feb 10, 2026
1dfe107
Add fast cached decoding path; unify generate_step
pctablet505 Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions keras_hub/src/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@

from keras_hub.src.api_export import keras_hub_export

# Check if SDPA is available for the PyTorch backend.
_TORCH_SDPA_AVAILABLE = None


def _check_torch_sdpa():
global _TORCH_SDPA_AVAILABLE
if _TORCH_SDPA_AVAILABLE is None:
try:
import torch.nn.functional as F

_TORCH_SDPA_AVAILABLE = hasattr(F, "scaled_dot_product_attention")
except ImportError:
_TORCH_SDPA_AVAILABLE = False
return _TORCH_SDPA_AVAILABLE


@keras_hub_export("keras_hub.layers.CachedMultiHeadAttention")
class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
Expand Down Expand Up @@ -121,3 +136,112 @@ def call(
if cache is not None:
return attention_output, cache
return attention_output

def call_cached(
self,
query,
attention_mask=None,
cache=None,
cache_update_index=None,
):
"""Ultra-fast path for cached autoregressive decoding.

Bypasses Layer.__call__ overhead on all sublayers by calling
.call() directly. This is safe because:
- All sublayers are already built
- Input dtypes are already correct (same dtype flows through)
- No masking metadata needed
- No training-mode checks needed (always inference)
- No autocast scope changes needed

This saves ~5 Layer.__call__ invocations per attention layer
(query_dense, key_dense, value_dense, output_dense, plus the
attention layer itself).
"""
# Directly call .call() on dense layers, bypassing Layer.__call__
query_proj = self._query_dense.call(query)

key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update = self._key_dense.call(query)
value_update = self._value_dense.call(query)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)

attention_output, _ = self._compute_attention(
query=query_proj,
key=key,
value=value,
attention_mask=attention_mask,
training=False,
)

attention_output = self._output_dense.call(attention_output)
return attention_output, cache

def _compute_attention(
self,
query,
key,
value,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Override to use SDPA during cached inference on torch.

Only activated when `_use_sdpa_override` is set True (by the
TransformerDecoder.call_cached fast path for self-attention).
Falls back to the parent implementation otherwise.
"""
if (
keras.config.backend() == "torch"
and not return_attention_scores
and (training is None or training is False)
and len(query.shape) == 4
and _check_torch_sdpa()
and getattr(self, "_use_sdpa_override", False)
):
import torch
import torch.nn.functional as F

# Transpose from (B, S, H, D) to (B, H, S, D) for SDPA.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# Convert attention mask to SDPA format.
# Both Keras and PyTorch SDPA use the same convention for bool
# masks: True = attend, False = mask out.
# No inversion needed - pass through directly.
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=torch.bool)
while attention_mask.dim() < 4:
attention_mask = attention_mask.unsqueeze(1)

attention_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
)

# Transpose back from (B, H, T, D) to (B, T, H, D).
attention_output = attention_output.transpose(1, 2)
return attention_output, None

return super()._compute_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
training=training,
return_attention_scores=return_attention_scores,
)
26 changes: 21 additions & 5 deletions keras_hub/src/layers/modeling/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,27 @@ def call(self, inputs, start_index=0, positions=None):
# than the sequence_length of the layer.
position_embeddings = ops.convert_to_tensor(self.position_embeddings)
if positions is None:
position_embeddings = ops.slice(
position_embeddings,
(start_index, 0),
(sequence_length, feature_length),
)
# Fast path for single-token cached decoding on torch: use direct
# indexing instead of ops.slice to avoid overhead.
# Only applies when both sequence_length and start_index are
# static Python ints (not traced values like in JAX JIT).
if (
isinstance(sequence_length, int)
and sequence_length == 1
and isinstance(start_index, int)
):
position_embeddings = position_embeddings[
start_index : start_index + 1, :
]
position_embeddings = ops.expand_dims(
position_embeddings, axis=0
)
else:
position_embeddings = ops.slice(
position_embeddings,
(start_index, 0),
(sequence_length, feature_length),
)
else:
# Take care of unbatched `positions`.
if len(ops.shape(positions)) == 1:
Expand Down
72 changes: 72 additions & 0 deletions keras_hub/src/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,78 @@ def build(
# Create layers based on input shape.
self.built = True

def call_cached(
self,
decoder_sequence,
self_attention_cache,
self_attention_cache_update_index,
self_attention_mask=None,
):
"""Ultra-fast path for cached autoregressive decoding (decoder-only).

Bypasses ALL Layer.__call__ overhead by calling .call() directly
on every sublayer. This is safe during cached inference because:
- All layers are already built
- Input dtypes are already correct (same dtype flows through)
- No masking metadata needed
- No training-mode checks needed (always inference)
- No autocast scope changes needed

This saves ~10 Layer.__call__ invocations per transformer layer:
- 1 for self_attention_layer_norm
- 1 for self_attention_layer (which internally saves 4 more for
query/key/value/output dense via call_cached)
- 1 for feedforward_layer_norm
- 1 for feedforward_intermediate_dense
- 1 for feedforward_output_dense
"""
x = decoder_sequence

# Self attention block (normalize_first path for GPT-2).
residual = x
if self.normalize_first:
x = self._self_attention_layer_norm.call(x)

# Compute mask only if not provided (fallback).
if self_attention_mask is None:
self_attention_mask = self._compute_self_attention_mask(
decoder_sequence=decoder_sequence,
decoder_padding_mask=None,
decoder_attention_mask=None,
use_causal_mask=True,
self_attention_cache=self_attention_cache,
self_attention_cache_update_index=(
self_attention_cache_update_index
),
)

# Use call_cached() on the attention layer to bypass Layer.__call__
# overhead on all dense sublayers (query, key, value, output).
self._self_attention_layer._use_sdpa_override = True
x, self_attention_cache = self._self_attention_layer.call_cached(
query=x,
attention_mask=self_attention_mask,
cache=self_attention_cache,
cache_update_index=self_attention_cache_update_index,
)
self._self_attention_layer._use_sdpa_override = False

x = x + residual
if not self.normalize_first:
x = self._self_attention_layer_norm.call(x)

# Feedforward block - bypass Layer.__call__ on all dense layers.
residual = x
if self.normalize_first:
x = self._feedforward_layer_norm.call(x)
x = self._feedforward_intermediate_dense.call(x)
x = self._feedforward_output_dense.call(x)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm.call(x)

return (x, self_attention_cache)

def call(
self,
decoder_sequence,
Expand Down
12 changes: 12 additions & 0 deletions keras_hub/src/layers/modeling/transformer_layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0):
`(batch_size, output_length, input_length)` that can be passed to a
attention layer.
"""
# Fast path for autoregressive generation: when output_length=1 (single
# token), the causal mask is simply True for all positions up to
# cache_index and False after. This avoids ops.arange/expand_dims/
# broadcast_to overhead that is significant when called 12×46 times.
if isinstance(output_length, int) and output_length == 1:
j = ops.arange(input_length, dtype="float32")
mask = ops.expand_dims(
ops.expand_dims(j <= ops.cast(cache_index, "float32"), axis=0),
axis=0,
)
return ops.broadcast_to(mask, (batch_size, 1, input_length))

i = ops.arange(output_length, dtype="float32")
i = i + ops.cast(cache_index, "float32")
i = ops.expand_dims(i, axis=1)
Expand Down
76 changes: 0 additions & 76 deletions keras_hub/src/models/bloom/bloom_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
BloomCausalLMPreprocessor,
)
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.utils.tensor_utils import any_equal


@keras_hub_export("keras_hub.models.BloomCausalLM")
Expand Down Expand Up @@ -206,78 +205,3 @@ def _build_cache(self, token_ids):
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache

def generate_step(
self,
inputs,
stop_token_ids=None,
):
"""A compilable generation function for a single batch of inputs.

This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.

Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
batch_size = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

token_ids = self.sampler(
next=next,
prompt=token_ids,
cache=cache,
index=index,
mask=padding_mask,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of stop token locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)

end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
Loading
Loading