Skip to content

Commit 326f180

Browse files
badGarnetqued
andauthored
Feat: add more output format for table inference (#263)
This PR addresses [CORE-2307](https://unstructured-ai.atlassian.net/browse/CORE-2307) - add a new kwarg to `UnstructuredTableTransformerModel.run_prediction`: `output_format` - default `output_format` is `html`, which is current behavior: output html string representation of the table - another options available is `dataframe`, which returns a pandas dataframe representation of the table - if not specified or any other string value for `output_format` it returns a list of dictionaries: table cell format, the original output format from table transformer - `unstructured.model.tables.recognize` no longer accepts `out_html` kwarg and it now only returns table cell format [CORE-2307]: https://unstructured-ai.atlassian.net/browse/CORE-2307?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ --------- Co-authored-by: qued <[email protected]>
1 parent 2ee38e6 commit 326f180

File tree

5 files changed

+107
-16
lines changed

5 files changed

+107
-16
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
## 0.7.10-dev1
1+
## 0.7.10-dev2
22

33
* fix: Reduce Chipper memory consumption on x86_64 cpus
44
* fix: Skips ordering elements coming from Chipper
55
* fix: After refactoring to introduce Chipper, annotate() weren't able to show text with extra info from elements, this is fixed now.
6+
* feat: add table cell and dataframe output formats to table transformer's `run_prediction` call
7+
* breaking change: function `unstructured_inference.models.tables.recognize` no longer takes `out_html` parameter and it now only returns table cell data format (lists of dictionaries)
68

79
## 0.7.9
810

test_unstructured_inference/conftest.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import pytest
33
from PIL import Image
44

5-
from unstructured_inference.inference.elements import EmbeddedTextRegion, Rectangle, TextRegion
5+
from unstructured_inference.inference.elements import (
6+
EmbeddedTextRegion,
7+
Rectangle,
8+
TextRegion,
9+
)
610
from unstructured_inference.inference.layoutelement import LayoutElement
711

812

@@ -122,3 +126,43 @@ def mock_layout(mock_embedded_text_regions):
122126
LayoutElement(text=r.text, type="UncategorizedText", bbox=r.bbox)
123127
for r in mock_embedded_text_regions
124128
]
129+
130+
131+
@pytest.fixture()
132+
def example_table_cells():
133+
cells = [
134+
{"cell text": "Disability Category", "row_nums": [0, 1], "column_nums": [0]},
135+
{"cell text": "Participants", "row_nums": [0, 1], "column_nums": [1]},
136+
{"cell text": "Ballots Completed", "row_nums": [0, 1], "column_nums": [2]},
137+
{"cell text": "Ballots Incomplete/Terminated", "row_nums": [0, 1], "column_nums": [3]},
138+
{"cell text": "Results", "row_nums": [0], "column_nums": [4, 5]},
139+
{"cell text": "Accuracy", "row_nums": [1], "column_nums": [4]},
140+
{"cell text": "Time to complete", "row_nums": [1], "column_nums": [5]},
141+
{"cell text": "Blind", "row_nums": [2], "column_nums": [0]},
142+
{"cell text": "Low Vision", "row_nums": [3], "column_nums": [0]},
143+
{"cell text": "Dexterity", "row_nums": [4], "column_nums": [0]},
144+
{"cell text": "Mobility", "row_nums": [5], "column_nums": [0]},
145+
{"cell text": "5", "row_nums": [2], "column_nums": [1]},
146+
{"cell text": "5", "row_nums": [3], "column_nums": [1]},
147+
{"cell text": "5", "row_nums": [4], "column_nums": [1]},
148+
{"cell text": "3", "row_nums": [5], "column_nums": [1]},
149+
{"cell text": "1", "row_nums": [2], "column_nums": [2]},
150+
{"cell text": "2", "row_nums": [3], "column_nums": [2]},
151+
{"cell text": "4", "row_nums": [4], "column_nums": [2]},
152+
{"cell text": "3", "row_nums": [5], "column_nums": [2]},
153+
{"cell text": "4", "row_nums": [2], "column_nums": [3]},
154+
{"cell text": "3", "row_nums": [3], "column_nums": [3]},
155+
{"cell text": "1", "row_nums": [4], "column_nums": [3]},
156+
{"cell text": "0", "row_nums": [5], "column_nums": [3]},
157+
{"cell text": "34.5%, n=1", "row_nums": [2], "column_nums": [4]},
158+
{"cell text": "98.3% n=2 (97.7%, n=3)", "row_nums": [3], "column_nums": [4]},
159+
{"cell text": "98.3%, n=4", "row_nums": [4], "column_nums": [4]},
160+
{"cell text": "95.4%, n=3", "row_nums": [5], "column_nums": [4]},
161+
{"cell text": "1199 sec, n=1", "row_nums": [2], "column_nums": [5]},
162+
{"cell text": "1716 sec, n=3 (1934 sec, n=2)", "row_nums": [3], "column_nums": [5]},
163+
{"cell text": "1672.1 sec, n=4", "row_nums": [4], "column_nums": [5]},
164+
{"cell text": "1416 sec, n=3", "row_nums": [5], "column_nums": [5]},
165+
]
166+
for i in range(len(cells)):
167+
cells[i]["column header"] = False
168+
return [cells]

test_unstructured_inference/models/test_tables.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,54 @@ def test_table_prediction_tesseract(table_transformer, example_image):
361361
) in prediction
362362

