Skip to content

Commit 9d03085

Browse files
fix: linknet hyperparameters postprocessing + demo for rotation model (#865)
* fix: linknet parameters * feat: add demo rotation * feat: add rotation in demo
1 parent 9878d03 commit 9d03085

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

demo/app.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from doctr.models import ocr_predictor
2222
from doctr.utils.visualization import visualize_page
2323

24-
DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
24+
DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18_rotation"]
2525
RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]
2626

2727

@@ -73,7 +73,10 @@ def main():
7373

7474
else:
7575
with st.spinner('Loading model...'):
76-
predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)
76+
predictor = ocr_predictor(
77+
det_arch, reco_arch, pretrained=True,
78+
assume_straight_pages=(det_arch != "linknet_resnet18_rotation")
79+
)
7780

7881
with st.spinner('Analyzing...'):
7982

@@ -97,8 +100,9 @@ def main():
97100

98101
# Page reconsitution under input page
99102
page_export = out.pages[0].export()
100-
img = out.pages[0].synthesize()
101-
cols[3].image(img, clamp=True)
103+
if det_arch != "linknet_resnet18_rotation":
104+
img = out.pages[0].synthesize()
105+
cols[3].image(img, clamp=True)
102106

103107
# Display JSON
104108
st.markdown("\nHere are your analysis results in JSON format:")

doctr/models/detection/linknet/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
3030
"""
3131
def __init__(
3232
self,
33-
bin_thresh: float = 0.5,
33+
bin_thresh: float = 0.1,
3434
box_thresh: float = 0.1,
3535
assume_straight_pages: bool = True,
3636
) -> None:
@@ -39,7 +39,7 @@ def __init__(
3939
bin_thresh,
4040
assume_straight_pages
4141
)
42-
self.unclip_ratio = 1.5
42+
self.unclip_ratio = 1.2
4343

4444
def polygon_to_box(
4545
self,
@@ -103,13 +103,12 @@ def bitmap_to_boxes(
103103
containing x, y, w, h, alpha, score for the box
104104
"""
105105
height, width = bitmap.shape[:2]
106-
min_size_box = 1 + int(height / 512)
107106
boxes = []
108107
# get contours from connected components on the bitmap
109108
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110109
for contour in contours:
111110
# Check whether smallest enclosing bounding box is not too small
112-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
111+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
113112
continue
114113
# Compute objectness
115114
if self.assume_straight_pages:

doctr/utils/visualization.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,16 @@ def visualize_page(
218218
int(page['dimensions'][1] * word['geometry'][0][0]),
219219
int(page['dimensions'][0] * word['geometry'][0][1])
220220
)
221-
ax.text(
222-
*text_loc,
223-
word['value'],
224-
size=10,
225-
alpha=0.5,
226-
color=(0, 0, 1),
227-
)
221+
222+
if len(word['geometry']) == 2:
223+
# We draw only if boxes are in straight format
224+
ax.text(
225+
*text_loc,
226+
word['value'],
227+
size=10,
228+
alpha=0.5,
229+
color=(0, 0, 1),
230+
)
228231

229232
if display_artefacts:
230233
for artefact in block['artefacts']:
@@ -251,7 +254,6 @@ def visualize_page(
251254
def synthesize_page(
252255
page: Dict[str, Any],
253256
draw_proba: bool = False,
254-
font_size: int = 13,
255257
font_family: Optional[str] = None,
256258
) -> np.ndarray:
257259
"""Draw a the content of the element page (OCR response) on a blank page.

0 commit comments

Comments
 (0)