Conversation
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines data type handling within the Gemma3 model compilation and training components. It ensures correct data type alignment for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4b62c658ba
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| attn_impl = getattr(self.config, "_attn_implementation", "sdpa") | ||
| if _UNSLOTH_FLEX_ATTENTION_DISABLED: | ||
| attn_impl = "sdpa" | ||
| if attn_impl == "flex_attention": | ||
| attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] |
There was a problem hiding this comment.
Restore flex-attention disable override
This change drops the UNSLOTH_ENABLE_FLEX_ATTENTION=0 guard, so Gemma3 now still executes the flex_attention path whenever self.config._attn_implementation is set to "flex_attention". In environments where users explicitly disable flex attention due to kernel/runtime incompatibility, this becomes a regression from the previous behavior and can trigger runtime failures instead of falling back to SDPA; the same removal appears in both Gemma3 attention patch variants.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces several fixes for Gemma3 dtype handling, which seem correct and align with the PR's objectives. I've identified a few areas for improvement. In unsloth_zoo/compiler.py, there's duplicated code for lm_head input casting that could be refactored for better maintainability. In unsloth_zoo/temporary_patches/gemma.py, I've noted a performance concern with accessing an environment variable on a hot path and suggested a cleaner way to handle a try-except block. Addressing these points will enhance the code's performance and readability.
| intermediate_fp32 = activated_fp32 * up_proj_fp32 # Product in fp32 | ||
| def forward(self, x): | ||
| # If forcing float32, keep the original float32 path. | ||
| if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": |
There was a problem hiding this comment.
Accessing the environment variable with os.environ.get() inside the forward method is inefficient because it's on a hot path and involves repeated system calls. This check should be performed only once when patch_Gemma3MLP is called. You can store the result in a variable within the patch_Gemma3MLP function's scope, and the forward function can then access it from its closure.
For example:
def patch_Gemma3MLP():
# ...
_force_float32 = os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1"
def forward(self, x):
if _force_float32:
# ...| _lm_head_input = hidden_states\\1 | ||
| if torch.is_floating_point(_lm_head_input): | ||
| _lm_head_weight = self.lm_head.weight | ||
| _lm_head_dtype = getattr(_lm_head_weight, "dtype", None) | ||
| if _lm_head_dtype is not None and _lm_head_input.dtype != _lm_head_dtype: | ||
| _lm_head_input = _lm_head_input.to(_lm_head_dtype) | ||
| logits = self.lm_head(_lm_head_input) |
There was a problem hiding this comment.
This block of code for lm_head input dtype casting is duplicated in three places within this file (here, at lines 1071-1077, and 1110-1116). This repetition makes the code harder to maintain and prone to inconsistencies if updates are needed. Consider defining this logic once in a shared string variable and reusing it in each template to improve maintainability.
| try: | ||
| is_float = torch.is_floating_point(torch.empty((), dtype=target_dtype)) | ||
| except Exception: | ||
| is_float = False |
There was a problem hiding this comment.
Using a broad except Exception: can hide unrelated errors and make debugging more difficult. Additionally, creating a temporary tensor with torch.empty is unnecessary. You can call torch.is_floating_point(target_dtype) directly and catch a more specific TypeError, which is raised for invalid dtype arguments. This will make the code cleaner and more robust.
| try: | |
| is_float = torch.is_floating_point(torch.empty((), dtype=target_dtype)) | |
| except Exception: | |
| is_float = False | |
| try: | |
| is_float = torch.is_floating_point(target_dtype) | |
| except TypeError: | |
| is_float = False |
Summary
Testing