363363

364+
@pytest.mark.parametrize(
365+
("output_format", "expectation"),
366+
[
367+
("html", "<tr><td>Blind</td><td>5</td><td>1</td><td>4</td><td>34.5%, n=1</td>"),
368+
(
369+
"cells",
370+
{
371+
"column_nums": [0],
372+
"row_nums": [2],
373+
"column header": False,
374+
"cell text": "Blind",
375+
},
376+
),
377+
("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]),
378+
(None, "<tr><td>Blind</td><td>5</td><td>1</td><td>4</td><td>34.5%, n=1</td>"),
379+
],
380+
)
381+
def test_table_prediction_output_format(
382+
output_format,
383+
expectation,
384+
table_transformer,
385+
example_image,
386+
mocker,
387+
example_table_cells,
388+
):
389+
mocker.patch.object(tables, "recognize", return_value=example_table_cells)
390+
mocker.patch.object(
391+
tables.UnstructuredTableTransformerModel,
392+
"get_structure",
393+
return_value=None,
394+
)
395+
mocker.patch.object(tables.UnstructuredTableTransformerModel, "get_tokens", return_value=None)
396+
if output_format:
397+
result = table_transformer.run_prediction(example_image, result_format=output_format)
398+
else:
399+
result = table_transformer.run_prediction(example_image)
400+
401+
if output_format == "dataframe":
402+
assert expectation in result.values
403+
elif output_format == "cells":
404+
# other output like bbox are flakey to test since they depend on OCR and it may change
405+
# slightly when OCR pacakge changes or even on different machines
406+
validation_fields = ("column_nums", "row_nums", "column header", "cell text")
407+
assert expectation in [{key: cell[key] for key in validation_fields} for cell in result]
408+
else:
409+
assert expectation in result
410+
411+
364412
def test_table_prediction_tesseract_with_ocr_tokens(table_transformer, example_image):
365413
ocr_tokens = [
366414
{
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.10-dev1" # pragma: no cover
1+
__version__ = "0.7.10-dev2" # pragma: no cover

unstructured_inference/models/tables.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unstructured_inference.constants import (
2020
TESSERACT_TEXT_HEIGHT,
2121
)
22+
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
2223
from unstructured_inference.logger import logger
2324
from unstructured_inference.models.table_postprocess import Rect
2425
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
@@ -176,6 +177,7 @@ def run_prediction(
176177
x: Image,
177178
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
178179
ocr_tokens: Optional[List[Dict]] = None,
180+
result_format: Optional[str] = "html",
179181
):
180182
"""Predict table structure"""
181183
outputs_structure = self.get_structure(x, pad_for_structure_detection)
@@ -186,8 +188,12 @@ def run_prediction(
186188
)
187189
ocr_tokens = self.get_tokens(x=x)
188190

189-
html = recognize(outputs_structure, x, tokens=ocr_tokens, out_html=True)["html"]
190-
prediction = html[0] if html else ""
191+
prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0]
192+
if result_format == "html":
193+
# Convert cells to HTML
194+
prediction = cells_to_html(prediction) or ""
195+
elif result_format == "dataframe":
196+
prediction = table_cells_to_dataframe(prediction)
191197
return prediction
192198

193199

@@ -234,10 +240,8 @@ def get_class_map(data_type: str):
234240
}
235241

236242

237-
def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
243+
def recognize(outputs: dict, img: Image, tokens: list):
238244
"""Recognize table elements."""
239-
out_formats = {}
240-
241245
str_class_name2idx = get_class_map("structure")
242246
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
243247
str_class_thresholds = structure_class_thresholds
@@ -248,14 +252,7 @@ def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
248252
# Further process the detected objects so they correspond to a consistent table
249253
tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
250254
# Enumerate all table cells: grid cells and spanning cells
251-
tables_cells = [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
252-
253-
# Convert cells to HTML
254-
if out_html:
255-
tables_htmls = [cells_to_html(cells) for cells in tables_cells]
256-
out_formats["html"] = tables_htmls
257-
258-
return out_formats
255+
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
259256

260257

261258
def outputs_to_objects(outputs, img_size, class_idx2name):

0 commit comments

Comments
 (0)