Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 1.0.9

* Make OD model loading thread safe

## 1.0.8-dev2

* Enhancement: Optimized `zoom_image` (codeflash)
Expand Down
44 changes: 44 additions & 0 deletions test_unstructured_inference/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import threading
from typing import Any
from unittest import mock

Expand Down Expand Up @@ -40,6 +41,49 @@ def test_get_model(monkeypatch):
assert isinstance(models.get_model("yolox"), MockModel)


def test_get_model_threaded(monkeypatch):
"""Test that get_model works correctly when called from multiple threads simultaneously."""
monkeypatch.setattr(models, "models", {})

# Results and exceptions from threads will be stored here
results = []
exceptions = []

def get_model_worker(thread_id):
"""Worker function for each thread."""
try:
model = models.get_model("yolox")
results.append((thread_id, model))
except Exception as e:
exceptions.append((thread_id, e))

# Create and start multiple threads
num_threads = 10
threads = []

with mock.patch.dict(models.model_class_map, {"yolox": MockModel}):
for i in range(num_threads):
thread = threading.Thread(target=get_model_worker, args=(i,))
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

# Verify no exceptions occurred
assert len(exceptions) == 0, f"Exceptions occurred in threads: {exceptions}"

# Verify all threads got results
assert len(results) == num_threads, f"Expected {num_threads} results, got {len(results)}"

# Verify all results are MockModel instances
for thread_id, model in results:
assert isinstance(
model, MockModel
), f"Thread {thread_id} got unexpected model type: {type(model)}"


def test_register_new_model():
assert "foo" not in models.model_class_map
assert "foo" not in models.model_config_map
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.8-dev2" # pragma: no cover
__version__ = "1.0.9" # pragma: no cover
38 changes: 23 additions & 15 deletions unstructured_inference/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __setitem__(self, key: str, value: UnstructuredModel):

models: Models = Models()

models_lock = threading.Lock()


def get_default_model_mappings() -> Tuple[
Dict[str, Type[UnstructuredModel]],
Expand Down Expand Up @@ -78,24 +80,30 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
if model_name in models:
return models[model_name]

initialize_param_json = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH")
if initialize_param_json is not None:
with open(initialize_param_json) as fp:
initialize_params = json.load(fp)
label_map_int_keys = {
int(key): value for key, value in initialize_params["label_map"].items()
}
initialize_params["label_map"] = label_map_int_keys
else:
if model_name in model_config_map:
initialize_params = model_config_map[model_name]
with models_lock:
if model_name in models:
return models[model_name]

initialize_param_json = os.environ.get(
"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH"
)
if initialize_param_json is not None:
with open(initialize_param_json) as fp:
initialize_params = json.load(fp)
label_map_int_keys = {
int(key): value for key, value in initialize_params["label_map"].items()
}
initialize_params["label_map"] = label_map_int_keys
else:
raise UnknownModelException(f"Unknown model type: {model_name}")
if model_name in model_config_map:
initialize_params = model_config_map[model_name]
else:
raise UnknownModelException(f"Unknown model type: {model_name}")

model: UnstructuredModel = model_class_map[model_name]()
model: UnstructuredModel = model_class_map[model_name]()

model.initialize(**initialize_params)
models[model_name] = model
model.initialize(**initialize_params)
models[model_name] = model
return model


Expand Down