Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,15 @@
from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier_preprocessor import (
MobileNetV5ImageClassifierPreprocessor as MobileNetV5ImageClassifierPreprocessor,
)
from keras_hub.src.models.moondream.moondream_backbone import (
MoondreamBackbone as MoondreamBackbone,
)
from keras_hub.src.models.moondream.moondream_causal_lm import (
MoondreamCausalLM as MoondreamCausalLM,
)
from keras_hub.src.models.moondream.moondream_preprocessor import (
MoondreamPreprocessor as MoondreamPreprocessor,
)
from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
MoonshineAudioToText as MoonshineAudioToText,
)
Expand Down
4 changes: 4 additions & 0 deletions keras_hub/src/models/moondream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
from keras_hub.src.models.moondream.moondream_preprocessor import (
MoondreamPreprocessor,
)
135 changes: 135 additions & 0 deletions keras_hub/src/models/moondream/moondream_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.MoondreamBackbone")
class MoondreamBackbone(Backbone):
"""
The Moondream Backbone model.

This model connects a vision encoder (SigLIP) and a text decoder (Phi-1.5)
using a projection layer. It is designed for vision-language tasks where
image features are projected into the text embedding space.

Args:
vision_encoder: A Keras model (e.g., SigLIP). The vision encoder
responsible for processing input images.
text_decoder: A Keras model (e.g., Phi-1.5). The text decoder
responsible for generating text tokens.
projection_dim: int. The dimension to project image features into.
Defaults to `2048`.
**kwargs: Standard Keras keyword arguments.

Example:
```python
import keras
import numpy as np
from keras_hub.src.models.moondream.moondream_backbone import (
MoondreamBackbone
)

# 1. Create Mock Encoders
# Vision Encoder: Maps (378, 378, 3) -> (729, 1152)
image_input = keras.Input(shape=(378, 378, 3))
vision_output = keras.layers.Lambda(
lambda x: keras.ops.ones((keras.ops.shape(x)[0], 729, 1152))
)(image_input)
vision_encoder = keras.Model(inputs=image_input, outputs=vision_output)

# Text Decoder: Maps (Seq,) -> (Seq, 2048)
text_input = keras.Input(shape=(None,), dtype="int32")
text_output = keras.layers.Lambda(
lambda x: keras.ops.ones(
(keras.ops.shape(x)[0], keras.ops.shape(x)[1], 2048)
)
)(text_input)
text_decoder = keras.Model(inputs=text_input, outputs=text_output)

# Helper for embeddings
text_decoder.get_input_embeddings = lambda x: keras.layers.Embedding(
50000, 2048
)(x)

# 2. Instantiate Backbone
backbone = MoondreamBackbone(
vision_encoder=vision_encoder,
text_decoder=text_decoder,
projection_dim=2048
)

# 3. Run Forward Pass
inputs = {
"images": np.random.rand(2, 378, 378, 3),
"token_ids": np.random.randint(0, 50000, (2, 10)),
"padding_mask": np.ones((2, 10))
}
outputs = backbone(inputs)
```
"""

def __init__(
self, vision_encoder, text_decoder, projection_dim=2048, **kwargs
):
super().__init__(**kwargs)
self.vision_encoder = vision_encoder
self.text_decoder = text_decoder
self.projection_dim = projection_dim

self.vision_projection = keras.layers.Dense(
projection_dim, name="vision_projection"
)

images = keras.Input(shape=(None, None, 3), name="images")
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

inputs = {
"images": images,
"token_ids": token_ids,
"padding_mask": padding_mask,
}

image_features = self.vision_encoder(images)
projected_images = self.vision_projection(image_features)

text_embeddings = self.text_decoder.get_input_embeddings(token_ids)

combined_embeddings = ops.concatenate(
[projected_images, text_embeddings], axis=1
)

batch_size = ops.shape(images)[0]
num_patches = ops.shape(projected_images)[1]

image_mask = ops.ones((batch_size, num_patches), dtype="int32")
combined_mask = ops.concatenate([image_mask, padding_mask], axis=1)

outputs = self.text_decoder(
inputs=None,
decoder_inputs_embeds=combined_embeddings,
padding_mask=combined_mask,
)

super(MoondreamBackbone, self).__init__(
inputs=inputs, outputs=outputs, **kwargs
)

def get_config(self):
config = super().get_config()
config.update(
{
"vision_encoder": keras.saving.serialize_keras_object(
self.vision_encoder
),
"text_decoder": keras.saving.serialize_keras_object(
self.text_decoder
),
"projection_dim": self.projection_dim,
}
)
return config
86 changes: 86 additions & 0 deletions keras_hub/src/models/moondream/moondream_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
from keras_hub.src.models.moondream.moondream_preprocessor import (
MoondreamPreprocessor,
)


