Skip to content

Performance optimizations#2594

Draft
pctablet505 wants to merge 31 commits intokeras-team:masterfrom
pctablet505:performance-optimizations
Draft

Performance optimizations#2594
pctablet505 wants to merge 31 commits intokeras-team:masterfrom
pctablet505:performance-optimizations

Conversation

@pctablet505
Copy link
Collaborator

This pull request introduces several optimizations for cached autoregressive decoding in GPT-2 and related transformer models, primarily targeting the PyTorch backend. The changes improve generation speed by reducing redundant mask computations, minimizing tensor slicing overhead, and leveraging PyTorch's scaled_dot_product_attention (SDPA) when available. There are also minor improvements to sampler logic and tensor utilities for efficiency.

Transformer and Attention Layer Optimizations

  • Added a fast path in CachedMultiHeadAttention to use PyTorch's SDPA for cached inference, including a runtime check for SDPA availability and an override mechanism for self-attention. [1] [2]
  • Introduced a call_cached method in TransformerDecoder to bypass validation, skip redundant mask computations, and enable SDPA override during autoregressive decoding.
  • Modified compute_causal_mask to use a fast path for single-token generation, avoiding unnecessary tensor operations when output length is 1.

GPT-2 Model Generation Improvements

  • Changed GPT2CausalLM.call_with_cache to precompute and share the causal mask across all decoder layers, reducing repeated mask creation and leveraging in-place cache updates on PyTorch.
  • Updated generate_step to use direct tensor indexing for single-token extraction on PyTorch, avoiding ops.slice overhead.
  • Improved inference performance by using both torch.no_grad() and torch.inference_mode() in the generation wrapper.

Sampler and Utility Enhancements

  • Refactored temperature scaling in Sampler.compute_probabilities for clarity and efficiency, applying division only when temperature is not 1.0.
  • Clarified backend handling in Sampler.stateless_body, noting optimized while loop support for PyTorch.
  • Added a fast path for single stop token detection in any_equal to avoid unnecessary logical operations.

Position Embedding Optimization

  • Added a fast path in PositionEmbedding.call for single-token decoding on PyTorch, using direct indexing to avoid slicing overhead.## Description of the change

Reference

Colab Notebook

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

pctablet505 and others added 29 commits April 17, 2025 10:26
Added checks for invalid inputs
Added tests to check invalid inputs
Fix for model not loading when using numpy behaviour with tensorflow
Casts indices to int32 before using them in ops.take_along_axis to prevent type mismatch issues in non-TensorFlow backends. This improves compatibility and avoids potential runtime errors.
Replaces direct access to the _keras_mask attribute with the get_keras_mask utility in TokenAndPositionEmbeddingTest. This improves compatibility with changes in Keras mask handling.
- Added SDPA override for self-attention in CachedMultiHeadAttention
- Fast path for position embedding single-token decoding
- New call_cached() method in TransformerDecoder for decoder-only inference
- Optimized causal mask computation for autoregressive decoding
- Direct tensor indexing in GPT-2 call_with_cache
- Temperature division skip in sampler when temp=1.0
- Fast path for single stop token in any_equal
- Added torch.inference_mode() for better PyTorch performance

Benchmark: PyTorch 1.58s vs JAX 0.98s (1.61x gap)
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @pctablet505, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on significant performance optimizations for cached autoregressive decoding in GPT-2 and similar transformer models, primarily targeting the PyTorch backend. The changes aim to enhance generation speed by reducing computational overheads associated with mask calculations, minimizing tensor slicing operations, and integrating highly optimized PyTorch functionalities like scaled_dot_product_attention. These improvements lead to a more efficient and faster inference process for language models.

