Add Gemma 3n model to KerasHub#2606
Conversation
Summary of ChangesHello @laxmareddyp, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request finalizes the integration of the Gemma 3n model into KerasHub, building upon previous foundational work. It introduces crucial architectural elements for multimodal processing, ensuring numerical accuracy against reference implementations and streamlining the overall code for production readiness. The changes enable the model to effectively handle text, image, and audio inputs with optimized performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive implementation of the Gemma 3n model, including its multimodal capabilities for text, image, and audio. While no specific security vulnerabilities were identified, there are critical issues with backend-agnosticism in the preprocessing layers. Specifically, Gemma3nAudioConverter and Gemma3nCausalLMPreprocessor rely on TensorFlow-specific operations (tf.signal, tf.py_function, tf.strings, tf.RaggedTensor) rather than keras.ops, which is required for compatibility with JAX and PyTorch backends according to the KerasHub style guide. These layers must be refactored to use keras.ops to ensure backend-agnosticism. Addressing these issues will make this an excellent contribution.
…he mobilenetv5 presets
sachinprasadhs
left a comment
There was a problem hiding this comment.
Thnak you !
I have reviewed few files and made comments, please check.
| conf_num_attention_heads, | ||
| conf_attention_context_left, | ||
| conf_attention_context_right, |
There was a problem hiding this comment.
conf_num_attention_heads --> num_attention_heads
conf_attention_context_left --> num_attention_context_left
conf_attention_context_right --> num_attention_context_right
|
|
||
| Args: | ||
| hidden_size: int. The size of the hidden state. | ||
| conf_num_attention_heads: int. The number of attention heads. |
There was a problem hiding this comment.
conf_num_attention_heads --> num_attention_heads
There was a problem hiding this comment.
Same changes in arg and in description
| conf_num_attention_heads: int. The number of attention heads. | ||
| conf_attention_chunk_size: int. The size of each processing chunk. | ||
| conf_attention_context_right: int. The number of steps to attend to in | ||
| the future. | ||
| conf_attention_context_left: int. The number of steps to attend to in | ||
| the past, including the current step. | ||
| conf_attention_logit_cap: float. The soft cap value to apply to the |
There was a problem hiding this comment.
conf_num_attention_heads --> num_attention_heads
conf_attention_chunk_size --> num_attention_context_right
conf_attention_context_left --> num_attention_context_left
conf_attention_logit_cap --> attention_logit_cap
| import keras | ||
| import numpy as np |
There was a problem hiding this comment.
Add from keras import ops and then use the ops as ops.xxx instead of keras.ops.xxx everywhere.
| np.arange(num_timescales, dtype="float32") | ||
| * -log_timescale_increment | ||
| ) | ||
| self.inv_timescales = keras.ops.expand_dims( |
There was a problem hiding this comment.
After importing ops, use ops.expand_dims, follow this for all the ops.
| self._allow_non_tensor_positional_args = True | ||
| self.built = True | ||
|
|
||
| def _create_fb_matrix( |
There was a problem hiding this comment.
avoid abbreviation, name it something like _create_filterbank_matrix
| return batch_outputs_features, None | ||
| return batch_outputs_features, batch_outputs_masks | ||
|
|
||
| def call( |
There was a problem hiding this comment.
Add call argument details.
|
|
||
| @keras_hub_export("keras_hub.layers.Gemma3nAudioConverter") | ||
| class Gemma3nAudioConverter(keras.layers.Layer): | ||
| """Converts raw audio waveforms into log-mel spectrograms. |
There was a problem hiding this comment.
Add example usage section.
| "mel_floor": self.mel_floor_arg, | ||
| "per_bin_mean": self.per_bin_mean_arg, | ||
| "per_bin_stddev": self.per_bin_stddev_arg, |
There was a problem hiding this comment.
Keep the names constant and avoid suffix like _arg
| @@ -0,0 +1,580 @@ | |||
| import keras | |||
There was a problem hiding this comment.
Import also from keras import ops and use ops.xxx in the file.
sachinprasadhs
left a comment
There was a problem hiding this comment.
Few more file reviews.
| conf_residual_weight: float. The weight for the residual connection in | ||
| the feed-forward layers. | ||
| conf_num_attention_heads: int. The number of attention heads. | ||
| conf_attention_chunk_size: int. The size of chunks for local attention. | ||
| conf_attention_context_right: int. The right context size for local | ||
| attention. | ||
| conf_attention_context_left: int. The left context size for local | ||
| attention. | ||
| conf_attention_logit_cap: float. The maximum value for the attention | ||
| logits. | ||
| conf_conv_kernel_size: int. The kernel size for the 1D convolution | ||
| layer. |
There was a problem hiding this comment.
Remove this conf_ prefix, not Keras Hub standard.
| def compute_output_shape(self, input_shape): | ||
| audio_encodings_shape, _ = input_shape | ||
| return audio_encodings_shape |
There was a problem hiding this comment.
This is not consistent with the build shape check, this handles only one type.
| current_f_for_block_input = input_feat_size | ||
| self.calculated_block_padding = [] | ||
| self.calculated_f_out_dims = [] | ||
| for i in range(2): |
There was a problem hiding this comment.
Why this is hardcoded? is it always assumed to be 2 for all the configs? or doesn't it has to be length of sscp_conv_kernel_size?
| def build(self, input_shape): | ||
| _, t_in, f_in = input_shape | ||
| conv0_input_shape = (None, 1, t_in, f_in) | ||
| self.conv_0.build(conv0_input_shape) | ||
| if t_in is not None: | ||
| pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0][2:4] | ||
| kernel_h_0, _ = self.sscp_conv_kernel_size[0] | ||
| stride_h_0, _ = self.sscp_conv_stride_size[0] | ||
| t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0 | ||
| t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1 | ||
| else: | ||
| t_out_0 = None | ||
| c_out_0 = self.sscp_conv_channel_size[0] | ||
| f_out_0 = self.calculated_f_out_dims[0] | ||
| conv1_input_shape = (None, c_out_0, t_out_0, f_out_0) | ||
| self.conv_1.build(conv1_input_shape) | ||
| if t_out_0 is not None: | ||
| t_padded_1 = ( | ||
| t_out_0 | ||
| + self.calculated_block_padding[1][2] | ||
| + self.calculated_block_padding[1][3] | ||
| ) | ||
| kernel_h_1, _ = self.sscp_conv_kernel_size[1] | ||
| stride_h_1, _ = self.sscp_conv_stride_size[1] | ||
| t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1 | ||
| else: | ||
| t_out_1 = None | ||
| c_out_1 = self.sscp_conv_channel_size[1] | ||
| f_out_1 = self.calculated_f_out_dims[1] | ||
| proj_input_shape = (None, t_out_1, f_out_1 * c_out_1) | ||
| self.input_proj_linear.build(proj_input_shape) | ||
| super().build(input_shape) | ||
|
|
||
| def compute_output_shape(self, input_shape): | ||
| b, t_in, f_in = input_shape | ||
| if t_in is not None: | ||
| _, _, pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0] | ||
| kernel_h_0, _ = self.sscp_conv_kernel_size[0] | ||
| stride_h_0, _ = self.sscp_conv_stride_size[0] | ||
| t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0 | ||
| t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1 | ||
| _, _, pad_t_top_1, pad_t_bottom_1 = self.calculated_block_padding[1] | ||
| kernel_h_1, _ = self.sscp_conv_kernel_size[1] | ||
| stride_h_1, _ = self.sscp_conv_stride_size[1] | ||
| t_padded_1 = t_out_0 + pad_t_top_1 + pad_t_bottom_1 | ||
| t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1 |
There was a problem hiding this comment.
This is also doping direct indexing of [0] and [1] or based on assumption of sscp_conv_channel_size having only 2 elements.
It's better to add a validation check for sscp_conv_channel_size not to exceed 2 elements and document why it is this way.
| time_stride_product = 1 | ||
| for stride_pair in self.sscp_conv_stride_size: | ||
| time_stride_product *= stride_pair[0] |
There was a problem hiding this comment.
This is not being used in the code.
| max_position_embeddings: int. The maximum sequence length. | ||
| vocab_size_per_layer_input: int. The vocab size for per-layer inputs. | ||
| hidden_size_per_layer_input: int. The hidden size for per-layer inputs. | ||
| altup_num_inputs: int. The number of inputs for the AltUp mechanism. |
There was a problem hiding this comment.
Alternating Updates(AltUp) for better clarity.
| dtype=input_ids_spec.dtype | ||
| if hasattr(input_ids_spec.dtype, "name") | ||
| else "float32", | ||
| ) | ||
| num_layers = self.language_model.num_hidden_layers | ||
| per_layer_hidden_size = self.language_model.hidden_size_per_layer_input | ||
| per_layer_inputs_spec = keras.KerasTensor( | ||
| shape=(batch_size, seq_len, num_layers, per_layer_hidden_size), | ||
| dtype=input_ids_spec.dtype |
There was a problem hiding this comment.
Doesn't it have to be model compute dtype? not the token id dtype?
| inputs_embeds, | ||
| ) | ||
| if self.audio_encoder and self.embed_audio: | ||
| audio_mask = input_ids >= self.embed_audio.vocab_offset |
There was a problem hiding this comment.
This is not consistent with the vision upper bound logic.
| input_data = { | ||
| "token_ids": np.random.randint(0, 50, size=(1, 16), dtype="int32"), | ||
| "attention_mask": np.ones((1, 1, 16, 16), dtype=bool), | ||
| "pixel_values": np.random.rand(1, 1, 224, 224, 3).astype("float32"), |
There was a problem hiding this comment.
Doesn't it have to be "images" instead of "pixel_values" for the model input as per the implementation.
| audio_indices = inputs.get("audio_indices", None) | ||
| vision_mask = inputs.get("vision_mask", None) | ||
| audio_mask = inputs.get("audio_mask", None) | ||
| audios = inputs.get("audios", None) |
There was a problem hiding this comment.
What is the use of this? This is not being used.
sachinprasadhs
left a comment
There was a problem hiding this comment.
Reviewed rest of the files. Please address it and mak the reolved comments as Resolved.
For the Generic comments like naming convention, import stype etc apply it to all the files.
| input_features is not None | ||
| and len(keras.ops.shape(input_features)) == 2 | ||
| ): | ||
| input_features = keras.ops.expand_dims(input_features, axis=0) | ||
| if ( | ||
| input_features_mask is not None | ||
| and len(keras.ops.shape(input_features_mask)) == 1 | ||
| ): | ||
| input_features_mask = keras.ops.expand_dims( | ||
| input_features_mask, axis=0 |
There was a problem hiding this comment.
This logic is contradicting to the docstring, as per the docstring, unbatched input feature should be (num_audios, audio_seq_len, feature_size) but here you are checking only rank 2. Same for input_features_mask
| if len(audios.shape) > 1: | ||
| audios = tf.RaggedTensor.from_tensor(audios) | ||
| else: | ||
| audios = tf.ragged.constant([audios.numpy()], dtype=tf.float32) |
There was a problem hiding this comment.
I suspect .numpy() would fail in graph mode and currently this will not be caught in any of the test.
| ): | ||
| # If a 4D attention mask is passed, | ||
| # squeeze it to 2D for standard processing. | ||
| if padding_mask is not None and len(keras.ops.shape(padding_mask)) == 4: |
There was a problem hiding this comment.
Use static rank for reliable result or to avoid failure in Graph mode len(padding_mask.shape)
| decoder_mask = merge_padding_and_attention_mask( | ||
| inputs=x, padding_mask=padding_mask, attention_mask=None | ||
| ) |
There was a problem hiding this comment.
Here padding_mask which is passed is 4D, but in merge_padding_and_attention_mask it is documented as 2D, please check.
| reshape_shape = modalities_shape[:-1] + ( | ||
| self.altup_num_inputs, | ||
| self.altup_num_inputs, | ||
| ) |
There was a problem hiding this comment.
It's better to use ops.concatenate
| MODEL_CONFIGS = {"mobilenetv5_300m_enc": mobilenetv5_config} | ||
|
|
||
|
|
||
| def convert_model(hf_config, dtype=None): |
There was a problem hiding this comment.
Is it possible to move convert weights and configs under utils/transformers and keep the validation and other codes here?
Description of the change
This PR completes the implementation of the Gemma 3n model, building upon the foundations laid in #2404.
It introduces critical architectural features, ensures numerical accuracy against the reference implementation, and streamlines the codebase for production readiness.
Note: Special thanks to @harshaljanjani for the initial work and foundations laid in #2404.
Key Changes & Improvements:
KV Sharing Implementation:
Causal Masking:
Numerical Parity Fixes:
Code Refactoring & Cleanup:
Relationship to Previous Work:
Reference
Colab Notebook
Numerical Verification Results:
Text-only Validation
Multimodal Validation (text + image + audio):
Text (14 positions): mean=0.00126, max=0.01021
Vision (257 positions): mean=0.00072, max=0.00865
Audio (189 positions): mean=0.00094, max=0.02665
Note on tolerance
Gemma3n has a uniquely deep architecture — 30 decoder layers with AltUp (4-way prediction/correction), Laurel blocks, and per-layer input gating.
Cross-framework float32 rounding differences (JAX/XLA vs PyTorch) accumulate ~5.6e-06 per layer, compounding to ~4.5e-04 at the logit level.
Layer-by-layer debugging confirmed that input embeddings match perfectly (0.00 diff) and error grows linearly through the decoder stack — there is no implementation bug.
At atol=1e-3, 99.7% match, At atol=1e-4, approximately 70% of logit elements match.
The 100% token prediction match at every position confirms the conversion is functionally correct.
Note on Parameter Count Mismatch:
At atol=1e-3:


At atol=1e-4:
Checklist