Skip to content

Conversation

@SauravMaheshkar
Copy link
Contributor

@SauravMaheshkar SauravMaheshkar commented May 18, 2025

Adds T5 Implementation (#34)

  • fixes linter/type-checker warnings (ruff + mypy). Affected Files:
    • jaxgarden/__init__.py
    • jaxgarden/attention/rope_multi_head_attention.py
    • jaxgarden/models/base.py
    • tests/attention/test_RoPEMultiHeadAttention.py
  • also fixes Fix type checking with class mixins #28. Affected Files:
    • jaxgarden/models/gemma2.py
    • jaxgarden/models/llama.py
  • makes minor modifications to jaxgarden/tokenization.py to handle edge cases of special tokens

References

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @SauravMaheshkar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

As gemini-code-assist, I've reviewed this pull request. The primary goal of this PR is to introduce an initial implementation of the T5 model architecture. This includes adding core components like the T5 configuration, layer normalization, and the feed-forward network (MLP). Alongside the new T5 code, the PR also includes minor formatting adjustments and type hinting fixes in existing files, as well as a change to the base class order for the Gemma2 and Llama models.

Highlights

  • New T5 Model Components: Adds T5Config, T5LayerNorm, and T5MLP implementations using Flax NNX in a new file jaxgarden/models/t5.py.
  • Model Base Class Order: Changes the inheritance order for Gemma2ForCausalLM and LlamaForCausalLM to prioritize GenerationMixin over BaseModel.
  • Formatting and Type Hinting: Includes minor code formatting improvements (e.g., multi-line function signatures, spacing) and adds type: ignore comments for type checking workarounds in jaxgarden/__init__.py and jaxgarden/models/llama.py. Also updates devcontainer settings for code actions.
  • New Tests: Adds a new test file (tests/models/test_t5.py) specifically for the new T5 LayerNorm and MLP components.

Changelog

Click here to see the changelog
  • .devcontainer/devcontainer.json
    • Updated source.organizeImports and source.fixAll settings from "true" to "always".
  • jaxgarden/init.py
    • Added # type: ignore comment to the Tokenizer import.
  • jaxgarden/attention/rope_multi_head_attention.py
    • Reformated precompute_rotary_embeddings function signature.
    • Minor formatting and comment adjustments in setup and __call__ methods.
    • Removed an extra blank line.
    • Adjusted formatting for mask handling logic comments.
  • jaxgarden/models/gemma2.py
    • Swapped the base class order for Gemma2ForCausalLM from BaseModel, GenerationMixin to GenerationMixin, BaseModel.
  • jaxgarden/models/llama.py
    • Swapped the base class order for LlamaForCausalLM from BaseModel, GenerationMixin to GenerationMixin, BaseModel.
    • Added # type: ignore comment to the attention_mask line in the generate method.
  • jaxgarden/models/t5.py
    • Added new file implementing T5Config, T5LayerNorm, and T5MLP.
  • tests/attention/test_RoPEMultiHeadAttention.py
    • Removed an extra blank line.
    • Minor formatting adjustments in test comments.
  • tests/models/test_t5.py
    • Added new file with tests for T5LayerNorm and T5MLP.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


New model arrives,
T5 joins the garden,
Code grows, tests pass.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a T5 model implementation, which is a valuable addition. The new T5 components (T5Config, T5LayerNorm, T5MLP) and their corresponding tests are well-structured. The PR also includes various refactorings and fixes across other model files, with many changes being stylistic improvements that enhance code consistency.

I've identified a few areas that require attention, primarily concerning potential runtime errors due to type handling and the implications of changes to class inheritance order. Addressing these will help ensure the robustness and clarity of the codebase.

Summary of Findings

  • Potential runtime error with None attention mask in Llama model: In jaxgarden/models/llama.py, the attention_mask parameter can be None but is used directly in jnp.where. This could lead to a runtime error if attention_mask is indeed None at that point. This was commented on with high severity.
  • Class inheritance order changes: The Method Resolution Order (MRO) for Gemma2ForCausalLM and LlamaForCausalLM has changed due to reordering of base classes. It's important to confirm this has no unintended side effects on method/attribute resolution. This was commented on with medium severity.
  • Suppressed type error in __init__.py: A # type: ignore was added to jaxgarden/__init__.py for the Tokenizer import. Clarification on why this is necessary would be helpful to understand if there's an underlying typing issue. This was commented on with medium severity.
  • Devcontainer configuration update: In .devcontainer/devcontainer.json, editor.codeActionsOnSave values for source.organizeImports and source.fixAll were changed from string "true" to "always". This is a minor preference/explicitness improvement and seems fine. (Not commented directly due to review severity settings).
  • Stylistic cleanups: Numerous stylistic cleanups (formatting, blank lines, comment spacing) were applied, particularly in jaxgarden/attention/rope_multi_head_attention.py and its tests. These improve code consistency and are generally good. (Not commented directly due to review severity settings).
  • New T5 Implementation and Tests: The new T5 model components (T5Config, T5LayerNorm, T5MLP) in jaxgarden/models/t5.py and the corresponding tests in tests/models/test_t5.py are well-implemented and a good addition to the library. (No issues found, positive feedback).

Merge Readiness

The pull request introduces significant new functionality with the T5 implementation and includes several useful refactorings. However, there is a high-severity issue in jaxgarden/models/llama.py concerning the handling of a potentially None attention mask that could lead to runtime errors. Additionally, there are medium-severity questions about class inheritance order and a suppressed type error that should be clarified.

I recommend that these issues, particularly the high-severity one, be addressed before merging. I am not authorized to approve pull requests, but based on this review, changes are requested.

print(input_ids.shape)
position_ids = jnp.arange(input_ids.shape[-1])[None, :].astype(jnp.int32)
attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...]
attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...] # type: ignore

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attention_mask parameter in the __call__ method is typed as jnp.ndarray | None. However, this line uses attention_mask directly as the condition in jnp.where. If attention_mask is None when this line is reached, jnp.where(None, ...) will result in a runtime error. The added # type: ignore might be suppressing a type checker warning related to this or the use of ....

