Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/src/models/moondream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
from keras_hub.src.models.moondream.moondream_preprocessor import \
MoondreamPreprocessor
68 changes: 68 additions & 0 deletions keras_hub/src/models/moondream/moondream_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.MoondreamBackbone")
class MoondreamBackbone(Backbone):
def __init__(self, vision_encoder, text_decoder, projection_dim=2048, **kwargs):
super().__init__(**kwargs)

self.vision_encoder = vision_encoder
self.text_decoder = text_decoder

# The Connector
self.vision_projection = keras.layers.Dense(
projection_dim, name="vision_projection"
)

def call(self, inputs):
images = inputs["images"]
token_ids = inputs["token_ids"]
padding_mask = inputs["padding_mask"]

# 1. Image Features
image_features = self.vision_encoder(images)

# 2. Project
projected_images = self.vision_projection(image_features)

# 3. Text Embeddings
text_embeddings = self.text_decoder.get_input_embeddings(token_ids)

# 4. Concatenate
combined_embeddings = ops.concatenate(
[projected_images, text_embeddings], axis=1
)

# 5. Masking
batch_size = ops.shape(images)[0]
num_patches = ops.shape(projected_images)[1]

image_mask = ops.ones((batch_size, num_patches), dtype="bool")
combined_mask = ops.concatenate([image_mask, padding_mask], axis=1)

# 6. Decoder Pass
# Now compatible with our Subclass Mock Decoder
outputs = self.text_decoder(
inputs=None,
decoder_inputs_embeds=combined_embeddings,
padding_mask=combined_mask,
)

return outputs

def get_config(self):
config = super().get_config()
config.update(
{
"vision_encoder": keras.saving.serialize_keras_object(
self.vision_encoder
),
"text_decoder": keras.saving.serialize_keras_object(self.text_decoder),
"projection_dim": self.vision_projection.units,
}
)
return config
37 changes: 37 additions & 0 deletions keras_hub/src/models/moondream/moondream_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
from keras_hub.src.models.moondream.moondream_preprocessor import \
MoondreamPreprocessor


@keras_hub_export("keras_hub.models.MoondreamCausalLM")
class MoondreamCausalLM(CausalLM):
backbone_cls = MoondreamBackbone
preprocessor_cls = MoondreamPreprocessor

def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
inputs = getattr(backbone, "input", None)

super().__init__(**kwargs)

# Manually set the attributes
self.backbone = backbone
self.preprocessor = preprocessor

# Set tensor spec if available
if inputs is not None:
self.input_tensor_spec = inputs

def call(self, inputs, training=False):
if self.backbone is None:
raise ValueError("Backbone not initialized")
x = self.backbone(inputs)
return x
57 changes: 57 additions & 0 deletions keras_hub/src/models/moondream/moondream_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor


@keras_hub_export("keras_hub.models.MoondreamPreprocessor")
class MoondreamPreprocessor(CausalLMPreprocessor):
def __init__(
self,
tokenizer,
image_converter=None,
sequence_length=1024,
add_start_token=True,
add_end_token=True,
**kwargs,
):
super().__init__(
tokenizer=tokenizer,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
**kwargs,
)
self.image_converter = image_converter

def call(self, x, y=None, sample_weight=None):
output = super().call(x, y, sample_weight)

# 1. Identify the input dictionary from the output
# If output is a tuple (x, y, sw), the first element is the input dict.
if isinstance(output, tuple):
x_out = output[0]
else:
x_out = output

# 2. Type Guard for Pylance
# We explicitly check if x_out IS a dictionary.
# This stops Pylance from thinking it might be a Tuple/List.
if isinstance(x_out, dict) and isinstance(x, dict) and "images" in x:
images = x["images"]
if self.image_converter:
images = self.image_converter(images)
x_out["images"] = images

return output

def get_config(self):
config = super().get_config()
config.update(
{
"image_converter": keras.saving.serialize_keras_object(
self.image_converter
),
}
)
return config
Loading