-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultimodal_chunker.py
581 lines (488 loc) · 23.6 KB
/
multimodal_chunker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
import base64
import logging
import os
import re
from typing import Generator
from tools import BlobClient
from ..exceptions import UnsupportedFormatError
from .doc_analysis_chunker import DocumentAnalysisChunker
class MultimodalChunker(DocumentAnalysisChunker):
"""
`MultimodalChunker` processes documents containing both text and figures.
It splits the document into chunks, replaces figure tags with identifiers,
attaches figures to the appropriate chunks, and handles storage and captioning.
"""
def __init__(
self,
data: dict,
max_chunk_size: int | None = None,
minimum_chunk_size: int | None = None,
token_overlap: int | None = None,
):
super().__init__(data, max_chunk_size, minimum_chunk_size, token_overlap)
self.image_container = os.getenv("STORAGE_CONTAINER_IMAGES", "documents-images")
self.storage_account_name = os.getenv(
"STORAGE_ACCOUNT_NAME", "set-storage-account-name-env-var"
)
self.minimum_figure_area_percentage = float(
os.getenv("MINIMUM_FIGURE_AREA_PERCENTAGE", "4.0")
)
def get_chunks(self) -> list[dict]:
"""
Splits the document content into chunks based on the specified format and criteria.
Returns:
list[dict]: A list of dictionaries, each representing a chunk of the document content.
Raises:
UnsupportedFormatError: If the document's extension is not supported.
Exception: If there are errors during document analysis.
"""
if self.extension not in self.supported_formats:
raise UnsupportedFormatError(
f"[MultimodalChunker] [get_chunks] {self.extension} format is not supported."
)
logging.info(
f"[MultimodalChunker] [get_chunks] [{self.filename}] Running `get_chunks`."
)
document, analysis_errors = self._analyze_document_with_retry()
if analysis_errors:
formatted_errors = ", ".join(map(str, analysis_errors))
raise Exception(
f"Error in `DocumentAnalysisChunker` analyzing {self.filename}: {formatted_errors}."
)
chunks = self._process_document_chunks(document)
return chunks
def _process_document_chunks(self, document: dict) -> list[dict]:
"""
Processes the document by replacing figure tags, creating text chunks,
and attaching figures to the corresponding chunks.
Args:
document (dict): The analyzed document containing content and figures.
Returns:
list[dict]: A list of processed document chunks.
"""
# 1. Replace <figure>...</figure> with <figure{id}> in sequence
if "figures" in document and document["figures"]:
document["content"] = self._replace_figures_in_sequence(
document["content"], document["figures"]
)
# 2. Create text chunks
chunks = self._create_text_chunks(document)
# 3. Attach figures to chunks
self._attach_figures_to_chunks(document, chunks)
return chunks # ensure chunks are returned
def _replace_figures_in_sequence(self, content: str, figures: list[dict]) -> str:
"""
Replace all occurrences of <figure>...</figure> with <figure{id}> in the order
of the figures list. If we run out of <figure> tags or figures, we stop.
Args:
content (str): The document content containing figure tags.
figures (list[dict]): A list of figure dictionaries with `id` keys.
Returns:
str: The updated content with figure tags replaced by identifiers.
"""
for fig in figures:
figure_id = fig.get("id")
if not figure_id:
continue
start_index = content.find("<figure>")
if start_index == -1:
break # no more <figure> tags
end_index = content.find("</figure>", start_index)
if end_index == -1:
break # malformed or missing closing </figure>
# Replace everything from <figure> to </figure> with <figure{id}>
content = (
content[:start_index]
+ f"<figure{figure_id}>"
+ content[end_index + len("</figure>") :]
)
return content
def _create_text_chunks(self, document: dict) -> list[dict]:
"""
Splits the document content into chunks based on specified format and criteria.
Args:
document (dict): The document containing content to be chunked.
Returns:
list[dict]: A list of chunk dictionaries with content and metadata.
"""
chunks = []
document_content = document["content"]
document_content = self._number_pagebreaks(document_content)
# Returns a Generator that yields the truncated chunks, chunk sizes when iterated over
text_chunks = self._chunk_content(document_content)
chunk_id = 0
skipped_chunks = 0
current_page = 1
for text_chunk, num_tokens, chunk_offset, _ in text_chunks:
current_page = self._update_page(text_chunk, current_page)
chunk_page = self._determine_chunk_page(text_chunk, current_page)
if num_tokens >= self.minimum_chunk_size:
chunk_id += 1
chunk = self._create_chunk(
chunk_id=chunk_id,
content=text_chunk,
page=chunk_page,
offset=chunk_offset,
)
chunks.append(chunk)
else:
skipped_chunks += 1
logging.debug(
f"[MultimodalChunker] [_create_text_chunks] [{self.filename}] {len(chunks)} chunk(s) created."
)
if skipped_chunks > 0:
logging.debug(
f"[MultimodalChunker] [_create_text_chunks] [{self.filename}] {skipped_chunks} chunk(s) skipped."
)
return chunks
def _chunk_content(
self, content: str
) -> Generator[tuple[str, int, int, int], None, None]:
"""
Splits the document content into chunks based on the specified format and criteria.
Args:
content (str): The content to chunk.
Yields:
Generator[tuple[str, int, int, int], None, None]:
A tuple containing the chunked content, the number of tokens in the chunk (i.e. the chunk size),
the offset and the length of the chunk.
"""
# Split the content according to the Document Intelligence version
splitter = self._choose_splitter()
chunks = splitter.split_text(content)
offset = 0
for chunked_content in chunks:
chunk_size = self.token_estimator.estimate_tokens(chunked_content)
chunk_length = len(chunked_content)
yield chunked_content, chunk_size, offset, chunk_length
offset += chunk_length
def _attach_figures_to_chunks(self, document: dict, chunks: list[dict]) -> None:
"""
Associates figures from the document with their corresponding text chunks.
by scanning each chunk for `<figureX.Y>` placeholders.
For each figure reference in a chunk:
1. Retrieve the figure from `document["figures"]` by ID
2. Upload the image to Blob Storage
3. Generate descriptions (captions)
4. Generate embeddings
5. Build one combined caption string that references all figures in this chunk
6. Attach caption and embeddings to the chunk via the method `_append_figures_to_chunk`
Args:
document (dict): The document containing the figures.
chunks (list[dict]): The list of dictionaries as chunks.
"""
if "figures" not in document or not document["figures"]:
logging.info(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] No figures to attach."
)
return
result_id = document.get("result_id")
model_id = document.get("model_id")
if not result_id or not model_id:
logging.warning(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Missing `result_id` or `model_id` in document analysis results."
)
return
logging.info(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Attaching figures to chunks using "
f"`result_id`: {result_id} and `model_id`: {model_id}."
)
# Create a quick-access dictionary for the figures by their ID
figures_dict = {fig["id"]: fig for fig in document["figures"] if "id" in fig}
# Regex to find all <figureX.Y> (or <figureX> if single integer)
figure_tag_pattern = re.compile(r"<figure(\d+(?:\.\d+)*)>")
for chunk in chunks:
chunk_content = chunk.get("content", "")
figure_refs = figure_tag_pattern.findall(chunk_content)
if not figure_refs:
# No figure references in this chunk; move to the next
continue
# Build arrays to store references for this chunk
figure_urls = []
figure_descriptions = []
for figure_id in figure_refs:
# Attempt to find the figure in the dictionary
figure = figures_dict.get(figure_id)
if not figure:
logging.warning(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Figure with id={figure_id} not found in `document['figures']`."
)
chunk_content = chunk_content.replace(f"<figure{figure_id}>", "")
continue
try:
# 1. Check dimensions
figure_area_percentage = round(
self._figure_area(figure, document["pages"]), 2
)
if figure_area_percentage <= self.minimum_figure_area_percentage:
logging.warning(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Image for figure {figure_id} "
f"has insufficient percentual area ({figure_area_percentage}). Skipping."
)
chunk_content = chunk_content.replace(
f"<figure{figure_id}>", ""
)
continue
# 2. Fetch the figure image
image_binary = self.docint_client.get_figure(
model_id, result_id, figure_id
)
if not image_binary:
logging.warning(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] No image data retrieved for figure {figure_id}."
)
chunk_content = chunk_content.replace(
f"<figure{figure_id}>", ""
)
continue
# Check dimensions
# image = Image.open(io.BytesIO(image_binary))
# width, height = image.size
# pixel_area = width * height
# if pixel_area <= self.minimum_pixel_area:
# logging.warning(
# f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Image for figure {figure_id} "
# f"has insufficient pixel area ({pixel_area}). Skipping."
# )
# chunk_content = chunk_content.replace(f"<figure{figure_id}>", "")
# continue
# 3. Upload to blob
blob_name = f"{self.filename}-figure-{figure_id}.png"
url = self._upload_figure_blob(image_binary, blob_name)
# 4. Generate caption
logging.info(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Generating caption for figure {figure_id}. Percent area: {figure_area_percentage}."
)
figure_caption = self._generate_caption_for_figure(
{
"id": figure_id,
"image": base64.b64encode(image_binary).decode("utf-8"),
"blob_name": blob_name,
}
)
# Store references
figure_urls.append(url)
figure_descriptions.append(f"[{blob_name}]: {figure_caption}")
# Replace <figureX.Y> with a simpler marker or remove it
chunk_content = chunk_content.replace(
f"<figure{figure_id}>", f"<figure>{blob_name}</figure>"
)
except Exception as e:
logging.error(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Error processing figure {figure_id}: {str(e)}."
)
# Update the chunk content with placeholders updated
chunk["content"] = chunk_content
# 5. Build the combined caption string
# Example:
# [myfile-figure-1.1.png]: figure (myfile-figure-1.1.png) description: ...
# [myfile-figure-1.2.png]: figure (myfile-figure-1.2.png) description: ...
combined_caption = "\n".join(figure_descriptions)
caption_vector = self.aoai_client.get_embeddings(combined_caption)
# 6. Attach everything to the chunk
if figure_urls or combined_caption:
self._append_figures_to_chunk(
chunk, figure_urls, combined_caption, caption_vector
)
logging.info(
f"[MultimodalChunker] [_attach_figures_to_chunks] [{self.filename}] Attached {len(figure_urls)} figures to chunk {chunk['chunk_id']}."
)
def _figure_area(self, figure: dict, pages: list[dict]) -> float:
"""
Calculate the total figure area by summing the areas of all bounding regions across pages.
Args:
figure (dict): A dictionary representing the figure with `boundingRegions`,
where each bounding region contains `pageNumber` and `polygon`.
pages (list[dict]): A list of page dictionaries each containing `pageNumber`, `width`, and `height`.
Returns:
float: The total area of all valid bounding regions across pages.
Returns 0.0 if no valid bounding regions are found or an error occurs.
"""
total_area = 0.0
# Ensure `boundingRegions` exists in the figure
bounding_regions = figure.get("boundingRegions", [])
if not bounding_regions:
logging.warning(
f"[MultimodalChunker] [_figure_area] [{self.filename}] No `boundingRegions` found in figure."
)
return total_area # returns 0.0
# Create a lookup dictionary for pages to optimize performance
page_lookup = {page["pageNumber"]: page for page in pages}
for idx, bounding_region in enumerate(bounding_regions, start=1):
try:
# Extract bounding region details
page_number = bounding_region["pageNumber"]
polygon = bounding_region["polygon"]
except KeyError as e:
logging.error(
f"[MultimodalChunker] [_figure_area] [{self.filename}] Bounding region {idx} is missing key: {e}."
)
continue # skip this bounding region
# Find the corresponding page using the lookup dictionary
page = page_lookup.get(page_number)
if not page:
logging.info(
f"[MultimodalChunker] [_figure_area] [{self.filename}] No matching page found for pageNumber: {page_number} in bounding region {idx}."
)
continue # skip this bounding region
page_width = page.get("width")
page_height = page.get("height")
# Validate page dimensions
if page_width is None or page_height is None:
logging.error(
f"[MultimodalChunker] [_figure_area] [{self.filename}] Page {page_number} is missing 'width' or 'height'."
)
continue # skip this bounding region
try:
# Calculate polygon area using a helper method
polygon_area = self._calculate_polygon_area(polygon)
except ValueError as ve:
logging.error(
f"[MultimodalChunker] [_figure_area] [{self.filename}] Error calculating area for figure on page {page_number}, bounding region {idx}: {ve}."
)
continue # Skip this bounding region
# Optionally, validate that the polygon area does not exceed the page area
page_area = page_width * page_height
if polygon_area > page_area:
logging.warning(
f"[MultimodalChunker] [_figure_area] [{self.filename}] Polygon area {polygon_area:.2f} exceeds page area {page_area:.2f} on page {page_number}, bounding region {idx}."
)
# Depending on requirements, we might choose to:
# - Skip adding this area
# - Cap the polygon area to the page area
# - Include the area as is (current implementation)
# Here, we'll include it.
# Accumulate the total area
total_area += polygon_area
logging.debug(
f"[MultimodalChunker] [_figure_area] [{self.filename}] Figure on Page {page_number}, Bounding Region {idx}: "
f"Polygon Coordinates: {polygon}, "
f"Polygon Area: {polygon_area:.2f}, "
f"Accumulated Total Area: {total_area:.2f}."
)
if total_area == 0.0:
logging.warning(
f"[MultimodalChunker] [_figure_area] [{self.filename}] No valid bounding regions found to calculate total area."
)
return total_area
def _calculate_polygon_area(self, polygon: list[float]) -> float:
"""
Calculate the area of a polygon using the Shoelace formula.
Args:
polygon (list[float]): polygon (List[float]): A list of coordinates [x1, y1, x2, y2, ..., xn, yn].
Returns:
float: The absolute area of the polygon.
"""
if len(polygon) < 6:
raise ValueError("A polygon must have at least 3 points (6 coordinates).")
area = 0.0
num_points = len(polygon) // 2
for i in range(num_points):
x1, y1 = polygon[2 * i], polygon[2 * i + 1]
x2, y2 = (
polygon[2 * ((i + 1) % num_points)],
polygon[2 * ((i + 1) % num_points) + 1],
)
area += (x1 * y2) - (x2 * y1)
return abs(area) / 2.0
def _append_figures_to_chunk(
self,
chunk: dict,
figure_urls: list[str],
combined_caption: str,
caption_vector: list[float],
):
"""
Appends the combined figure data (URLs, a single combined caption string,
and a list of caption vectors) to the chunk.
Args:
chunk (dict): The chunk.
figure_urls (list[str]): The list of figure URLs.
combined_caption (str): The combined caption string.
caption_vector (list[float]): The caption vector.
"""
# 1. Related images (URLs)
if "relatedImages" not in chunk:
chunk["relatedImages"] = []
chunk["relatedImages"].extend(figure_urls)
# 2. Combined caption text
if "imageCaptions" not in chunk:
chunk["imageCaptions"] = ""
if chunk["imageCaptions"]:
chunk["imageCaptions"] += "\n"
chunk["imageCaptions"] += combined_caption
# 3. Assign the caption vector to the chunk
chunk["captionVector"] = caption_vector
def _find_chunks_for_figure(self, figure_id: str, chunks: list[dict]) -> list[dict]:
"""
Searches for all chunks whose content contains <figure{figure_id}>.
Returns a list of such chunks.
Args:
figure_id (str): The identifier of the figure to find.
chunks (list[dict]): The list of text chunks to search within.
Returns:
list[dict]: A list of chunks containing the figure placeholder.
"""
matched_chunks = []
for chunk in chunks:
chunk_content = chunk.get("content", "")
if f"<figure{figure_id}>" in chunk_content:
matched_chunks.append(chunk)
if not matched_chunks:
# Log a warning if no chunks contain <figure{id}>
logging.warning(
f"[MultimodalChunker] [_find_chunks_for_figure] [{self.filename}] Could not find <figure{figure_id}> in any chunk."
)
return matched_chunks
def _upload_figure_blob(self, image_bytes: bytes, blob_name: str) -> str:
"""
Uploads the image bytes to Blob Storage and returns the SAS URL.
Args:
image_bytes (bytes): The binary data of the image to upload.
blob_name (str): The name to assign to the uploaded blob.
Returns:
str: The URL of the uploaded blob, or an empty string if upload fails.
"""
try:
blob_url = f"https://{self.storage_account_name}.blob.core.windows.net/{self.image_container}/{blob_name}"
blob_client = BlobClient(blob_url)
blob_service_client = blob_client.blob_service_client
container_client = blob_service_client.get_container_client(
blob_client.container_name
)
blob_client_instance = container_client.get_blob_client(
blob_client.blob_name
)
blob_client_instance.upload_blob(image_bytes, overwrite=True)
return blob_url
except Exception as e:
logging.error(
f"[MultimodalChunker] [_upload_figure_blob] [{self.filename}] Failed to upload figure {blob_name}: {str(e)}."
)
return ""
def _generate_caption_for_figure(self, figure: dict) -> str:
"""
Generates a caption for the figure using the Azure OpenAI client.
Args:
figure (dict): The figure data containing image information.
Returns:
str: The generated caption for the figure.
"""
try:
caption_prompt = (
"Generate a detailed description of the following figure, including "
"its key elements and context, to optimize it for retrieval purposes. "
"Use no more than 200 words."
)
caption = self.aoai_client.get_completion(
prompt=caption_prompt, image_base64=figure["image"]
)
logging.debug(
f"[MultimodalChunker] [_generate_caption_for_figure] [{self.filename}] Generated caption for figure {figure.get('id', 'unknown')}: {caption}."
)
return caption
except Exception as e:
logging.error(
f"[MultimodalChunker] [_generate_caption_for_figure] [{self.filename}] Failed to generate caption for figure {figure.get('id', 'unknown')}: {str(e)}."
)
return "No caption available."