-
Notifications
You must be signed in to change notification settings - Fork 11
feat: add T5 implementation #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @SauravMaheshkar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
As gemini-code-assist, I've reviewed this pull request. The primary goal of this PR is to introduce an initial implementation of the T5 model architecture. This includes adding core components like the T5 configuration, layer normalization, and the feed-forward network (MLP). Alongside the new T5 code, the PR also includes minor formatting adjustments and type hinting fixes in existing files, as well as a change to the base class order for the Gemma2 and Llama models.
Highlights
- New T5 Model Components: Adds
T5Config,T5LayerNorm, andT5MLPimplementations using Flax NNX in a new filejaxgarden/models/t5.py. - Model Base Class Order: Changes the inheritance order for
Gemma2ForCausalLMandLlamaForCausalLMto prioritizeGenerationMixinoverBaseModel. - Formatting and Type Hinting: Includes minor code formatting improvements (e.g., multi-line function signatures, spacing) and adds
type: ignorecomments for type checking workarounds injaxgarden/__init__.pyandjaxgarden/models/llama.py. Also updates devcontainer settings for code actions. - New Tests: Adds a new test file (
tests/models/test_t5.py) specifically for the new T5 LayerNorm and MLP components.
Changelog
Click here to see the changelog
- .devcontainer/devcontainer.json
- Updated
source.organizeImportsandsource.fixAllsettings from"true"to"always".
- Updated
- jaxgarden/init.py
- Added
# type: ignorecomment to theTokenizerimport.
- Added
- jaxgarden/attention/rope_multi_head_attention.py
- Reformated
precompute_rotary_embeddingsfunction signature. - Minor formatting and comment adjustments in
setupand__call__methods. - Removed an extra blank line.
- Adjusted formatting for mask handling logic comments.
- Reformated
- jaxgarden/models/gemma2.py
- Swapped the base class order for
Gemma2ForCausalLMfromBaseModel, GenerationMixintoGenerationMixin, BaseModel.
- Swapped the base class order for
- jaxgarden/models/llama.py
- Swapped the base class order for
LlamaForCausalLMfromBaseModel, GenerationMixintoGenerationMixin, BaseModel. - Added
# type: ignorecomment to theattention_maskline in thegeneratemethod.
- Swapped the base class order for
- jaxgarden/models/t5.py
- Added new file implementing
T5Config,T5LayerNorm, andT5MLP.
- Added new file implementing
- tests/attention/test_RoPEMultiHeadAttention.py
- Removed an extra blank line.
- Minor formatting adjustments in test comments.
- tests/models/test_t5.py
- Added new file with tests for
T5LayerNormandT5MLP.
- Added new file with tests for
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
New model arrives,
T5 joins the garden,
Code grows, tests pass.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a T5 model implementation, which is a valuable addition. The new T5 components (T5Config, T5LayerNorm, T5MLP) and their corresponding tests are well-structured. The PR also includes various refactorings and fixes across other model files, with many changes being stylistic improvements that enhance code consistency.
I've identified a few areas that require attention, primarily concerning potential runtime errors due to type handling and the implications of changes to class inheritance order. Addressing these will help ensure the robustness and clarity of the codebase.
Summary of Findings
- Potential runtime error with
Noneattention mask in Llama model: Injaxgarden/models/llama.py, theattention_maskparameter can beNonebut is used directly injnp.where. This could lead to a runtime error ifattention_maskis indeedNoneat that point. This was commented on with high severity. - Class inheritance order changes: The Method Resolution Order (MRO) for
Gemma2ForCausalLMandLlamaForCausalLMhas changed due to reordering of base classes. It's important to confirm this has no unintended side effects on method/attribute resolution. This was commented on with medium severity. - Suppressed type error in
__init__.py: A# type: ignorewas added tojaxgarden/__init__.pyfor theTokenizerimport. Clarification on why this is necessary would be helpful to understand if there's an underlying typing issue. This was commented on with medium severity. - Devcontainer configuration update: In
.devcontainer/devcontainer.json,editor.codeActionsOnSavevalues forsource.organizeImportsandsource.fixAllwere changed from string"true"to"always". This is a minor preference/explicitness improvement and seems fine. (Not commented directly due to review severity settings). - Stylistic cleanups: Numerous stylistic cleanups (formatting, blank lines, comment spacing) were applied, particularly in
jaxgarden/attention/rope_multi_head_attention.pyand its tests. These improve code consistency and are generally good. (Not commented directly due to review severity settings). - New T5 Implementation and Tests: The new T5 model components (
T5Config,T5LayerNorm,T5MLP) injaxgarden/models/t5.pyand the corresponding tests intests/models/test_t5.pyare well-implemented and a good addition to the library. (No issues found, positive feedback).
Merge Readiness
The pull request introduces significant new functionality with the T5 implementation and includes several useful refactorings. However, there is a high-severity issue in jaxgarden/models/llama.py concerning the handling of a potentially None attention mask that could lead to runtime errors. Additionally, there are medium-severity questions about class inheritance order and a suppressed type error that should be clarified.
I recommend that these issues, particularly the high-severity one, be addressed before merging. I am not authorized to approve pull requests, but based on this review, changes are requested.
| print(input_ids.shape) | ||
| position_ids = jnp.arange(input_ids.shape[-1])[None, :].astype(jnp.int32) | ||
| attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...] | ||
| attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention_mask parameter in the __call__ method is typed as jnp.ndarray | None. However, this line uses attention_mask directly as the condition in jnp.where. If attention_mask is None when this line is reached, jnp.where(None, ...) will result in a runtime error. The added # type: ignore might be suppressing a type checker warning related to this or the use of ....
To prevent a potential runtime error, how is the None case for attention_mask intended to be handled here? If None implies, for example, that no tokens are padded (i.e., all tokens are valid), the mask should effectively be all Trues before being converted to the additive float mask.
A robust way to handle this would be to ensure attention_mask is a boolean jnp.ndarray before this line. For example:
# Preceding line: position_ids = ...
_current_boolean_mask: jnp.ndarray
if attention_mask is None:
# Default to all True if None (no padding)
_current_boolean_mask = jnp.ones_like(input_ids, dtype=jnp.bool_)
elif attention_mask.dtype != jnp.bool_:
# Coerce to boolean if it's, for instance, int32 from GenerationMixin
_current_boolean_mask = attention_mask.astype(jnp.bool_)
else:
_current_boolean_mask = attention_mask
# Now use _current_boolean_mask in jnp.where
# This line (514) would then become something like:
# attention_mask_additive = jnp.where(_current_boolean_mask, 0.0, -jnp.inf)[None, None, ...]
# And then use attention_mask_additive in the loop for layers.This change would involve modifying how attention_mask is processed and passed to the layers. Could you clarify the expected behavior or update the handling for the None case?
|
Hi @SauravMaheshkar, thanks for this PR! Do you think it's ready for review, or are you continuing to work on it? |
Still working on this, will tag you for review once ready 😃 |
|
@monatis PR Ready for Review 😃 |
|
@monatis gentle ping |
|
Hi @SauravMaheshkar sorry missed the notificaiton 😄 LGTM, I'll be merging it shortly. Thanks for the contribution |
monatis
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also thanks for additional fixes
Adds T5 Implementation (#34)
ruff+mypy). Affected Files:jaxgarden/__init__.pyjaxgarden/attention/rope_multi_head_attention.pyjaxgarden/models/base.pytests/attention/test_RoPEMultiHeadAttention.pyjaxgarden/models/gemma2.pyjaxgarden/models/llama.pyjaxgarden/tokenization.pyto handle edge cases of special tokensReferences