@@ -43,6 +43,7 @@ def __init__(self, config: "LayoutConfig"):
4343 self .cuda_visible_devices = config .cuda_visible_devices
4444
4545 self .threshold = config .threshold
46+ self .threshold_by_class = config .threshold_by_class
4647 self .layout_nms = config .layout_nms
4748 self .layout_unclip_ratio = config .layout_unclip_ratio
4849 self .layout_merge_bboxes_mode = config .layout_merge_bboxes_mode
@@ -88,6 +89,69 @@ def stop(self):
8889 self ._device = None
8990 logger .debug ("PP-DocLayoutV3 stopped." )
9091
92+ def _apply_per_class_threshold (self , raw_results : List [Dict ]):
93+ """Filter detections by per-class confidence thresholds.
94+
95+ For each detection, look up its class in threshold_by_class. Classes
96+ not listed fall back to self.threshold.
97+
98+ Args:
99+ raw_results: List of dicts from post_process_object_detection,
100+ each with 'scores', 'labels', 'boxes' tensors and optional
101+ 'order_seq' tensor and 'polygon_points' list.
102+
103+ Returns:
104+ Filtered list in the same format.
105+ """
106+ # Build mapping for label name to class id lookup.
107+ label2id = {name : int (cls_id ) for cls_id , name in self .id2label .items ()}
108+
109+ # Build a lookup mapping class_id (int) -> threshold (float).
110+ class_thresholds = {}
111+ for key , value in self .threshold_by_class .items ():
112+ if isinstance (key , str ):
113+ if key in label2id :
114+ class_thresholds [label2id [key ]] = float (value )
115+ else :
116+ logger .warning (
117+ "Unknown class name '%s' in threshold_by_class; "
118+ "this entry will be ignored. Known classes: %s" ,
119+ key ,
120+ ", " .join (sorted (label2id .keys ())),
121+ )
122+ else :
123+ class_thresholds [int (key )] = float (value )
124+
125+ fallback = self .threshold
126+
127+ filtered = []
128+ for result in raw_results :
129+ scores = result ["scores" ]
130+ labels = result ["labels" ]
131+
132+ # Build a per-detection threshold tensor: use the per-class value
133+ # if defined, otherwise fall back to the global threshold.
134+ thresholds = torch .full_like (scores , fallback )
135+ for class_id , thresh in class_thresholds .items ():
136+ thresholds [labels == class_id ] = thresh
137+
138+ keep = scores >= thresholds
139+
140+ new_result = {
141+ "scores" : scores [keep ],
142+ "labels" : labels [keep ],
143+ "boxes" : result ["boxes" ][keep ],
144+ }
145+ if "order_seq" in result :
146+ new_result ["order_seq" ] = result ["order_seq" ][keep ]
147+ if "polygon_points" in result :
148+ keep_list = keep .tolist ()
149+ new_result ["polygon_points" ] = [
150+ p for p , k in zip (result ["polygon_points" ], keep_list ) if k
151+ ]
152+ filtered .append (new_result )
153+ return filtered
154+
91155 def process (
92156 self ,
93157 images : List [Image .Image ],
@@ -154,11 +218,23 @@ def process(
154218 except Exception as e :
155219 logger .warning ("Pre-filter failed (%s), continuing..." , e )
156220
221+ if self .threshold_by_class :
222+ # Use the lowest threshold (per-class or global fallback)
223+ # so post-processing doesn't discard valid detections early.
224+ pre_threshold = min (
225+ self .threshold , min (self .threshold_by_class .values ())
226+ )
227+ else :
228+ pre_threshold = self .threshold
229+
157230 raw_results = self ._image_processor .post_process_object_detection (
158231 outputs ,
159- threshold = self . threshold ,
232+ threshold = pre_threshold ,
160233 target_sizes = target_sizes ,
161234 )
235+
236+ if self .threshold_by_class :
237+ raw_results = self ._apply_per_class_threshold (raw_results )
162238 img_sizes = [img .size for img in chunk_pil ]
163239 paddle_format_results = apply_layout_postprocess (
164240 raw_results = raw_results ,
0 commit comments