Skip to content

Commit f23992b

Browse files
TeHeladoCopilot
andauthored
feat: Added support for per-class confidence thresholds in layout detection. (#50)
* Added support for per-class confidence thresholds in layout detection. * Update glmocr/layout/layout_detector.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 88b7808 commit f23992b

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

glmocr/config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ pipeline:
187187
# threshold_by_class: # per-class threshold override
188188
# 0: 0.5
189189
# 1: 0.3
190+
# text: 0.5
191+
# table: 0.2
190192

191193
# Processing
192194
# batch_size: max images per model forward pass (reduce to 1 if OOM)
@@ -227,7 +229,7 @@ pipeline:
227229
20: large # seal
228230
21: large # table
229231
22: large # text
230-
23: large # text
232+
23: large # vertical_text
231233
24: large # vision_footnote
232234

233235
# Map detected labels to OCR task types

glmocr/layout/layout_detector.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, config: "LayoutConfig"):
4343
self.cuda_visible_devices = config.cuda_visible_devices
4444

4545
self.threshold = config.threshold
46+
self.threshold_by_class = config.threshold_by_class
4647
self.layout_nms = config.layout_nms
4748
self.layout_unclip_ratio = config.layout_unclip_ratio
4849
self.layout_merge_bboxes_mode = config.layout_merge_bboxes_mode
@@ -88,6 +89,69 @@ def stop(self):
8889
self._device = None
8990
logger.debug("PP-DocLayoutV3 stopped.")
9091

92+
def _apply_per_class_threshold(self, raw_results: List[Dict]):
93+
"""Filter detections by per-class confidence thresholds.
94+
95+
For each detection, look up its class in threshold_by_class. Classes
96+
not listed fall back to self.threshold.
97+
98+
Args:
99+
raw_results: List of dicts from post_process_object_detection,
100+
each with 'scores', 'labels', 'boxes' tensors and optional
101+
'order_seq' tensor and 'polygon_points' list.
102+
103+
Returns:
104+
Filtered list in the same format.
105+
"""
106+
# Build mapping for label name to class id lookup.
107+
label2id = {name: int(cls_id) for cls_id, name in self.id2label.items()}
108+
109+
# Build a lookup mapping class_id (int) -> threshold (float).
110+
class_thresholds = {}
111+
for key, value in self.threshold_by_class.items():
112+
if isinstance(key, str):
113+
if key in label2id:
114+
class_thresholds[label2id[key]] = float(value)
115+
else:
116+
logger.warning(
117+
"Unknown class name '%s' in threshold_by_class; "
118+
"this entry will be ignored. Known classes: %s",
119+
key,
120+
", ".join(sorted(label2id.keys())),
121+
)
122+
else:
123+
class_thresholds[int(key)] = float(value)
124+
125+
fallback = self.threshold
126+
127+
filtered = []
128+
for result in raw_results:
129+
scores = result["scores"]
130+
labels = result["labels"]
131+
132+
# Build a per-detection threshold tensor: use the per-class value
133+
# if defined, otherwise fall back to the global threshold.
134+
thresholds = torch.full_like(scores, fallback)
135+
for class_id, thresh in class_thresholds.items():
136+
thresholds[labels == class_id] = thresh
137+
138+
keep = scores >= thresholds
139+
140+
new_result = {
141+
"scores": scores[keep],
142+
"labels": labels[keep],
143+
"boxes": result["boxes"][keep],
144+
}
145+
if "order_seq" in result:
146+
new_result["order_seq"] = result["order_seq"][keep]
147+
if "polygon_points" in result:
148+
keep_list = keep.tolist()
149+
new_result["polygon_points"] = [
150+
p for p, k in zip(result["polygon_points"], keep_list) if k
151+
]
152+
filtered.append(new_result)
153+
return filtered
154+
91155
def process(
92156
self,
93157
images: List[Image.Image],
@@ -154,11 +218,23 @@ def process(
154218
except Exception as e:
155219
logger.warning("Pre-filter failed (%s), continuing...", e)
156220

221+
if self.threshold_by_class:
222+
# Use the lowest threshold (per-class or global fallback)
223+
# so post-processing doesn't discard valid detections early.
224+
pre_threshold = min(
225+
self.threshold, min(self.threshold_by_class.values())
226+
)
227+
else:
228+
pre_threshold = self.threshold
229+
157230
raw_results = self._image_processor.post_process_object_detection(
158231
outputs,
159-
threshold=self.threshold,
232+
threshold=pre_threshold,
160233
target_sizes=target_sizes,
161234
)
235+
236+
if self.threshold_by_class:
237+
raw_results = self._apply_per_class_threshold(raw_results)
162238
img_sizes = [img.size for img in chunk_pil]
163239
paddle_format_results = apply_layout_postprocess(
164240
raw_results=raw_results,

0 commit comments

Comments
 (0)