Skip to content

Comments

Fix Gemma3 compile dtype handling#462

Open
danielhanchen wants to merge 1 commit intomainfrom
fix/gemma3-compile-dtype
Open

Fix Gemma3 compile dtype handling#462
danielhanchen wants to merge 1 commit intomainfrom
fix/gemma3-compile-dtype

Conversation

@danielhanchen
Copy link
Contributor

Summary

  1. Add lm_head input dtype casting in compiled logits paths to match lm_head weight dtype.
  2. Update Gemma3 MLP to keep native dtype unless force float32 is set and cast to down_proj compute dtype when needed.
  3. Set UnslothSFTTrainer _is_vlm based on processing_class.image_processor.

Testing

  1. Ran FunctionGemma 270M Mobile Actions notebook with transformers 4.57.6 and 5.0.0 using trl 0.27.1.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines data type handling within the Gemma3 model compilation and training components. It ensures correct data type alignment for lm_head operations, optimizes the Gemma3 MLP's floating-point precision management, and improves the UnslothSFTTrainer's ability to detect Vision-Language Models. These changes collectively enhance the robustness and efficiency of Gemma3 model compilation and training workflows.

Highlights

  • LM Head Dtype Casting: Implemented input dtype casting for lm_head in compiled logits paths to ensure it matches the lm_head weight dtype, preventing potential type mismatches and improving numerical stability.
  • Gemma3 MLP Dtype Handling: Updated the Gemma3 MLP's forward method to preserve native data types for intermediate calculations unless UNSLOTH_FORCE_FLOAT32 is explicitly set. It now conditionally casts the intermediate tensor to the down_proj's compute dtype (or weight dtype) only when necessary, optimizing precision management.
  • UnslothSFTTrainer VLM Detection: Modified the UnslothSFTTrainer's _is_vlm attribute to be dynamically determined based on the presence of an image_processor in the processing_class, rather than being hardcoded to False, enhancing flexibility for Vision-Language Models.
  • Flex Attention Environment Variable Removal: Removed the _UNSLOTH_FLEX_ATTENTION_DISABLED environment variable check and related logic from gemma.py, simplifying attention implementation selection and reducing reliance on environment variables.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth_zoo/compiler.py
    • Added logic to dynamically set UnslothSFTTrainer._is_vlm based on processing_class.image_processor.
    • Introduced dtype casting for lm_head input to match lm_head.weight dtype in compiled logits paths.
  • unsloth_zoo/temporary_patches/gemma.py
    • Removed the _UNSLOTH_FLEX_ATTENTION_DISABLED environment variable and its associated conditional logic.
    • Modified Gemma3MLP.forward to handle data types more flexibly: it now maintains native dtypes unless UNSLOTH_FORCE_FLOAT32 is active, and casts to down_proj's compute dtype only when required.
Activity
  • The pull request author, danielhanchen, has implemented these changes.
  • Testing was performed by running the FunctionGemma 270M Mobile Actions notebook with transformers 4.57.6 and 5.0.0 using trl 0.27.1.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4b62c658ba

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 497 to 499
attn_impl = getattr(self.config, "_attn_implementation", "sdpa")
if _UNSLOTH_FLEX_ATTENTION_DISABLED:
attn_impl = "sdpa"
if attn_impl == "flex_attention":
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore flex-attention disable override

This change drops the UNSLOTH_ENABLE_FLEX_ATTENTION=0 guard, so Gemma3 now still executes the flex_attention path whenever self.config._attn_implementation is set to "flex_attention". In environments where users explicitly disable flex attention due to kernel/runtime incompatibility, this becomes a regression from the previous behavior and can trigger runtime failures instead of falling back to SDPA; the same removal appears in both Gemma3 attention patch variants.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several fixes for Gemma3 dtype handling, which seem correct and align with the PR's objectives. I've identified a few areas for improvement. In unsloth_zoo/compiler.py, there's duplicated code for lm_head input casting that could be refactored for better maintainability. In unsloth_zoo/temporary_patches/gemma.py, I've noted a performance concern with accessing an environment variable on a hot path and suggested a cleaner way to handle a try-except block. Addressing these points will enhance the code's performance and readability.

intermediate_fp32 = activated_fp32 * up_proj_fp32 # Product in fp32
def forward(self, x):
# If forcing float32, keep the original float32 path.
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing the environment variable with os.environ.get() inside the forward method is inefficient because it's on a hot path and involves repeated system calls. This check should be performed only once when patch_Gemma3MLP is called. You can store the result in a variable within the patch_Gemma3MLP function's scope, and the forward function can then access it from its closure.

For example:

def patch_Gemma3MLP():
    # ...
    _force_float32 = os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1"

    def forward(self, x):
        if _force_float32:
            # ...

Comment on lines +987 to +993
_lm_head_input = hidden_states\\1
if torch.is_floating_point(_lm_head_input):
_lm_head_weight = self.lm_head.weight
_lm_head_dtype = getattr(_lm_head_weight, "dtype", None)
if _lm_head_dtype is not None and _lm_head_input.dtype != _lm_head_dtype:
_lm_head_input = _lm_head_input.to(_lm_head_dtype)
logits = self.lm_head(_lm_head_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for lm_head input dtype casting is duplicated in three places within this file (here, at lines 1071-1077, and 1110-1116). This repetition makes the code harder to maintain and prone to inconsistencies if updates are needed. Consider defining this logic once in a shared string variable and reusing it in each template to improve maintainability.

Comment on lines +331 to +334
try:
is_float = torch.is_floating_point(torch.empty((), dtype=target_dtype))
except Exception:
is_float = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unrelated errors and make debugging more difficult. Additionally, creating a temporary tensor with torch.empty is unnecessary. You can call torch.is_floating_point(target_dtype) directly and catch a more specific TypeError, which is raised for invalid dtype arguments. This will make the code cleaner and more robust.

Suggested change
try:
is_float = torch.is_floating_point(torch.empty((), dtype=target_dtype))
except Exception:
is_float = False
try:
is_float = torch.is_floating_point(target_dtype)
except TypeError:
is_float = False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant