diff --git a/keras_hub/src/layers/modeling/cached_multi_head_attention.py b/keras_hub/src/layers/modeling/cached_multi_head_attention.py index 0441e71845..f84706978f 100644 --- a/keras_hub/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_hub/src/layers/modeling/cached_multi_head_attention.py @@ -3,6 +3,21 @@ from keras_hub.src.api_export import keras_hub_export +# Check if SDPA is available for the PyTorch backend. +_TORCH_SDPA_AVAILABLE = None + + +def _check_torch_sdpa(): + global _TORCH_SDPA_AVAILABLE + if _TORCH_SDPA_AVAILABLE is None: + try: + import torch.nn.functional as F + + _TORCH_SDPA_AVAILABLE = hasattr(F, "scaled_dot_product_attention") + except ImportError: + _TORCH_SDPA_AVAILABLE = False + return _TORCH_SDPA_AVAILABLE + @keras_hub_export("keras_hub.layers.CachedMultiHeadAttention") class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): @@ -121,3 +136,112 @@ def call( if cache is not None: return attention_output, cache return attention_output + + def call_cached( + self, + query, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + """Ultra-fast path for cached autoregressive decoding. + + Bypasses Layer.__call__ overhead on all sublayers by calling + .call() directly. This is safe because: + - All sublayers are already built + - Input dtypes are already correct (same dtype flows through) + - No masking metadata needed + - No training-mode checks needed (always inference) + - No autocast scope changes needed + + This saves ~5 Layer.__call__ invocations per attention layer + (query_dense, key_dense, value_dense, output_dense, plus the + attention layer itself). + """ + # Directly call .call() on dense layers, bypassing Layer.__call__ + query_proj = self._query_dense.call(query) + + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update = self._key_dense.call(query) + value_update = self._value_dense.call(query) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + + attention_output, _ = self._compute_attention( + query=query_proj, + key=key, + value=value, + attention_mask=attention_mask, + training=False, + ) + + attention_output = self._output_dense.call(attention_output) + return attention_output, cache + + def _compute_attention( + self, + query, + key, + value, + attention_mask=None, + training=None, + return_attention_scores=False, + ): + """Override to use SDPA during cached inference on torch. + + Only activated when `_use_sdpa_override` is set True (by the + TransformerDecoder.call_cached fast path for self-attention). + Falls back to the parent implementation otherwise. + """ + if ( + keras.config.backend() == "torch" + and not return_attention_scores + and (training is None or training is False) + and len(query.shape) == 4 + and _check_torch_sdpa() + and getattr(self, "_use_sdpa_override", False) + ): + import torch + import torch.nn.functional as F + + # Transpose from (B, S, H, D) to (B, H, S, D) for SDPA. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Convert attention mask to SDPA format. + # Both Keras and PyTorch SDPA use the same convention for bool + # masks: True = attend, False = mask out. + # No inversion needed - pass through directly. + if attention_mask is not None: + attention_mask = attention_mask.to(dtype=torch.bool) + while attention_mask.dim() < 4: + attention_mask = attention_mask.unsqueeze(1) + + attention_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + ) + + # Transpose back from (B, H, T, D) to (B, T, H, D). + attention_output = attention_output.transpose(1, 2) + return attention_output, None + + return super()._compute_attention( + query=query, + key=key, + value=value, + attention_mask=attention_mask, + training=training, + return_attention_scores=return_attention_scores, + ) diff --git a/keras_hub/src/layers/modeling/position_embedding.py b/keras_hub/src/layers/modeling/position_embedding.py index 6e0f57906c..a5699ccd1a 100644 --- a/keras_hub/src/layers/modeling/position_embedding.py +++ b/keras_hub/src/layers/modeling/position_embedding.py @@ -104,11 +104,27 @@ def call(self, inputs, start_index=0, positions=None): # than the sequence_length of the layer. position_embeddings = ops.convert_to_tensor(self.position_embeddings) if positions is None: - position_embeddings = ops.slice( - position_embeddings, - (start_index, 0), - (sequence_length, feature_length), - ) + # Fast path for single-token cached decoding on torch: use direct + # indexing instead of ops.slice to avoid overhead. + # Only applies when both sequence_length and start_index are + # static Python ints (not traced values like in JAX JIT). + if ( + isinstance(sequence_length, int) + and sequence_length == 1 + and isinstance(start_index, int) + ): + position_embeddings = position_embeddings[ + start_index : start_index + 1, : + ] + position_embeddings = ops.expand_dims( + position_embeddings, axis=0 + ) + else: + position_embeddings = ops.slice( + position_embeddings, + (start_index, 0), + (sequence_length, feature_length), + ) else: # Take care of unbatched `positions`. if len(ops.shape(positions)) == 1: diff --git a/keras_hub/src/layers/modeling/transformer_decoder.py b/keras_hub/src/layers/modeling/transformer_decoder.py index 3ca19f5c73..c4a628f8e1 100644 --- a/keras_hub/src/layers/modeling/transformer_decoder.py +++ b/keras_hub/src/layers/modeling/transformer_decoder.py @@ -238,6 +238,78 @@ def build( # Create layers based on input shape. self.built = True + def call_cached( + self, + decoder_sequence, + self_attention_cache, + self_attention_cache_update_index, + self_attention_mask=None, + ): + """Ultra-fast path for cached autoregressive decoding (decoder-only). + + Bypasses ALL Layer.__call__ overhead by calling .call() directly + on every sublayer. This is safe during cached inference because: + - All layers are already built + - Input dtypes are already correct (same dtype flows through) + - No masking metadata needed + - No training-mode checks needed (always inference) + - No autocast scope changes needed + + This saves ~10 Layer.__call__ invocations per transformer layer: + - 1 for self_attention_layer_norm + - 1 for self_attention_layer (which internally saves 4 more for + query/key/value/output dense via call_cached) + - 1 for feedforward_layer_norm + - 1 for feedforward_intermediate_dense + - 1 for feedforward_output_dense + """ + x = decoder_sequence + + # Self attention block (normalize_first path for GPT-2). + residual = x + if self.normalize_first: + x = self._self_attention_layer_norm.call(x) + + # Compute mask only if not provided (fallback). + if self_attention_mask is None: + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + use_causal_mask=True, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=( + self_attention_cache_update_index + ), + ) + + # Use call_cached() on the attention layer to bypass Layer.__call__ + # overhead on all dense sublayers (query, key, value, output). + self._self_attention_layer._use_sdpa_override = True + x, self_attention_cache = self._self_attention_layer.call_cached( + query=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + self._self_attention_layer._use_sdpa_override = False + + x = x + residual + if not self.normalize_first: + x = self._self_attention_layer_norm.call(x) + + # Feedforward block - bypass Layer.__call__ on all dense layers. + residual = x + if self.normalize_first: + x = self._feedforward_layer_norm.call(x) + x = self._feedforward_intermediate_dense.call(x) + x = self._feedforward_output_dense.call(x) + x = x + residual + if not self.normalize_first: + x = self._feedforward_layer_norm.call(x) + + return (x, self_attention_cache) + def call( self, decoder_sequence, diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils.py b/keras_hub/src/layers/modeling/transformer_layer_utils.py index 1331422a94..6ca77cfbdf 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils.py @@ -41,6 +41,18 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0): `(batch_size, output_length, input_length)` that can be passed to a attention layer. """ + # Fast path for autoregressive generation: when output_length=1 (single + # token), the causal mask is simply True for all positions up to + # cache_index and False after. This avoids ops.arange/expand_dims/ + # broadcast_to overhead that is significant when called 12×46 times. + if isinstance(output_length, int) and output_length == 1: + j = ops.arange(input_length, dtype="float32") + mask = ops.expand_dims( + ops.expand_dims(j <= ops.cast(cache_index, "float32"), axis=0), + axis=0, + ) + return ops.broadcast_to(mask, (batch_size, 1, input_length)) + i = ops.arange(output_length, dtype="float32") i = i + ops.cast(cache_index, "float32") i = ops.expand_dims(i, axis=1) diff --git a/keras_hub/src/models/bloom/bloom_causal_lm.py b/keras_hub/src/models/bloom/bloom_causal_lm.py index d167bd5cc5..272b6a0d0e 100644 --- a/keras_hub/src/models/bloom/bloom_causal_lm.py +++ b/keras_hub/src/models/bloom/bloom_causal_lm.py @@ -6,7 +6,6 @@ BloomCausalLMPreprocessor, ) from keras_hub.src.models.causal_lm import CausalLM -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.BloomCausalLM") @@ -206,78 +205,3 @@ def _build_cache(self, token_ids): # Seed the cache. _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 32e00fb858..a588b6aa68 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -8,6 +8,7 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.task import Task from keras_hub.src.samplers.serialization import get as get_sampler +from keras_hub.src.utils.tensor_utils import any_equal try: import tensorflow as tf @@ -122,8 +123,133 @@ def compile( # Clear the compiled generate function. self.generate_function = None - def generate_step(self): - """Run generation on a single batch of input.""" + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This default implementation works for all CausalLM models that + implement `call_with_cache()` and `_build_cache()`. It includes + backend-specific optimizations (e.g., direct tensor indexing on + torch) that benefit all models automatically. + + Subclasses only need to override this if they require custom + generation logic beyond what `call_with_cache` provides. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: List of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + # Use direct tensor indexing on torch backend to avoid + # ops.slice / convert_to_tensor overhead. JAX/TF need ops.slice + # for static shapes in JIT compilation. + _use_direct_indexing = keras.config.backend() == "torch" + + if _use_direct_indexing: + import torch + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + # Extract single token for cached forward pass. + if _use_direct_indexing: + prompt = prompt[:, cache_update_index : cache_update_index + 1] + # Ensure cache_update_index is a tensor for + # call_with_cache, as some models pass it through to + # sublayers which require tensor-typed kwargs. + cache_update_index = torch.tensor( + cache_update_index, dtype=torch.int32 + ) + else: + batch_size = ops.shape(prompt)[0] + prompt = ops.slice( + prompt, [0, cache_update_index], [batch_size, 1] + ) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + logits[:, 0, :], + hidden_states[:, 0, :], + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop tokens locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass with cache for autoregressive inference. + + Subclasses must override this method to define their specific + cached forward pass logic. + + Args: + token_ids: a dense int Tensor with shape + `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current + inputs in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. + """ + raise NotImplementedError + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`. + + Subclasses must override this method to define their specific + cache structure. + """ raise NotImplementedError def make_generate_function(self): @@ -150,7 +276,8 @@ def wrapped_generate_function( inputs, stop_token_ids=None, ): - with torch.no_grad(): + # Use torch.no_grad() and inference_mode for best performance + with torch.no_grad(), torch.inference_mode(): return self.generate_step(inputs, stop_token_ids) self.generate_function = wrapped_generate_function diff --git a/keras_hub/src/models/falcon/falcon_causal_lm.py b/keras_hub/src/models/falcon/falcon_causal_lm.py index 5626448282..6e8c426084 100644 --- a/keras_hub/src/models/falcon/falcon_causal_lm.py +++ b/keras_hub/src/models/falcon/falcon_causal_lm.py @@ -6,7 +6,6 @@ from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( FalconCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.FalconCausalLM") @@ -211,77 +210,3 @@ def _build_cache(self, token_ids): # Seed the cache. _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } diff --git a/keras_hub/src/models/gemma/gemma_causal_lm.py b/keras_hub/src/models/gemma/gemma_causal_lm.py index ba72240aec..b923611391 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( GemmaCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.GemmaCausalLM") @@ -227,81 +226,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of `stop_token_ids` locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm.py b/keras_hub/src/models/gpt2/gpt2_causal_lm.py index a3a9c58b60..b90d75b3b6 100644 --- a/keras_hub/src/models/gpt2/gpt2_causal_lm.py +++ b/keras_hub/src/models/gpt2/gpt2_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( GPT2CausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.GPT2CausalLM") @@ -187,17 +186,53 @@ def call_with_cache( ) x = self.backbone.embeddings_add((tokens, positions)) x = self.backbone.embeddings_dropout(x) + + # Precompute the causal mask once and share across all layers. + # This saves (num_layers - 1) calls to compute_causal_mask per step. + # For GPT-2 base (12 layers), this is 11 fewer mask computations. + batch_size = ops.shape(token_ids)[0] + seq_len = ops.shape(token_ids)[1] + cache_len = ops.shape(cache)[3] # max_length from cache shape + from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, + ) + + causal_mask = compute_causal_mask( + batch_size, cache_len, seq_len, cache_update_index + ) + # Each decoder layer has a cache; we update them separately. - caches = [] - for i, transformer_layer in enumerate(self.backbone.transformer_layers): - current_cache = cache[:, i, ...] - x, next_cache = transformer_layer( - x, - self_attention_cache=current_cache, - self_attention_cache_update_index=cache_update_index, - ) - caches.append(next_cache) - cache = ops.stack(caches, axis=1) + if keras.config.backend() == "torch": + # On torch, slice_update is in-place through views, so the + # original cache tensor is already updated. We don't need to + # collect and re-stack layer caches, saving 12 ops.stack calls + # per generation step. + # Use call_cached() fast path to skip validation overhead and + # pass the precomputed mask to avoid redundant mask creation. + for i, transformer_layer in enumerate( + self.backbone.transformer_layers + ): + current_cache = cache[:, i, ...] + x, _ = transformer_layer.call_cached( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + self_attention_mask=causal_mask, + ) + else: + caches = [] + for i, transformer_layer in enumerate( + self.backbone.transformer_layers + ): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer.call_cached( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + self_attention_mask=causal_mask, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) hidden_states = x = self.backbone.layer_norm(x) logits = self.backbone.token_embedding(x, reverse=True) return logits, hidden_states, cache @@ -215,81 +250,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: List of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop tokens locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py index 16dc07ea0e..7079767b31 100644 --- a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -6,7 +6,6 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( GPTNeoXCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.GPTNeoXCausalLM") @@ -110,78 +109,3 @@ def _build_cache(self, token_ids): # Seed the cache. _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop_tokens locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } diff --git a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py index d6bdaede3d..b002cdc940 100644 --- a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +++ b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import ( GptOssCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.GptOssCausalLM") @@ -116,81 +115,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: List of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop_tokens locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/llama/llama_causal_lm.py b/keras_hub/src/models/llama/llama_causal_lm.py index 7f0f901d52..74fe62cc33 100644 --- a/keras_hub/src/models/llama/llama_causal_lm.py +++ b/keras_hub/src/models/llama/llama_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( LlamaCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.LlamaCausalLM") @@ -114,80 +113,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/mistral/mistral_causal_lm.py b/keras_hub/src/models/mistral/mistral_causal_lm.py index d28a7cad26..45877ec795 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( MistralCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.MistralCausalLM") @@ -114,81 +113,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: List of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop_tokens locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm.py b/keras_hub/src/models/mixtral/mixtral_causal_lm.py index 94d27e63e1..aeb7e45ab8 100644 --- a/keras_hub/src/models/mixtral/mixtral_causal_lm.py +++ b/keras_hub/src/models/mixtral/mixtral_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( MixtralCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.MixtralCausalLM") @@ -114,81 +113,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: List of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop_tokens locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/opt/opt_causal_lm.py b/keras_hub/src/models/opt/opt_causal_lm.py index c3f28d3006..cda8e962f5 100644 --- a/keras_hub/src/models/opt/opt_causal_lm.py +++ b/keras_hub/src/models/opt/opt_causal_lm.py @@ -6,7 +6,6 @@ from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( OPTCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.OPTCausalLM") @@ -209,78 +208,3 @@ def _build_cache(self, token_ids): # Seed the cache. _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of end token's to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } diff --git a/keras_hub/src/models/phi3/phi3_causal_lm.py b/keras_hub/src/models/phi3/phi3_causal_lm.py index a60c336afb..125fc4041a 100644 --- a/keras_hub/src/models/phi3/phi3_causal_lm.py +++ b/keras_hub/src/models/phi3/phi3_causal_lm.py @@ -6,7 +6,6 @@ from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( Phi3CausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.Phi3CausalLM") @@ -113,80 +112,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def generate(self, inputs, max_length=None, stop_token_ids="auto"): if self.preprocessor and stop_token_ids == "auto": # Stop at: diff --git a/keras_hub/src/models/qwen/qwen_causal_lm.py b/keras_hub/src/models/qwen/qwen_causal_lm.py index 6689101133..89ba973e24 100644 --- a/keras_hub/src/models/qwen/qwen_causal_lm.py +++ b/keras_hub/src/models/qwen/qwen_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( QwenCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export( @@ -222,80 +221,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/qwen3/qwen3_causal_lm.py b/keras_hub/src/models/qwen3/qwen3_causal_lm.py index f2d7b10b16..8278c5e87b 100644 --- a/keras_hub/src/models/qwen3/qwen3_causal_lm.py +++ b/keras_hub/src/models/qwen3/qwen3_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.qwen3.qwen3_causal_lm_preprocessor import ( Qwen3CausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export("keras_hub.models.Qwen3CausalLM") @@ -219,81 +218,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - print("generated token ids = ", token_ids[0]) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py index 198e3af697..eb83237ff1 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( Qwen3MoeCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export( @@ -231,80 +230,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py index 3c605fe561..826944e2c8 100644 --- a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +++ b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import ( QwenMoeCausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export( @@ -224,80 +223,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 7881ba2e4d..7363602d41 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -7,7 +7,6 @@ from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( SmolLM3CausalLMPreprocessor, ) -from keras_hub.src.utils.tensor_utils import any_equal @keras_hub_export( @@ -107,81 +106,6 @@ def _build_cache(self, token_ids): _, hidden_states, cache = self.call_with_cache(token_ids, cache, index) return hidden_states, cache - def generate_step( - self, - inputs, - stop_token_ids=None, - ): - """A compilable generation function for a single batch of inputs. - - This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. Inputs should have the same structure as - model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. - - Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. - stop_token_ids: Tuple of id's of the end token to stop on. If all - sequences have produced a new stop token, generation - will stop. - """ - token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - - hidden_states, cache = self._build_cache(token_ids) - # Compute the lengths of all user inputted tokens ids. - row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) - # Start at the first index that has no user inputted id. - index = ops.min(row_lengths) - - def next(prompt, cache, index): - # The cache index is the index of our previous token. - cache_update_index = index - 1 - batch_size = ops.shape(prompt)[0] - prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - - logits, hidden_states, cache = self.call_with_cache( - prompt, - cache, - cache_update_index, - ) - return ( - ops.squeeze(logits, axis=1), - ops.squeeze(hidden_states, axis=1), - cache, - ) - - token_ids = self.sampler( - next=next, - prompt=token_ids, - cache=cache, - index=index, - mask=padding_mask, - stop_token_ids=stop_token_ids, - hidden_states=hidden_states, - model=self, - ) - - # Compute an output padding mask with the token ids we updated. - if stop_token_ids is not None: - # Build a mask of stop token locations not in the original - # prompt (not in locations where `padding_mask` is True). - end_locations = any_equal( - token_ids, stop_token_ids, ops.logical_not(padding_mask) - ) - end_locations = ops.cast(end_locations, "int32") - # Use cumsum to get ones in all locations after end_locations. - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") - overflow = cumsum - end_locations - # Our padding mask is the inverse of these overflow locations. - padding_mask = ops.logical_not(ops.cast(overflow, "bool")) - else: - # Without early stopping, all locations will have been updated. - padding_mask = ops.ones_like(token_ids, dtype="bool") - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - def score( self, token_ids, diff --git a/keras_hub/src/samplers/sampler.py b/keras_hub/src/samplers/sampler.py index 44c4168375..e92e2526f9 100644 --- a/keras_hub/src/samplers/sampler.py +++ b/keras_hub/src/samplers/sampler.py @@ -1,4 +1,5 @@ import keras +from keras import config from keras import ops from keras import random @@ -92,32 +93,67 @@ def __call__( # `ops.while_loop` will not accept `None` as a value for `loop_vars`. cache = () if cache is None else cache - # OpenVINO requires all parameters to be passed in the body. - # So we pass `mask` as well. - def cond(prompt, cache, index, mask): - if stop_token_ids is None: - return ops.convert_to_tensor(True, dtype="bool") - # Stop if all sequences have produced a *new* id from - # stop_token_ids. - end_tokens = any_equal(prompt, stop_token_ids, ~mask) - prompt_done = ops.any(end_tokens, axis=-1) - return ops.logical_not(ops.all(prompt_done)) - - def body(prompt, cache, index, mask): - # Compute the softmax distribution for the next token. - logits, _, cache = next(prompt, cache, index) - probabilities = self.compute_probabilities(logits) - # Compute the next token. - next_token = self.get_next_token(probabilities) - # Don't overwrite anywhere mask is True. - next_token = ops.cast(next_token, prompt.dtype) - next_token = ops.where(mask[:, index], prompt[:, index], next_token) - # Update the prompt with the next token. - next_token = next_token[:, None] - prompt = ops.slice_update(prompt, [0, index], next_token) - - # Return the next prompt, cache and incremented index. - return (prompt, cache, index + 1, mask) + _is_torch = config.backend() == "torch" + + if _is_torch: + import torch + + # CRITICAL: Convert index to Python int for torch backend. + # torch.Tensor index causes ~2.5x slowdown in call_with_cache + # due to tensor arithmetic overhead propagating through all + # downstream operations (position embedding, cache slicing, etc.). + index = int(index) + + # Torch-optimized cond/body using native ops to avoid + # ops dispatch overhead (~1.5ms savings per iteration). + def cond(prompt, cache, index, mask): + if stop_token_ids is None: + return True + end_tokens = any_equal(prompt, stop_token_ids, ~mask) + prompt_done = end_tokens.any(dim=-1) + return not prompt_done.all().item() + + def body(prompt, cache, index, mask): + logits, _, cache = next(prompt, cache, index) + probabilities = self.compute_probabilities(logits) + next_token = self.get_next_token(probabilities) + # Don't overwrite anywhere mask is True. + next_token = next_token.to(dtype=prompt.dtype) + next_token = torch.where( + mask[:, index], prompt[:, index], next_token + ) + # Update the prompt with the next token (in-place). + prompt[:, index] = next_token + return (prompt, cache, index + 1, mask) + else: + # OpenVINO requires all parameters to be passed in the body. + # So we pass `mask` as well. + def cond(prompt, cache, index, mask): + if stop_token_ids is None: + return ops.convert_to_tensor(True, dtype="bool") + # Stop if all sequences have produced a *new* id from + # stop_token_ids. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) + prompt_done = ops.any(end_tokens, axis=-1) + return ops.logical_not(ops.all(prompt_done)) + + def body(prompt, cache, index, mask): + # Compute the softmax distribution for the next token. + logits, _, cache = next(prompt, cache, index) + probabilities = self.compute_probabilities(logits) + # Compute the next token. + next_token = self.get_next_token(probabilities) + # Don't overwrite anywhere mask is True. + next_token = ops.cast(next_token, prompt.dtype) + next_token = ops.where( + mask[:, index], prompt[:, index], next_token + ) + # Update the prompt with the next token. + next_token = next_token[:, None] + prompt = ops.slice_update(prompt, [0, index], next_token) + + # Return the next prompt, cache and incremented index. + return (prompt, cache, index + 1, mask) prompt, _, _, _ = self.run_loop( cond, @@ -134,8 +170,19 @@ def compute_probabilities(self, logits): This will always be done in full precision, regardless of dtype, and scale by `temperature`. """ - logits = ops.cast(logits, "float32") - return keras.activations.softmax(logits / self.temperature) + # Fast path for torch: avoid ops dispatch overhead. + if config.backend() == "torch": + import torch + + logits_f32 = logits.to(dtype=torch.float32) + if self.temperature != 1.0: + logits_f32 = logits_f32 / self.temperature + return torch.nn.functional.softmax(logits_f32, dim=-1) + + logits_scaled = ops.cast(logits, "float32") + if self.temperature != 1.0: + logits_scaled = logits_scaled / self.temperature + return keras.activations.softmax(logits_scaled) def run_loop( self, cond, body, model=None, loop_vars=None, maximum_iterations=None @@ -200,6 +247,8 @@ def stateless_body(state, *loop_vars): for ref_v, v in zip(self.variables, state[0]): ref_v.assign(v) else: + # Use ops.while_loop for all other backends + # The PyTorch backend's while_loop is now optimized in Keras Core loop_vars = ops.while_loop( cond=cond, body=body, diff --git a/keras_hub/src/samplers/top_k_sampler.py b/keras_hub/src/samplers/top_k_sampler.py index e0e4a339c5..4969f1a98a 100644 --- a/keras_hub/src/samplers/top_k_sampler.py +++ b/keras_hub/src/samplers/top_k_sampler.py @@ -1,3 +1,4 @@ +from keras import config from keras import ops from keras import random @@ -47,6 +48,32 @@ def __init__( self.seed_generator = random.SeedGenerator(seed) def get_next_token(self, probabilities): + # Fast path for torch backend: use native torch ops to avoid + # ops dispatch overhead (saves ~2ms per iteration). + if config.backend() == "torch": + import torch + + top_k_pred, top_k_indices = torch.topk( + probabilities, k=self.k, sorted=False + ) + # torch.multinomial on MPS/CPU with tiny tensors (batch=1, k=5) + # is much faster on CPU (~0.03ms vs ~1.7ms on MPS). + # For CUDA, keep on device. + device = top_k_pred.device + if device.type == "mps": + top_k_cpu = top_k_pred.to(device="cpu", dtype=torch.float32) + sample_indices = torch.multinomial(top_k_cpu, num_samples=1).to( + device=device + ) + else: + sample_indices = torch.multinomial( + top_k_pred.to(dtype=torch.float32), num_samples=1 + ) + # Gather the original token indices. + output = torch.gather(top_k_indices, 1, sample_indices) + return output.squeeze(-1) + + # Default path for JAX/TF: use keras ops. # Filter out top-k tokens. top_k_pred, top_k_indices = ops.top_k( probabilities, diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index e5be430e1d..cb3e150666 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -397,10 +397,34 @@ def any_equal(inputs, values, padding_mask): a value from any `values`. Padding mask will be applied before returning. """ - output = ops.equal(inputs, values[0]) - for value in values[1:]: - value_equality = ops.equal(inputs, value) - output = ops.logical_or(output, value_equality) + # Fast path for torch backend: use native torch ops to avoid + # ops dispatch overhead (~0.5ms saving per call). + if keras.config.backend() == "torch": + import torch + + def _to_comparable(v): + if isinstance(v, (int, float)): + return v + if not isinstance(v, torch.Tensor): + return torch.tensor(v, device=inputs.device) + return v + + v0 = _to_comparable(values[0]) + if len(values) == 1: + output = inputs.eq(v0) + else: + output = inputs.eq(v0) + for value in values[1:]: + output = output | inputs.eq(_to_comparable(value)) + return output & padding_mask + + # Fast path for single stop token (most common case). + if len(values) == 1: + output = ops.equal(inputs, values[0]) + else: + output = ops.equal(inputs, values[0]) + for value in values[1:]: + output = ops.logical_or(output, ops.equal(inputs, value)) return ops.logical_and(output, padding_mask)