To prevent a potential runtime error, how is the None case for attention_mask intended to be handled here? If None implies, for example, that no tokens are padded (i.e., all tokens are valid), the mask should effectively be all Trues before being converted to the additive float mask.

A robust way to handle this would be to ensure attention_mask is a boolean jnp.ndarray before this line. For example:

# Preceding line: position_ids = ...

_current_boolean_mask: jnp.ndarray
if attention_mask is None:
    # Default to all True if None (no padding)
    _current_boolean_mask = jnp.ones_like(input_ids, dtype=jnp.bool_)
elif attention_mask.dtype != jnp.bool_:
    # Coerce to boolean if it's, for instance, int32 from GenerationMixin
    _current_boolean_mask = attention_mask.astype(jnp.bool_)
else:
    _current_boolean_mask = attention_mask

# Now use _current_boolean_mask in jnp.where
# This line (514) would then become something like:
# attention_mask_additive = jnp.where(_current_boolean_mask, 0.0, -jnp.inf)[None, None, ...]
# And then use attention_mask_additive in the loop for layers.

This change would involve modifying how attention_mask is processed and passed to the layers. Could you clarify the expected behavior or update the handling for the None case?

@monatis
Copy link
Member

monatis commented May 19, 2025

Hi @SauravMaheshkar, thanks for this PR! Do you think it's ready for review, or are you continuing to work on it?

@SauravMaheshkar
Copy link
Contributor Author

Hi @SauravMaheshkar, thanks for this PR! Do you think it's ready for review, or are you continuing to work on it?

Still working on this, will tag you for review once ready 😃

@SauravMaheshkar SauravMaheshkar marked this pull request as ready for review May 23, 2025 01:09
@SauravMaheshkar
Copy link
Contributor Author

@monatis PR Ready for Review 😃

@SauravMaheshkar
Copy link
Contributor Author

@monatis gentle ping

@monatis
Copy link
Member

monatis commented Jun 2, 2025

Hi @SauravMaheshkar sorry missed the notificaiton 😄 LGTM, I'll be merging it shortly. Thanks for the contribution

Copy link
Member

@monatis monatis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also thanks for additional fixes

@monatis monatis merged commit afb228d into ml-gde:main Jun 2, 2025
6 checks passed
@SauravMaheshkar SauravMaheshkar mentioned this pull request Jun 2, 2025
@SauravMaheshkar SauravMaheshkar deleted the saurav/add-t5 branch June 2, 2025 10:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix type checking with class mixins

2 participants