Skip to content

Commit 6ed8ac2

Browse files
authored
enhancement: get model name and initialization params externally (#291)
Added a method to externally inject a supergradients ONNX model. By setting the environment variable `UNSTRUCTURED_DEFAULT_MODEL_NAME` one can override the default model. By setting the environment variable `UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH`, one can specify the path to a JSON containing model initialization parameters. #### Testing: ```python import os from unstructured_inference.models.base import get_model os.environ["UNSTRUCTURED_DEFAULT_MODEL_NAME"] = "detectron2_onnx" model = get_model() print(type(model)) ``` Output should be `UnstructuredDetectronONNXModel` as opposed to `UnstructuredYoloXModel`. ```python from unittest import mock import os import json from unstructured_inference.models.base import get_model from huggingface_hub import hf_hub_download label_map = {0: "Blue", 1: "Red"} model_path = hf_hub_download("unstructuredio/yolo_x_layout", "yolox_tiny.onnx") json_dict = {"model_path": model_path, "label_map": label_map} os.environ["UNSTRUCTURED_DEFAULT_MODEL_NAME"] = "yolox" os.environ["UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH"] = "some/fake/path.json" with mock.patch("builtins.open", mock.mock_open(read_data=json.dumps(json_dict))): model = get_model() print(model.layout_classes) ``` Output should be `{0: "Blue", 1: "Red"}` as opposed to the normal YoloX labels.
1 parent fe383c2 commit 6ed8ac2

File tree

7 files changed

+121
-68
lines changed

7 files changed

+121
-68
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.7.15
2+
3+
* enhancement: Enable env variables for model definition
4+
15
## 0.7.14
26

37
* enhancement: Remove Super-Gradients Dependency and Allow General Onnx Models Instead

test_unstructured_inference/inference/test_layout.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,12 @@ def test_read_pdf(monkeypatch, mock_initial_layout, mock_final_layout, mock_imag
145145

146146
layouts = [mock_initial_layout, mock_initial_layout]
147147

148-
monkeypatch.setattr(
149-
models,
150-
"UnstructuredDetectronModel",
151-
partial(MockLayoutModel, layout=mock_final_layout),
152-
)
153148
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
154149

155-
with patch.object(layout, "load_pdf", return_value=(layouts, image_paths)):
150+
with patch.object(layout, "load_pdf", return_value=(layouts, image_paths)), patch.dict(
151+
models.model_class_map,
152+
{"detectron2_lp": partial(MockLayoutModel, layout=mock_final_layout)},
153+
):
156154
model = layout.get_model("detectron2_lp")
157155
doc = layout.DocumentLayout.from_file("fake-file.pdf", detection_model=model)
158156

@@ -266,15 +264,26 @@ def __init__(
266264

267265
@pytest.mark.parametrize(
268266
("text", "expected"),
269-
[("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)],
267+
[
268+
("base", 0.0),
269+
("", 0.0),
270+
("(cid:2)", 1.0),
271+
("(cid:1)a", 0.5),
272+
("c(cid:1)ab", 0.25),
273+
],
270274
)
271275
def test_cid_ratio(text, expected):
272276
assert elements.cid_ratio(text) == expected
273277

274278

275279
@pytest.mark.parametrize(
276280
("text", "expected"),
277-
[("base", False), ("(cid:2)", True), ("(cid:1234567890)", True), ("jkl;(cid:12)asdf", True)],
281+
[
282+
("base", False),
283+
("(cid:2)", True),
284+
("(cid:1234567890)", True),
285+
("jkl;(cid:12)asdf", True),
286+
],
278287
)
279288
def test_is_cid_present(text, expected):
280289
assert elements.is_cid_present(text) == expected
@@ -389,7 +398,11 @@ def test_page_numbers_in_page_objects():
389398
@pytest.mark.parametrize(
390399
("fixed_layouts", "called_method", "not_called_method"),
391400
[
392-
([MockLayout()], "get_elements_from_layout", "get_elements_with_detection_model"),
401+
(
402+
[MockLayout()],
403+
"get_elements_from_layout",
404+
"get_elements_with_detection_model",
405+
),
393406
(None, "get_elements_with_detection_model", "get_elements_from_layout"),
394407
],
395408
)
@@ -470,7 +483,11 @@ def test_load_pdf_raises_with_path_only_no_output_folder():
470483
def test_load_pdf_with_multicolumn_layout(filename="sample-docs/design-thinking.pdf"):
471484
layouts, images = layout.load_pdf(filename)
472485
doc = layout.process_file_with_model(filename=filename, model_name=None)
473-
test_snippets = ["Key to design thinking", "Design thinking also", "But in recent years"]
486+
test_snippets = [
487+
"Key to design thinking",
488+
"Design thinking also",
489+
"But in recent years",
490+
]
474491

475492
test_elements = []
476493
for element in doc.pages[0].elements:
@@ -590,7 +607,9 @@ def test_get_elements_using_image_extraction(mock_image, inplace, expected):
590607
assert page.get_elements_using_image_extraction(inplace=inplace) == expected
591608

592609

593-
def test_get_elements_using_image_extraction_raises_with_no_extraction_model(mock_image):
610+
def test_get_elements_using_image_extraction_raises_with_no_extraction_model(
611+
mock_image,
612+
):
594613
page = layout.PageLayout(1, mock_image, None, element_extraction_model=None)
595614
with pytest.raises(ValueError):
596615
page.get_elements_using_image_extraction()
@@ -707,7 +726,10 @@ def test_exposed_pdf_image_dpi(pdf_image_dpi, expected, monkeypatch):
707726

708727
@pytest.mark.parametrize(
709728
("filename", "img_num", "should_complete"),
710-
[("sample-docs/empty-document.pdf", 0, True), ("sample-docs/empty-document.pdf", 10, False)],
729+
[
730+
("sample-docs/empty-document.pdf", 0, True),
731+
("sample-docs/empty-document.pdf", 10, False),
732+
],
711733
)
712734
def test_get_image(filename, img_num, should_complete):
713735
doc = layout.DocumentLayout.from_file(filename)

test_unstructured_inference/models/test_model.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,8 @@ def predict(self, x: Any) -> Any:
2727

2828
def test_get_model(monkeypatch):
2929
monkeypatch.setattr(models, "models", {})
30-
monkeypatch.setattr(
31-
models,
32-
"UnstructuredDetectronModel",
33-
MockModel,
34-
)
35-
assert isinstance(models.get_model("checkbox"), MockModel)
30+
with mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
31+
assert isinstance(models.get_model("checkbox"), MockModel)
3632

3733

3834
def test_raises_invalid_model():
@@ -48,20 +44,15 @@ def test_raises_uninitialized():
4844
def test_model_initializes_once():
4945
from unstructured_inference.inference import layout
5046

51-
with mock.patch.object(models, "UnstructuredYoloXModel", MockModel), mock.patch.object(
52-
models,
53-
"models",
54-
{},
47+
with mock.patch.dict(models.model_class_map, {"yolox": MockModel}), mock.patch.object(
48+
models, "models", {}
5549
):
5650
doc = layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf")
5751
doc.pages[0].detection_model.initializer.assert_called_once()
58-
assert hasattr(
59-
doc.pages[0].elements[0],
60-
"prob",
61-
) # NOTE(pravin) New Assertion to Make Sure Elements have probability attribute
62-
assert (
63-
doc.pages[0].elements[0].prob is None
64-
) # NOTE(pravin) New Assertion to Make Sure Uncategorized Text has None Probability
52+
# NOTE(pravin) New Assertion to Make Sure Elements have probability attribute
53+
assert hasattr(doc.pages[0].elements[0], "prob")
54+
# NOTE(pravin) New Assertion to Make Sure Uncategorized Text has None Probability
55+
assert doc.pages[0].elements[0].prob is None
6556

6657

6758
def test_deduplicate_detected_elements():
@@ -107,7 +98,12 @@ def test_enhance_regions():
10798
model = get_model("yolox_tiny")
10899
elements = model.enhance_regions(elements, 0.5)
109100
assert len(elements) == 1
110-
assert (elements[0].bbox.x1, elements[0].bbox.y1, elements[0].bbox.x2, elements[0].bbox.x2) == (
101+
assert (
102+
elements[0].bbox.x1,
103+
elements[0].bbox.y1,
104+
elements[0].bbox.x2,
105+
elements[0].bbox.x2,
106+
) == (
111107
0,
112108
0,
113109
1.10,
@@ -138,9 +134,36 @@ def test_clean_type():
138134
model = get_model("yolox_tiny")
139135
elements = model.clean_type(elements, type_to_clean="Table")
140136
assert len(elements) == 1
141-
assert (elements[0].bbox.x1, elements[0].bbox.y1, elements[0].bbox.x2, elements[0].bbox.x2) == (
142-
0,
143-
0,
144-
1,
145-
1,
146-
)
137+
assert (
138+
elements[0].bbox.x1,
139+
elements[0].bbox.y1,
140+
elements[0].bbox.x2,
141+
elements[0].bbox.x2,
142+
) == (0, 0, 1, 1)
143+
144+
145+
def test_env_variables_override_default_model(monkeypatch):
146+
# When an environment variable specifies a different default model and we call get_model with no
147+
# args, we should get back the model the env var calls for
148+
monkeypatch.setattr(models, "models", {})
149+
with mock.patch.dict(
150+
models.os.environ, {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "checkbox"}
151+
), mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
152+
model = models.get_model()
153+
assert isinstance(model, MockModel)
154+
155+
156+
def test_env_variables_override_intialization_params(monkeypatch):
157+
# When initialization params are specified in an environment variable, and we call get_model, we
158+
# should see that the model was initialized with those params
159+
monkeypatch.setattr(models, "models", {})
160+
with mock.patch.dict(
161+
models.os.environ,
162+
{"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH": "fake_json.json"},
163+
), mock.patch.object(models, "DEFAULT_MODEL", "fake"), mock.patch.dict(
164+
models.model_class_map, {"fake": mock.MagicMock()}
165+
), mock.patch(
166+
"builtins.open", mock.mock_open(read_data='{"date": "3/26/81"}')
167+
):
168+
model = models.get_model()
169+
model.initialize.assert_called_once_with(date="3/26/81")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.14" # pragma: no cover
1+
__version__ = "0.7.15" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Dict, Optional
1+
import json
2+
import os
3+
from typing import Dict, Optional, Type
24

35
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
46
from unstructured_inference.models.chipper import UnstructuredChipperModel
@@ -29,46 +31,47 @@
2931

3032
models: Dict[str, UnstructuredModel] = {}
3133

34+
model_class_map: Dict[str, Type[UnstructuredModel]] = {
35+
**{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES},
36+
**{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
37+
**{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
38+
**{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
39+
"super_gradients": UnstructuredSuperGradients,
40+
}
3241

33-
def get_model(
34-
model_name: Optional[str] = None,
35-
model_path: Optional[str] = None,
36-
label_map: Optional[dict] = None,
37-
input_shape: Optional[tuple] = None,
38-
) -> UnstructuredModel:
42+
43+
def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
3944
"""Gets the model object by model name."""
4045
# TODO(alan): These cases are similar enough that we can probably do them all together with
4146
# importlib
4247

4348
global models
4449

4550
if model_name is None:
46-
model_name = DEFAULT_MODEL
51+
default_name_from_env = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_NAME")
52+
model_name = default_name_from_env if default_name_from_env is not None else DEFAULT_MODEL
4753

4854
if model_name in models:
4955
return models[model_name]
5056

51-
if model_name in DETECTRON2_MODEL_TYPES:
52-
model: UnstructuredModel = UnstructuredDetectronModel()
53-
initialize_params = {**DETECTRON2_MODEL_TYPES[model_name]}
54-
elif model_name in DETECTRON2_ONNX_MODEL_TYPES:
55-
model = UnstructuredDetectronONNXModel()
56-
initialize_params = {**DETECTRON2_ONNX_MODEL_TYPES[model_name]}
57-
elif model_name in YOLOX_MODEL_TYPES:
58-
model = UnstructuredYoloXModel()
59-
initialize_params = {**YOLOX_MODEL_TYPES[model_name]}
60-
elif model_name in CHIPPER_MODEL_TYPES:
61-
model = UnstructuredChipperModel()
62-
initialize_params = {**CHIPPER_MODEL_TYPES[model_name]}
63-
elif model_name == "super_gradients":
64-
model = UnstructuredSuperGradients()
65-
initialize_params = {
66-
"model_path": model_path,
67-
"label_map": label_map,
68-
"input_shape": input_shape,
69-
}
57+
initialize_param_json = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH")
58+
if initialize_param_json is not None:
59+
with open(initialize_param_json) as fp:
60+
initialize_params = json.load(fp)
7061
else:
71-
raise UnknownModelException(f"Unknown model type: {model_name}")
62+
if model_name in DETECTRON2_MODEL_TYPES:
63+
initialize_params = DETECTRON2_MODEL_TYPES[model_name]
64+
elif model_name in DETECTRON2_ONNX_MODEL_TYPES:
65+
initialize_params = DETECTRON2_ONNX_MODEL_TYPES[model_name]
66+
elif model_name in YOLOX_MODEL_TYPES:
67+
initialize_params = YOLOX_MODEL_TYPES[model_name]
68+
elif model_name in CHIPPER_MODEL_TYPES:
69+
initialize_params = CHIPPER_MODEL_TYPES[model_name]
70+
else:
71+
raise UnknownModelException(f"Unknown model type: {model_name}")
72+
73+
model: UnstructuredModel = model_class_map[model_name]()
74+
7275
model.initialize(**initialize_params)
7376
models[model_name] = model
7477
return model

unstructured_inference/models/chipper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from unstructured_inference.utils import LazyDict, strip_tags
2525

26-
MODEL_TYPES: Dict[Optional[str], Union[LazyDict, dict]] = {
26+
MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = {
2727
"chipperv1": {
2828
"pre_trained_model_repo": "unstructuredio/ved-fine-tuning",
2929
"swap_head": False,

unstructured_inference/models/detectron2onnx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
# NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are
3333
# needed.
34-
MODEL_TYPES: Dict[Optional[str], Union[LazyDict, dict]] = {
34+
MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = {
3535
"detectron2_onnx": LazyDict(
3636
model_path=LazyEvaluateInfo(
3737
hf_hub_download,
@@ -124,7 +124,8 @@ def initialize(
124124

125125
def preprocess(self, image: Image.Image) -> Dict[str, np.ndarray]:
126126
"""Process input image into required format for ingestion into the Detectron2 ONNX binary.
127-
This involves resizing to a fixed shape and converting to a specific numpy format."""
127+
This involves resizing to a fixed shape and converting to a specific numpy format.
128+
"""
128129
# TODO (benjamin): check other shapes for inference
129130
img = np.array(image)
130131
# TODO (benjamin): We should use models.get_model() but currenly returns Detectron model

0 commit comments

Comments
 (0)