-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathmoondream_preprocessor.py
More file actions
57 lines (49 loc) · 1.74 KB
/
moondream_preprocessor.py
File metadata and controls
57 lines (49 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import keras
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
@keras_hub_export("keras_hub.models.MoondreamPreprocessor")
class MoondreamPreprocessor(CausalLMPreprocessor):
def __init__(
self,
tokenizer,
image_converter=None,
sequence_length=1024,
add_start_token=True,
add_end_token=True,
**kwargs,
):
super().__init__(
tokenizer=tokenizer,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
**kwargs,
)
self.image_converter = image_converter
def call(self, x, y=None, sample_weight=None):
output = super().call(x, y, sample_weight)
# 1. Identify the input dictionary from the output
# If output is a tuple (x, y, sw), the first element is the input dict.
if isinstance(output, tuple):
x_out = output[0]
else:
x_out = output
# 2. Type Guard for Pylance
# We explicitly check if x_out IS a dictionary.
# This stops Pylance from thinking it might be a Tuple/List.
if isinstance(x_out, dict) and isinstance(x, dict) and "images" in x:
images = x["images"]
if self.image_converter:
images = self.image_converter(images)
x_out["images"] = images
return output
def get_config(self):
config = super().get_config()
config.update(
{
"image_converter": keras.saving.serialize_keras_object(
self.image_converter
),
}
)
return config