Skip to content

Commit 039c589

Browse files
committed
Merge branch 'main' into codeflash/optimize-zoom_image-metaix6e
2 parents 1cfe7e7 + 9383bac commit 039c589

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
## 1.0.8-dev2
22

33
* Enhancement: Optimized `zoom_image` (codeflash)
4+
* Enhancement: Optimized `cells_to_html` for an 8% speedup in some cases (codeflash)
5+
* Enhancement: Optimized `outputs_to_objects` for an 88% speedup in some cases (codeflash)
46

57
## 1.0.7
68

unstructured_inference/models/tables.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,18 @@ def outputs_to_objects(
213213
):
214214
"""Output table element types."""
215215
m = outputs["logits"].softmax(-1).max(-1)
216-
pred_labels = list(m.indices.detach().cpu().numpy())[0]
217-
pred_scores = list(m.values.detach().cpu().numpy())[0]
216+
pred_labels = m.indices.detach().cpu().numpy()[0]
217+
pred_scores = m.values.detach().cpu().numpy()[0]
218218
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
219219

220220
pad = outputs.get("pad_for_structure_detection", 0)
221221
scale_size = (img_size[0] + pad * 2, img_size[1] + pad * 2)
222-
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, scale_size)]
222+
rescaled = rescale_bboxes(pred_bboxes, scale_size)
223223
# unshift the padding; padding effectively shifted the bounding boxes of structures in the
224224
# original image with half of the total pad
225-
shift_size = pad
225+
if pad != 0:
226+
rescaled = rescaled - pad
227+
pred_bboxes = rescaled.tolist()
226228

227229
objects = []
228230
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
@@ -232,7 +234,7 @@ def outputs_to_objects(
232234
{
233235
"label": class_label,
234236
"score": float(score),
235-
"bbox": [float(elem) - shift_size for elem in bbox],
237+
"bbox": bbox,
236238
},
237239
)
238240

@@ -279,7 +281,7 @@ def rescale_bboxes(out_bbox, size):
279281
"""Rescale relative bounding box to box of size given by size."""
280282
img_w, img_h = size
281283
b = box_cxcywh_to_xyxy(out_bbox)
282-
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
284+
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device=out_bbox.device)
283285
return b
284286

285287

@@ -688,25 +690,32 @@ def fill_cells(cells: List[dict]) -> List[dict]:
688690
if not cells:
689691
return []
690692

691-
table_rows_no = max({row for cell in cells for row in cell["row_nums"]})
692-
table_cols_no = max({col for cell in cells for col in cell["column_nums"]})
693-
filled = np.zeros((table_rows_no + 1, table_cols_no + 1), dtype=bool)
693+
# Find max row and col indices
694+
max_row = max(row for cell in cells for row in cell["row_nums"])
695+
max_col = max(col for cell in cells for col in cell["column_nums"])
696+
filled = set()
694697
for cell in cells:
695698
for row in cell["row_nums"]:
696699
for col in cell["column_nums"]:
697-
filled[row, col] = True
698-
# add cells for which filled is false
699-
header_rows = {row for cell in cells if cell["column header"] for row in cell["row_nums"]}
700+
filled.add((row, col))
701+
header_rows = set()
702+
for cell in cells:
703+
if cell["column header"]:
704+
header_rows.update(cell["row_nums"])
705+
706+
# Compose output list directly for speed
700707
new_cells = cells.copy()
701-
not_filled_idx = np.where(filled == False) # noqa: E712
702-
for row, col in zip(not_filled_idx[0], not_filled_idx[1]):
703-
new_cell = {
704-
"row_nums": [row],
705-
"column_nums": [col],
706-
"cell text": "",
707-
"column header": row in header_rows,
708-
}
709-
new_cells.append(new_cell)
708+
for row in range(max_row + 1):
709+
for col in range(max_col + 1):
710+
if (row, col) not in filled:
711+
new_cells.append(
712+
{
713+
"row_nums": [row],
714+
"column_nums": [col],
715+
"cell text": "",
716+
"column header": row in header_rows,
717+
}
718+
)
710719
return new_cells
711720

712721

@@ -725,18 +734,20 @@ def cells_to_html(cells: List[dict]) -> str:
725734
Returns:
726735
str: HTML table string
727736
"""
728-
cells = sorted(fill_cells(cells), key=lambda k: (min(k["row_nums"]), min(k["column_nums"])))
737+
# Pre-sort with tuple key, as per original
738+
cells_filled = fill_cells(cells)
739+
cells_sorted = sorted(cells_filled, key=lambda k: (min(k["row_nums"]), min(k["column_nums"])))
729740

730741
table = ET.Element("table")
731742
current_row = -1
732743

733-
table_header = None
734-
table_has_header = any(cell["column header"] for cell in cells)
735-
if table_has_header:
736-
table_header = ET.SubElement(table, "thead")
737-
744+
# Check if any column header exists
745+
table_has_header = any(cell["column header"] for cell in cells_sorted)
746+
table_header = ET.SubElement(table, "thead") if table_has_header else None
738747
table_body = ET.SubElement(table, "tbody")
739-
for cell in cells:
748+
749+
row = None
750+
for cell in cells_sorted:
740751
this_row = min(cell["row_nums"])
741752
attrib = {}
742753
colspan = len(cell["column_nums"])
@@ -754,8 +765,9 @@ def cells_to_html(cells: List[dict]) -> str:
754765
table_subelement = table_body
755766
cell_tag = "td"
756767
row = ET.SubElement(table_subelement, "tr") # type: ignore
757-
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
758-
tcell.text = cell["cell text"]
768+
if row is not None:
769+
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
770+
tcell.text = cell["cell text"]
759771

760772
return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))
761773

0 commit comments

Comments
 (0)