-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtest_document_figure_classifier.py
98 lines (79 loc) · 2.54 KB
/
test_document_figure_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import os
import numpy as np
import pytest
from PIL import Image
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
DocumentFigureClassifierPredictor,
)
from huggingface_hub import snapshot_download
@pytest.fixture(scope="module")
def init() -> dict:
r"""
Initialize the testing environment
"""
init = {
"num_threads": 1,
"test_imgs": [
{
"label": "bar_chart",
"image_path": "tests/test_data/figure_classifier/images/bar_chart.jpg",
},
{
"label": "map",
"image_path": "tests/test_data/figure_classifier/images/map.jpg",
},
],
"info": {
"device": "auto",
},
}
# Download models from HF
init["artifact_path"] = snapshot_download(
repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.1"
)
return init
def test_figure_classifier(init: dict):
r"""
Unit test for the CodeFormulaPredictor
"""
device = "cpu"
num_threads = 2
# Initialize LayoutPredictor
figure_classifier = DocumentFigureClassifierPredictor(
init["artifact_path"], device=device, num_threads=num_threads
)
# Check info
info = figure_classifier.info()
assert info["device"] == device, "Wronly set device"
assert info["num_threads"] == num_threads, "Wronly set number of threads"
# Unsupported input image
is_exception = False
try:
for _ in figure_classifier.predict(["wrong"]):
pass
except TypeError:
is_exception = True
assert is_exception
# Predict on test images, not batched
for d in init["test_imgs"]:
label = d["label"]
img_path = d["image_path"]
with Image.open(img_path) as img:
output = figure_classifier.predict([img])
predicted_class = output[0][0][0]
assert predicted_class == label
# Load images as numpy arrays
np_arr = np.asarray(img)
output = figure_classifier.predict([np_arr])
predicted_class = output[0][0][0]
assert predicted_class == label
# Predict on test images, batched
labels = [d['label'] for d in init["test_imgs"]]
images = [Image.open(d["image_path"]) for d in init["test_imgs"]]
outputs = figure_classifier.predict(images)
outputs = [output[0][0] for output in outputs]
assert outputs == labels