Skip to content

Commit c348dc3

Browse files
committed
various type fixes
1 parent a512a07 commit c348dc3

6 files changed

Lines changed: 23 additions & 17 deletions

File tree

src/mmore/colpali/retriever.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ def load_model(model_name: str, device: str):
6262
Tuple of (model, processor) ready for inference
6363
"""
6464
logger.info(f"Loading ColPali model: {model_name}")
65+
66+
bfloat16: torch.dtype = torch.bfloat16
6567
model = ColPali.from_pretrained(
6668
model_name,
67-
torch_dtype=torch.bfloat16,
69+
torch_dtype=bfloat16,
6870
device_map=device,
6971
).eval()
7072
processor = ColPaliProcessor.from_pretrained(model_name)
@@ -95,7 +97,7 @@ def embed_queries(texts: List[str], model, processor) -> List[np.ndarray]:
9597
with torch.no_grad():
9698
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
9799
emb = model(**batch_query)
98-
vectors.extend(list(torch.unbind(emb.to("cpu"))))
100+
vectors.extend(list(emb.to("cpu").unbind()))
99101
return [v.float().numpy() for v in vectors]
100102

101103

src/mmore/colpali/run_process.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def cleanup(self):
8686
class ColPaliEmbedder:
8787
def __init__(self, model_name: str = "vidore/colpali-v1.3", device: str = "cuda:0"):
8888
self.device = device
89-
dtype = torch.bfloat16
89+
bfloat16: torch.dtype = torch.bfloat16
9090
self.model = ColPali.from_pretrained(
91-
model_name, torch_dtype=dtype, device_map=device
91+
model_name, torch_dtype=bfloat16, device_map=device
9292
).eval()
9393
self.processor = ColPaliProcessor.from_pretrained(model_name)
9494

@@ -112,7 +112,7 @@ def embed_images(
112112
with torch.no_grad():
113113
batch_doc = {k: v.to(self.model.device) for k, v in batch_doc.items()}
114114
embeddings_doc = self.model(**batch_doc)
115-
ds.extend(list(torch.unbind(embeddings_doc.to(self.device))))
115+
ds.extend(list(embeddings_doc.to(self.device).unbind()))
116116
ds_np = [d.float().cpu().numpy() for d in ds]
117117
return ds_np
118118

src/mmore/process/processors/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,13 @@ def __getstate__(self):
216216
del state["_pool"]
217217
return state
218218

219-
def __setstate__(self, state):
219+
def __setstate__(self, state: Dict[str, Any]):
220220
"""
221221
Called when the object is unpickled (received by the worker).
222222
We restore the state and set _pool to None (workers don't need the pool manager).
223223
"""
224-
self.__dict__.update(state)
224+
for key, value in state.items():
225+
setattr(self, key, value)
225226
# Initialize _pool as None in the worker process
226227
self._pool = None
227228
# Workers should never own the pool

src/mmore/process/processors/media_processor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import logging
22
import os
33
import tempfile
4-
from typing import List
4+
from typing import List, cast
55

6+
import numpy as np
67
import torch
78
from moviepy.audio.io.AudioFileClip import AudioFileClip
89
from moviepy.video.io.VideoFileClip import VideoFileClip
910
from PIL import Image
11+
from torch._C import device as torch_device
1012
from transformers.pipelines import pipeline as pipeline_t
1113

1214
from ...type import FileDescriptor, MultimodalSample
@@ -19,10 +21,10 @@ class MediaProcessor(Processor):
1921
@staticmethod
2022
def _get_available_devices():
2123
if torch.cuda.is_available():
22-
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
24+
return [torch_device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
2325
if torch.backends.mps.is_available():
24-
return [torch.device("mps")]
25-
return [torch.device("cpu")]
26+
return [torch_device("mps")]
27+
return [torch_device("cpu")]
2628

2729
devices = _get_available_devices()
2830
pipelines = []
@@ -155,7 +157,7 @@ def _extract_video_frames(file_path: str) -> List[Image.Image]:
155157
for i in range(num_thumbnails):
156158
t = min(i * sample_rate, duration - 0.1)
157159
frame = clip.get_frame(t)
158-
image = Image.fromarray(frame).convert("RGB")
160+
image = Image.fromarray(cast(np.ndarray, frame).convert("RGB"))
159161
images.append(image)
160162
logger.info(f"Extracted {len(images)} images from {file_path}.")
161163
except Exception as e:

src/mmore/process/processors/pdf_processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import re
44
from multiprocessing import Manager, Process, set_start_method
5-
from typing import List, Optional, Tuple, cast
5+
from typing import Any, Dict, List, Optional, Tuple, cast
66

77
import pymupdf
88
import torch
@@ -148,7 +148,7 @@ def process(self, file_path: str) -> MultimodalSample:
148148

149149
paragraph_starts, text = self._parse_pagination(cast(str, text))
150150

151-
metadata = {"file_path": file_path}
151+
metadata: Dict[str, Any] = {"file_path": file_path}
152152
if paragraph_starts:
153153
metadata["paragraph_starts"] = paragraph_starts
154154

@@ -218,7 +218,7 @@ def _extract_images(pdf_doc, xref) -> Optional[Image.Image]:
218218
if image_bytes is None:
219219
logging.error(f"No image data found for xref {xref}")
220220

221-
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
221+
return Image.open(io.BytesIO(cast(bytes, image_bytes))).convert("RGB")
222222

223223
except KeyError as e:
224224
logging.error(f"KeyError while extracting image: {e}")
@@ -236,7 +236,7 @@ def _extract_images(pdf_doc, xref) -> Optional[Image.Image]:
236236
)
237237
return None
238238

239-
for page_num, page in enumerate(pdf_doc):
239+
for page_num, page in enumerate(pdf_doc): # pyright: ignore[reportArgumentType]
240240
text = clean_text(page.get_text()) # type: ignore[attr-defined]
241241

242242
if text.strip():

src/mmore/rag/model/dense/multimodal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
class MultimodalEmbeddings(Embeddings):
1515
def __init__(self, model_name: str):
1616
super().__init__()
17+
float16: torch.dtype = torch.float16
1718
self.model = AutoModelForImageTextToText.from_pretrained(
18-
model_name, torch_dtype=torch.float16, device_map="auto"
19+
model_name, torch_dtype=float16, device_map="auto"
1920
)
2021
self.processor = AutoProcessor.from_pretrained(model_name)
2122
self.device = self.model.device

0 commit comments

Comments
 (0)