Highlights

  • Transformer and Attention Layer Optimizations: Introduced a fast path in CachedMultiHeadAttention to leverage PyTorch's scaled_dot_product_attention (SDPA) for cached inference, including a runtime availability check and an override mechanism. A new call_cached method in TransformerDecoder bypasses validation and redundant mask computations, enabling SDPA during autoregressive decoding. The compute_causal_mask function now has a fast path for single-token generation, avoiding unnecessary tensor operations.
  • GPT-2 Model Generation Improvements: The GPT2CausalLM.call_with_cache method now precomputes and shares the causal mask across all decoder layers, significantly reducing repeated mask creation. It also leverages in-place cache updates on PyTorch for efficiency. generate_step was updated to use direct tensor indexing for single-token extraction on PyTorch, minimizing ops.slice overhead. Inference performance is further boosted by using both torch.no_grad() and torch.inference_mode() in the generation wrapper.
  • Sampler and Utility Enhancements: Temperature scaling in Sampler.compute_probabilities has been refactored for clarity and efficiency, applying division only when the temperature is not 1.0. Backend handling in Sampler.stateless_body now clarifies optimized while_loop support for PyTorch. A fast path was added to any_equal for single stop token detection, avoiding unnecessary logical operations.
  • Position Embedding Optimization: A fast path was added to PositionEmbedding.call for single-token decoding on PyTorch, utilizing direct indexing to avoid slicing overhead.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras_hub/src/layers/modeling/cached_multi_head_attention.py
    • Added _check_torch_sdpa function and _TORCH_SDPA_AVAILABLE global variable to detect PyTorch's scaled_dot_product_attention availability.
    • Overrode _compute_attention to use torch.nn.functional.scaled_dot_product_attention for cached inference on PyTorch when _use_sdpa_override is true.
  • keras_hub/src/layers/modeling/position_embedding.py
    • Added a fast path in call for single-token decoding on PyTorch, using direct indexing instead of ops.slice to avoid overhead.
  • keras_hub/src/layers/modeling/transformer_decoder.py
    • Introduced call_cached method for fast-path autoregressive decoding, skipping validation, cross-attention, and mask computation, and enabling SDPA override.
  • keras_hub/src/layers/modeling/transformer_layer_utils.py
    • Added a fast path to compute_causal_mask for single-token generation (output_length=1), simplifying mask creation.
  • keras_hub/src/models/causal_lm.py
    • Modified wrapped_generate_function to use both torch.no_grad() and torch.inference_mode() for improved inference performance.
  • keras_hub/src/models/gpt2/gpt2_causal_lm.py
    • In call_with_cache, precomputed the causal mask once and shared it across all transformer layers, reducing redundant computations.
    • Optimized cache updates for PyTorch by leveraging in-place updates and using call_cached for transformer layers.
    • In generate_step, added a conditional check to use direct tensor indexing (prompt[:, cache_update_index:cache_update_index + 1]) for single-token extraction on PyTorch, avoiding ops.slice overhead.
    • Changed ops.squeeze to direct indexing [:, 0, :] for logits and hidden states.
  • keras_hub/src/samplers/sampler.py
    • Refactored compute_probabilities to apply temperature division only when self.temperature is not 1.0.
    • Added a comment in stateless_body clarifying that PyTorch backend's while_loop is now optimized in Keras Core.
  • keras_hub/src/utils/tensor_utils.py
    • Added a fast path to any_equal for the common case of a single stop token, avoiding unnecessary logical_or operations.
Activity
  • No human activity (comments, reviews) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a set of significant performance optimizations for autoregressive decoding, mainly for the PyTorch backend. The changes are well-thought-out, leveraging backend-specific features like PyTorch's scaled_dot_product_attention (SDPA) and in-place updates to improve generation speed. The introduction of a call_cached fast path and pre-computation of causal masks are excellent strategies for reducing overhead.

I've found one critical issue in the implementation of the SDPA fast path where the attention mask logic is inverted, which would lead to incorrect attention results. I've provided a suggestion for the fix.

Overall, this is a high-quality contribution that will substantially improve model performance. Great work!


# Convert attention mask to SDPA format.
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=torch.bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The conversion of the attention mask for PyTorch's scaled_dot_product_attention appears to be incorrect. The Keras convention for attention masks is that 1 or True indicates a position should be attended to. However, PyTorch's SDPA attn_mask expects True for positions that should be ignored (masked out).

The current implementation attention_mask.to(dtype=torch.bool) will convert attending positions (1/True) to True, which causes SDPA to ignore them, effectively inverting the attention logic.

To fix this, the boolean mask should be inverted before being passed to scaled_dot_product_attention. This is a critical issue as it will lead to incorrect model outputs during cached inference on the PyTorch backend.

Suggested change
attention_mask = attention_mask.to(dtype=torch.bool)
attention_mask = ~attention_mask.to(dtype=torch.bool)

Keras convention: 1/True = attend, 0/False = don't attend
PyTorch SDPA convention: True = mask out, False = attend
The mask needs to be inverted when passed to scaled_dot_product_attention.
Introduce an ultra-fast cached decoding path and centralize generation logic.

- Add CachedMultiHeadAttention.call_cached to bypass Layer.__call__ and directly invoke sublayer .call() for query/key/value/output dense ops, reducing overhead during cached autoregressive decoding. Also adjust boolean attention-mask handling to pass through without inversion.
- Update TransformerDecoder to use .call() on layer-norms/denses and to call the new attention.call_cached path, avoiding repeated Layer.__call__ overhead in inference.
- Move the per-model generate_step implementations into a single default implementation on CausalLM (with backend-specific optimizations such as direct tensor indexing for torch). Add abstract call_with_cache and _build_cache stubs that subclasses must implement. Remove now-duplicate generate_step code and related any_equal imports from many model files.

These changes reduce runtime overhead during cached inference and consolidate generation behavior across models.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant