-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathmoondream_causal_lm.py
More file actions
37 lines (29 loc) · 1.05 KB
/
moondream_causal_lm.py
File metadata and controls
37 lines (29 loc) · 1.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import keras
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.moondream.moondream_backbone import MoondreamBackbone
from keras_hub.src.models.moondream.moondream_preprocessor import \
MoondreamPreprocessor
@keras_hub_export("keras_hub.models.MoondreamCausalLM")
class MoondreamCausalLM(CausalLM):
backbone_cls = MoondreamBackbone
preprocessor_cls = MoondreamPreprocessor
def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
inputs = getattr(backbone, "input", None)
super().__init__(**kwargs)
# Manually set the attributes
self.backbone = backbone
self.preprocessor = preprocessor
# Set tensor spec if available
if inputs is not None:
self.input_tensor_spec = inputs
def call(self, inputs, training=False):
if self.backbone is None:
raise ValueError("Backbone not initialized")
x = self.backbone(inputs)
return x