-
Notifications
You must be signed in to change notification settings - Fork 330
Add Moondream architecture skeleton #2553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BharathC0
wants to merge
2
commits into
keras-team:master
Choose a base branch
from
BharathC0:moondream-architecture
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone | ||
| from keras_hub.src.models.moondream.moondream_preprocessor import ( | ||
| MoondreamPreprocessor, | ||
| ) | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| import keras | ||
| from keras import ops | ||
|
|
||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.backbone import Backbone | ||
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.MoondreamBackbone") | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class MoondreamBackbone(Backbone): | ||
| """ | ||
| The Moondream Backbone model. | ||
|
|
||
| This model connects a vision encoder (SigLIP) and a text decoder (Phi-1.5) | ||
| using a projection layer. It is designed for vision-language tasks where | ||
| image features are projected into the text embedding space. | ||
|
|
||
| Args: | ||
| vision_encoder: A Keras model (e.g., SigLIP). The vision encoder | ||
| responsible for processing input images. | ||
| text_decoder: A Keras model (e.g., Phi-1.5). The text decoder | ||
| responsible for generating text tokens. | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| projection_dim: int. The dimension to project image features into. | ||
| Defaults to `2048`. | ||
| **kwargs: Standard Keras keyword arguments. | ||
|
|
||
| Example: | ||
| ```python | ||
| import keras | ||
| import numpy as np | ||
| from keras_hub.src.models.moondream.moondream_backbone import ( | ||
| MoondreamBackbone | ||
| ) | ||
|
|
||
| # 1. Create Mock Encoders | ||
| # Vision Encoder: Maps (378, 378, 3) -> (729, 1152) | ||
| image_input = keras.Input(shape=(378, 378, 3)) | ||
| vision_output = keras.layers.Lambda( | ||
| lambda x: keras.ops.ones((keras.ops.shape(x)[0], 729, 1152)) | ||
| )(image_input) | ||
| vision_encoder = keras.Model(inputs=image_input, outputs=vision_output) | ||
|
|
||
| # Text Decoder: Maps (Seq,) -> (Seq, 2048) | ||
| text_input = keras.Input(shape=(None,), dtype="int32") | ||
| text_output = keras.layers.Lambda( | ||
| lambda x: keras.ops.ones( | ||
| (keras.ops.shape(x)[0], keras.ops.shape(x)[1], 2048) | ||
| ) | ||
| )(text_input) | ||
| text_decoder = keras.Model(inputs=text_input, outputs=text_output) | ||
|
|
||
| # Helper for embeddings | ||
| text_decoder.get_input_embeddings = lambda x: keras.layers.Embedding( | ||
| 50000, 2048 | ||
| )(x) | ||
|
|
||
| # 2. Instantiate Backbone | ||
| backbone = MoondreamBackbone( | ||
| vision_encoder=vision_encoder, | ||
| text_decoder=text_decoder, | ||
| projection_dim=2048 | ||
| ) | ||
|
|
||
| # 3. Run Forward Pass | ||
| inputs = { | ||
| "images": np.random.rand(2, 378, 378, 3), | ||
| "token_ids": np.random.randint(0, 50000, (2, 10)), | ||
| "padding_mask": np.ones((2, 10)) | ||
| } | ||
| outputs = backbone(inputs) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, vision_encoder, text_decoder, projection_dim=2048, **kwargs | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.vision_encoder = vision_encoder | ||
| self.text_decoder = text_decoder | ||
| self.projection_dim = projection_dim | ||
|
|
||
| self.vision_projection = keras.layers.Dense( | ||
| projection_dim, name="vision_projection" | ||
| ) | ||
|
|
||
| images = keras.Input(shape=(None, None, 3), name="images") | ||
| token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
| padding_mask = keras.Input( | ||
| shape=(None,), dtype="int32", name="padding_mask" | ||
| ) | ||
|
|
||
| inputs = { | ||
| "images": images, | ||
| "token_ids": token_ids, | ||
| "padding_mask": padding_mask, | ||
| } | ||
|
|
||
| image_features = self.vision_encoder(images) | ||
| projected_images = self.vision_projection(image_features) | ||
|
|
||
| text_embeddings = self.text_decoder.get_input_embeddings(token_ids) | ||
|
|
||
| combined_embeddings = ops.concatenate( | ||
| [projected_images, text_embeddings], axis=1 | ||
| ) | ||
|
|
||
| batch_size = ops.shape(images)[0] | ||
| num_patches = ops.shape(projected_images)[1] | ||
|
|
||
| image_mask = ops.ones((batch_size, num_patches), dtype="int32") | ||
| combined_mask = ops.concatenate([image_mask, padding_mask], axis=1) | ||
|
|
||
| outputs = self.text_decoder( | ||
| inputs=None, | ||
| decoder_inputs_embeds=combined_embeddings, | ||
| padding_mask=combined_mask, | ||
| ) | ||
|
|
||
| super(MoondreamBackbone, self).__init__( | ||
| inputs=inputs, outputs=outputs, **kwargs | ||
| ) | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "vision_encoder": keras.saving.serialize_keras_object( | ||
| self.vision_encoder | ||
| ), | ||
| "text_decoder": keras.saving.serialize_keras_object( | ||
| self.text_decoder | ||
| ), | ||
| "projection_dim": self.projection_dim, | ||
| } | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| return config | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.causal_lm import CausalLM | ||
| from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone | ||
| from keras_hub.src.models.moondream.moondream_preprocessor import ( | ||
| MoondreamPreprocessor, | ||
| ) | ||
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.MoondreamCausalLM") | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class MoondreamCausalLM(CausalLM): | ||
| """ | ||
| An end-to-end Moondream model for causal language modeling. | ||
|
|
||
| This model wraps `MoondreamBackbone` and handles the complete flow from | ||
| raw inputs (images + text) to generated text output. It provides a | ||
| high-level interface for image captioning and visual question answering. | ||
|
|
||
| Args: | ||
| backbone: A `MoondreamBackbone` instance. The backbone model that | ||
| connects the vision encoder and text decoder. | ||
| preprocessor: A `MoondreamPreprocessor` instance. Handles data | ||
| preprocessing (tokenization and image resizing). | ||
| **kwargs: Standard Keras keyword arguments. | ||
|
|
||
| Example: | ||
| ```python | ||
| import keras | ||
| import numpy as np | ||
| from keras_hub.src.models.moondream.moondream_backbone import ( | ||
| MoondreamBackbone | ||
| ) | ||
| from keras_hub.src.models.moondream.moondream_causal_lm import ( | ||
| MoondreamCausalLM | ||
| ) | ||
|
|
||
| # 1. Setup Mock Backbone | ||
| images = keras.Input(shape=(None, None, 3), name="images") | ||
| token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
| padding_mask = keras.Input( | ||
| shape=(None,), dtype="int32", name="padding_mask" | ||
| ) | ||
|
|
||
| outputs = keras.layers.Dense(2048)(token_ids) | ||
|
|
||
| backbone = keras.Model( | ||
| inputs={ | ||
| "images": images, | ||
| "token_ids": token_ids, | ||
| "padding_mask": padding_mask | ||
| }, | ||
| outputs=outputs | ||
| ) | ||
|
|
||
| # 2. Instantiate CausalLM | ||
| model = MoondreamCausalLM(backbone=backbone) | ||
|
|
||
| # 3. Run Forward Pass | ||
| inputs = { | ||
| "images": np.random.rand(2, 378, 378, 3), | ||
| "token_ids": np.random.randint(0, 100, (2, 10)), | ||
| "padding_mask": np.ones((2, 10)) | ||
| } | ||
| outputs = model(inputs) | ||
| ``` | ||
| """ | ||
|
|
||
| backbone_cls = MoondreamBackbone | ||
| preprocessor_cls = MoondreamPreprocessor | ||
|
|
||
| def __init__( | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| backbone, | ||
| preprocessor=None, | ||
| **kwargs, | ||
| ): | ||
| inputs = backbone.input | ||
| outputs = backbone(inputs) | ||
|
|
||
| super().__init__( | ||
| inputs=inputs, | ||
| outputs=outputs, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| self.backbone = backbone | ||
| self.preprocessor = preprocessor | ||
136 changes: 136 additions & 0 deletions
136
keras_hub/src/models/moondream/moondream_preprocessor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| import keras | ||
|
|
||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor | ||
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.MoondreamPreprocessor") | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class MoondreamPreprocessor(CausalLMPreprocessor): | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Moondream Causal LM Preprocessor. | ||
|
|
||
| This class handles the preprocessing of images and text for the Moondream | ||
| model. It combines image resizing/rescaling logic with text tokenization | ||
| to prepare inputs for the model. | ||
|
|
||
| Args: | ||
| tokenizer: The tokenizer to be used for text inputs. | ||
| image_converter: An optional layer or callable for image preprocessing | ||
| (e.g., resizing, normalization). | ||
| sequence_length: int. The context length for tokenization. | ||
| Defaults to 1024. | ||
| add_start_token: bool. Whether to add the start token. | ||
| Defaults to True. | ||
| add_end_token: bool. Whether to add the end token. | ||
| Defaults to True. | ||
| **kwargs: Standard Keras keyword arguments. | ||
|
|
||
| Example: | ||
| ```python | ||
| import keras | ||
| import numpy as np | ||
| from keras_hub.src.models.moondream.moondream_preprocessor import ( | ||
| MoondreamPreprocessor | ||
| ) | ||
|
|
||
| # 1. Create a Mock Tokenizer | ||
| class MockTokenizer: | ||
| def __call__(self, x): | ||
| return keras.ops.convert_to_tensor([[1, 2, 3]] * len(x)) | ||
| def detokenize(self, x): | ||
| return x | ||
| pass | ||
|
|
||
| tokenizer = MockTokenizer() | ||
|
|
||
| # 2. Create an Image Converter | ||
| image_converter = keras.layers.Resizing(height=378, width=378) | ||
|
|
||
| # 3. Instantiate Preprocessor | ||
| preprocessor = MoondreamPreprocessor( | ||
| tokenizer=tokenizer, | ||
| image_converter=image_converter, | ||
| sequence_length=128 | ||
| ) | ||
|
|
||
| # 4. Preprocess Data | ||
| inputs = { | ||
| "images": np.random.randint(0, 255, (2, 500, 500, 3)), | ||
| "text": ["Describe this image.", "What is in the photo?"] | ||
| } | ||
|
|
||
| outputs = preprocessor(inputs) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| tokenizer, | ||
| image_converter=None, | ||
| sequence_length=1024, | ||
| add_start_token=True, | ||
| add_end_token=True, | ||
| **kwargs, | ||
| ): | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| super().__init__( | ||
| tokenizer=tokenizer, | ||
| sequence_length=sequence_length, | ||
| add_start_token=add_start_token, | ||
| add_end_token=add_end_token, | ||
| **kwargs, | ||
| ) | ||
| self.image_converter = image_converter | ||
|
|
||
| def call(self, x, y=None, sample_weight=None): | ||
| if isinstance(x, dict): | ||
| text_input = x.get("text", "") | ||
| images = x.get("images", None) | ||
| else: | ||
| text_input = x | ||
| images = None | ||
|
|
||
| output = super().call(text_input, y=y, sample_weight=sample_weight) | ||
|
|
||
| if isinstance(output, tuple): | ||
| x_out = output[0] | ||
| else: | ||
| x_out = output | ||
|
|
||
| if images is not None: | ||
| if self.image_converter: | ||
| images = self.image_converter(images) | ||
|
|
||
| if isinstance(x_out, dict): | ||
| x_out["images"] = images | ||
|
|
||
| return output | ||
|
|
||
| def generate_preprocess(self, x, sequence_length=None): | ||
| if isinstance(x, dict): | ||
| text_input = x.get("text", "") | ||
| images = x.get("images", None) | ||
| else: | ||
| text_input = x | ||
| images = None | ||
|
|
||
| output = super().generate_preprocess( | ||
| text_input, sequence_length=sequence_length | ||
| ) | ||
|
|
||
| if images is not None: | ||
| if self.image_converter: | ||
| images = self.image_converter(images) | ||
| output["images"] = images | ||
|
|
||
| return output | ||
BharathC0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "image_converter": keras.saving.serialize_keras_object( | ||
| self.image_converter | ||
| ), | ||
| } | ||
| ) | ||
| return config | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.