@@ -213,16 +213,18 @@ def outputs_to_objects(
213213):
214214 """Output table element types."""
215215 m = outputs ["logits" ].softmax (- 1 ).max (- 1 )
216- pred_labels = list ( m .indices .detach ().cpu ().numpy () )[0 ]
217- pred_scores = list ( m .values .detach ().cpu ().numpy () )[0 ]
216+ pred_labels = m .indices .detach ().cpu ().numpy ()[0 ]
217+ pred_scores = m .values .detach ().cpu ().numpy ()[0 ]
218218 pred_bboxes = outputs ["pred_boxes" ].detach ().cpu ()[0 ]
219219
220220 pad = outputs .get ("pad_for_structure_detection" , 0 )
221221 scale_size = (img_size [0 ] + pad * 2 , img_size [1 ] + pad * 2 )
222- pred_bboxes = [ elem . tolist () for elem in rescale_bboxes (pred_bboxes , scale_size )]
222+ rescaled = rescale_bboxes (pred_bboxes , scale_size )
223223 # unshift the padding; padding effectively shifted the bounding boxes of structures in the
224224 # original image with half of the total pad
225- shift_size = pad
225+ if pad != 0 :
226+ rescaled = rescaled - pad
227+ pred_bboxes = rescaled .tolist ()
226228
227229 objects = []
228230 for label , score , bbox in zip (pred_labels , pred_scores , pred_bboxes ):
@@ -232,7 +234,7 @@ def outputs_to_objects(
232234 {
233235 "label" : class_label ,
234236 "score" : float (score ),
235- "bbox" : [ float ( elem ) - shift_size for elem in bbox ] ,
237+ "bbox" : bbox ,
236238 },
237239 )
238240
@@ -279,7 +281,7 @@ def rescale_bboxes(out_bbox, size):
279281 """Rescale relative bounding box to box of size given by size."""
280282 img_w , img_h = size
281283 b = box_cxcywh_to_xyxy (out_bbox )
282- b = b * torch .tensor ([img_w , img_h , img_w , img_h ], dtype = torch .float32 )
284+ b = b * torch .tensor ([img_w , img_h , img_w , img_h ], dtype = torch .float32 , device = out_bbox . device )
283285 return b
284286
285287
@@ -688,25 +690,32 @@ def fill_cells(cells: List[dict]) -> List[dict]:
688690 if not cells :
689691 return []
690692
691- table_rows_no = max ({row for cell in cells for row in cell ["row_nums" ]})
692- table_cols_no = max ({col for cell in cells for col in cell ["column_nums" ]})
693- filled = np .zeros ((table_rows_no + 1 , table_cols_no + 1 ), dtype = bool )
693+ # Find max row and col indices
694+ max_row = max (row for cell in cells for row in cell ["row_nums" ])
695+ max_col = max (col for cell in cells for col in cell ["column_nums" ])
696+ filled = set ()
694697 for cell in cells :
695698 for row in cell ["row_nums" ]:
696699 for col in cell ["column_nums" ]:
697- filled [row , col ] = True
698- # add cells for which filled is false
699- header_rows = {row for cell in cells if cell ["column header" ] for row in cell ["row_nums" ]}
700+ filled .add ((row , col ))
701+ header_rows = set ()
702+ for cell in cells :
703+ if cell ["column header" ]:
704+ header_rows .update (cell ["row_nums" ])
705+
706+ # Compose output list directly for speed
700707 new_cells = cells .copy ()
701- not_filled_idx = np .where (filled == False ) # noqa: E712
702- for row , col in zip (not_filled_idx [0 ], not_filled_idx [1 ]):
703- new_cell = {
704- "row_nums" : [row ],
705- "column_nums" : [col ],
706- "cell text" : "" ,
707- "column header" : row in header_rows ,
708- }
709- new_cells .append (new_cell )
708+ for row in range (max_row + 1 ):
709+ for col in range (max_col + 1 ):
710+ if (row , col ) not in filled :
711+ new_cells .append (
712+ {
713+ "row_nums" : [row ],
714+ "column_nums" : [col ],
715+ "cell text" : "" ,
716+ "column header" : row in header_rows ,
717+ }
718+ )
710719 return new_cells
711720
712721
@@ -725,18 +734,20 @@ def cells_to_html(cells: List[dict]) -> str:
725734 Returns:
726735 str: HTML table string
727736 """
728- cells = sorted (fill_cells (cells ), key = lambda k : (min (k ["row_nums" ]), min (k ["column_nums" ])))
737+ # Pre-sort with tuple key, as per original
738+ cells_filled = fill_cells (cells )
739+ cells_sorted = sorted (cells_filled , key = lambda k : (min (k ["row_nums" ]), min (k ["column_nums" ])))
729740
730741 table = ET .Element ("table" )
731742 current_row = - 1
732743
733- table_header = None
734- table_has_header = any (cell ["column header" ] for cell in cells )
735- if table_has_header :
736- table_header = ET .SubElement (table , "thead" )
737-
744+ # Check if any column header exists
745+ table_has_header = any (cell ["column header" ] for cell in cells_sorted )
746+ table_header = ET .SubElement (table , "thead" ) if table_has_header else None
738747 table_body = ET .SubElement (table , "tbody" )
739- for cell in cells :
748+
749+ row = None
750+ for cell in cells_sorted :
740751 this_row = min (cell ["row_nums" ])
741752 attrib = {}
742753 colspan = len (cell ["column_nums" ])
@@ -754,8 +765,9 @@ def cells_to_html(cells: List[dict]) -> str:
754765 table_subelement = table_body
755766 cell_tag = "td"
756767 row = ET .SubElement (table_subelement , "tr" ) # type: ignore
757- tcell = ET .SubElement (row , cell_tag , attrib = attrib )
758- tcell .text = cell ["cell text" ]
768+ if row is not None :
769+ tcell = ET .SubElement (row , cell_tag , attrib = attrib )
770+ tcell .text = cell ["cell text" ]
759771
760772 return str (ET .tostring (table , encoding = "unicode" , short_empty_elements = False ))
761773
0 commit comments