@keras_hub_export("keras_hub.models.MoondreamCausalLM")
class MoondreamCausalLM(CausalLM):
"""
An end-to-end Moondream model for causal language modeling.

This model wraps `MoondreamBackbone` and handles the complete flow from
raw inputs (images + text) to generated text output. It provides a
high-level interface for image captioning and visual question answering.

Args:
backbone: A `MoondreamBackbone` instance. The backbone model that
connects the vision encoder and text decoder.
preprocessor: A `MoondreamPreprocessor` instance. Handles data
preprocessing (tokenization and image resizing).
**kwargs: Standard Keras keyword arguments.

Example:
```python
import keras
import numpy as np
from keras_hub.src.models.moondream.moondream_backbone import (
MoondreamBackbone
)
from keras_hub.src.models.moondream.moondream_causal_lm import (
MoondreamCausalLM
)

# 1. Setup Mock Backbone
images = keras.Input(shape=(None, None, 3), name="images")
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

outputs = keras.layers.Dense(2048)(token_ids)

backbone = keras.Model(
inputs={
"images": images,
"token_ids": token_ids,
"padding_mask": padding_mask
},
outputs=outputs
)

# 2. Instantiate CausalLM
model = MoondreamCausalLM(backbone=backbone)

# 3. Run Forward Pass
inputs = {
"images": np.random.rand(2, 378, 378, 3),
"token_ids": np.random.randint(0, 100, (2, 10)),
"padding_mask": np.ones((2, 10))
}
outputs = model(inputs)
```
"""

backbone_cls = MoondreamBackbone
preprocessor_cls = MoondreamPreprocessor

def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
inputs = backbone.input
outputs = backbone(inputs)

super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

self.backbone = backbone
self.preprocessor = preprocessor
136 changes: 136 additions & 0 deletions keras_hub/src/models/moondream/moondream_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor


@keras_hub_export("keras_hub.models.MoondreamPreprocessor")
class MoondreamPreprocessor(CausalLMPreprocessor):
"""
Moondream Causal LM Preprocessor.

This class handles the preprocessing of images and text for the Moondream
model. It combines image resizing/rescaling logic with text tokenization
to prepare inputs for the model.

Args:
tokenizer: The tokenizer to be used for text inputs.
image_converter: An optional layer or callable for image preprocessing
(e.g., resizing, normalization).
sequence_length: int. The context length for tokenization.
Defaults to 1024.
add_start_token: bool. Whether to add the start token.
Defaults to True.
add_end_token: bool. Whether to add the end token.
Defaults to True.
**kwargs: Standard Keras keyword arguments.

Example:
```python
import keras
import numpy as np
from keras_hub.src.models.moondream.moondream_preprocessor import (
MoondreamPreprocessor
)

# 1. Create a Mock Tokenizer
class MockTokenizer:
def __call__(self, x):
return keras.ops.convert_to_tensor([[1, 2, 3]] * len(x))
def detokenize(self, x):
return x
pass

tokenizer = MockTokenizer()

# 2. Create an Image Converter
image_converter = keras.layers.Resizing(height=378, width=378)

# 3. Instantiate Preprocessor
preprocessor = MoondreamPreprocessor(
tokenizer=tokenizer,
image_converter=image_converter,
sequence_length=128
)

# 4. Preprocess Data
inputs = {
"images": np.random.randint(0, 255, (2, 500, 500, 3)),
"text": ["Describe this image.", "What is in the photo?"]
}

outputs = preprocessor(inputs)
```
"""

def __init__(
self,
tokenizer,
image_converter=None,
sequence_length=1024,
add_start_token=True,
add_end_token=True,
**kwargs,
):
super().__init__(
tokenizer=tokenizer,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
**kwargs,
)
self.image_converter = image_converter

def call(self, x, y=None, sample_weight=None):
if isinstance(x, dict):
text_input = x.get("text", "")
images = x.get("images", None)
else:
text_input = x
images = None

output = super().call(text_input, y=y, sample_weight=sample_weight)

if isinstance(output, tuple):
x_out = output[0]
else:
x_out = output

if images is not None:
if self.image_converter:
images = self.image_converter(images)

if isinstance(x_out, dict):
x_out["images"] = images

return output

def generate_preprocess(self, x, sequence_length=None):
if isinstance(x, dict):
text_input = x.get("text", "")
images = x.get("images", None)
else:
text_input = x
images = None

output = super().generate_preprocess(
text_input, sequence_length=sequence_length
)

if images is not None:
if self.image_converter:
images = self.image_converter(images)
output["images"] = images

return output

def get_config(self):
config = super().get_config()
config.update(
{
"image_converter": keras.saving.serialize_keras_object(
self.image_converter
),
}
)
return config
Loading