Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugin/plugins/galgame_plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ class GalgameRapidOcrConfig:
rapidocr_engine_type: str = "onnxruntime"
rapidocr_lang_type: str = "ch"
rapidocr_model_type: str = "mobile"
rapidocr_ocr_version: str = "PP-OCRv5"
rapidocr_ocr_version: str = "PP-OCRv4"


@dataclass(slots=True, init=False)
Expand Down
2 changes: 1 addition & 1 deletion plugin/plugins/galgame_plugin/plugin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ install_timeout_seconds = 180
engine_type = "onnxruntime"
lang_type = "ch"
model_type = "mobile"
ocr_version = "PP-OCRv5"
ocr_version = "PP-OCRv4"

[plugin.i18n]
default_locale = "zh-CN"
Expand Down
78 changes: 77 additions & 1 deletion plugin/plugins/galgame_plugin/rapidocr_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DEFAULT_RAPIDOCR_ENGINE_TYPE = "onnxruntime"
DEFAULT_RAPIDOCR_LANG_TYPE = "ch"
DEFAULT_RAPIDOCR_MODEL_TYPE = "mobile"
DEFAULT_RAPIDOCR_OCR_VERSION = "PP-OCRv5"
DEFAULT_RAPIDOCR_OCR_VERSION = "PP-OCRv4"
_INSTALL_STATE_NAME = "install_state.json"
# Leave one core free for the OS / interactive use; floor at 2 so 1-2 core hosts still parallelise.
_RAPIDOCR_INFERENCE_THREAD_LIMIT = max(2, (os.cpu_count() or 2) - 1)
Expand Down Expand Up @@ -99,6 +99,50 @@ def rapidocr_selected_model_name(
)


def _resolve_rapidocr_model_paths(
*,
model_cache_dir: Path,
package_models_dir: Path | None,
lang_type: str,
ocr_version: str,
model_type: str,
) -> tuple[str | None, str | None, str | None]:
lang = str(lang_type or DEFAULT_RAPIDOCR_LANG_TYPE).strip() or DEFAULT_RAPIDOCR_LANG_TYPE
version = str(ocr_version or DEFAULT_RAPIDOCR_OCR_VERSION).strip() or DEFAULT_RAPIDOCR_OCR_VERSION
# RapidOCR / PaddleOCR file naming (see SWHL/RapidOCR on HuggingFace):
# mobile: ch_PP-OCRv4_det_infer.onnx
# server: ch_PP-OCRv4_det_server_infer.onnx (i.e. "_server" sits between "_det" and "_infer")
# Anything other than "server" falls back to mobile silently — invalid values
# would otherwise miss every candidate file and the runtime default mobile model
# would still load via RapidOCR's bundled config.yaml.
mt = (str(model_type or DEFAULT_RAPIDOCR_MODEL_TYPE).strip() or DEFAULT_RAPIDOCR_MODEL_TYPE).lower()
server_infix = "_server" if mt == "server" else ""
det_name = f"{lang}_{version}_det{server_infix}_infer.onnx"
rec_name = f"{lang}_{version}_rec{server_infix}_infer.onnx"
# cls is shared across mobile/server variants in PaddleOCR's released model zoo.
cls_name = "ch_ppocr_mobile_v2.0_cls_infer.onnx"

det_path: str | None = None
cls_path: str | None = None
rec_path: str | None = None
for search_dir in (model_cache_dir, package_models_dir):
if not search_dir or not search_dir.is_dir():
continue
if det_path is None:
candidate = search_dir / det_name
if candidate.is_file():
det_path = str(candidate)
if cls_path is None:
candidate = search_dir / cls_name
if candidate.is_file():
cls_path = str(candidate)
if rec_path is None:
candidate = search_dir / rec_name
if candidate.is_file():
rec_path = str(candidate)
return det_path, cls_path, rec_path
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@contextmanager
def _rapidocr_import_context(
*,
Expand Down Expand Up @@ -177,11 +221,35 @@ def _build_runtime_constructor_kwargs(
model_type: str,
ocr_version: str,
model_cache_dir: Path,
package_models_dir: Path | None = None,
) -> dict[str, Any]:
try:
parameters = inspect.signature(runtime_class).parameters
except (TypeError, ValueError):
return {}

has_var_kwargs = any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in parameters.values()
)
if has_var_kwargs:
Comment thread
wehos marked this conversation as resolved.
det_path, cls_path, rec_path = _resolve_rapidocr_model_paths(
model_cache_dir=model_cache_dir,
package_models_dir=package_models_dir,
lang_type=lang_type,
ocr_version=ocr_version,
model_type=model_type,
)
kwargs: dict[str, Any] = {}
if det_path and rec_path:
kwargs["det_model_path"] = det_path
kwargs["rec_model_path"] = rec_path
if cls_path:
kwargs["cls_model_path"] = cls_path
if engine_type:
kwargs["engine_type"] = engine_type
return kwargs

kwargs: dict[str, Any] = {}
direct_values = {
"engine_type": engine_type,
Expand Down Expand Up @@ -273,6 +341,13 @@ def load_rapidocr_runtime(
runtime_class = getattr(module, "RapidOCR", None)
if runtime_class is None:
raise RuntimeError("RapidOCR runtime class not found")
module_file = getattr(module, "__file__", "") or ""
# Sentinel must be None (not Path()) — Path() resolves to CWD and would
# let _resolve_rapidocr_model_paths inadvertently scan the working
# directory if `__file__` were ever missing.
package_models_dir: Path | None = (
Path(module_file).resolve().parent / "models" if module_file else None
)
with _onnxruntime_intra_op_thread_cap(_RAPIDOCR_INFERENCE_THREAD_LIMIT):
runtime = runtime_class(
**_build_runtime_constructor_kwargs(
Expand All @@ -282,6 +357,7 @@ def load_rapidocr_runtime(
model_type=model_type,
ocr_version=ocr_version,
model_cache_dir=model_cache_dir,
package_models_dir=package_models_dir,
)
)
metadata = {
Expand Down
3 changes: 2 additions & 1 deletion plugin/plugins/galgame_plugin/test_ocr_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TesseractOcrBackend,
_default_window_scanner,
)
from plugin.plugins.galgame_plugin.rapidocr_support import DEFAULT_RAPIDOCR_OCR_VERSION


def _noop_logger():
Expand Down Expand Up @@ -159,7 +160,7 @@ async def main() -> None:
rapidocr_engine_type="onnxruntime",
rapidocr_lang_type="ch",
rapidocr_model_type="mobile",
rapidocr_ocr_version="PP-OCRv5",
rapidocr_ocr_version=DEFAULT_RAPIDOCR_OCR_VERSION,
)
mgr = OcrReaderManager(logger=_noop_logger(), config=config)
tick = await mgr.tick(bridge_sdk_available=False, memory_reader_runtime={})
Expand Down
164 changes: 164 additions & 0 deletions plugin/tests/unit/plugins/test_galgame_rapidocr_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

from contextlib import nullcontext
from pathlib import Path
from types import SimpleNamespace

import pytest

from plugin.plugins.galgame_plugin import rapidocr_support


pytestmark = pytest.mark.plugin_unit


class _RapidOcrWithKwargs:
captured_kwargs: dict[str, object] | None = None

def __init__(self, config_path=None, **kwargs) -> None:
del config_path
type(self).captured_kwargs = dict(kwargs)


def _touch(path: Path) -> Path:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("", encoding="utf-8")
return path


def test_rapidocr_kwargs_resolve_configured_model_paths(tmp_path: Path) -> None:
model_cache_dir = tmp_path / "RapidOCR" / "models"
package_models_dir = tmp_path / "package" / "models"
det_path = _touch(package_models_dir / "ch_PP-OCRv4_det_infer.onnx")
cls_path = _touch(package_models_dir / "ch_ppocr_mobile_v2.0_cls_infer.onnx")
rec_path = _touch(package_models_dir / "ch_PP-OCRv4_rec_infer.onnx")

kwargs = rapidocr_support._build_runtime_constructor_kwargs(
_RapidOcrWithKwargs,
engine_type="onnxruntime",
lang_type="ch",
model_type="mobile",
ocr_version="PP-OCRv4",
model_cache_dir=model_cache_dir,
package_models_dir=package_models_dir,
)

assert kwargs == {
"det_model_path": str(det_path),
"cls_model_path": str(cls_path),
"rec_model_path": str(rec_path),
"engine_type": "onnxruntime",
}


def test_rapidocr_kwargs_prefers_user_model_cache(tmp_path: Path) -> None:
model_cache_dir = tmp_path / "RapidOCR" / "models"
package_models_dir = tmp_path / "package" / "models"
user_det_path = _touch(model_cache_dir / "japan_PP-OCRv4_det_infer.onnx")
user_rec_path = _touch(model_cache_dir / "japan_PP-OCRv4_rec_infer.onnx")
package_cls_path = _touch(package_models_dir / "ch_ppocr_mobile_v2.0_cls_infer.onnx")
_touch(package_models_dir / "japan_PP-OCRv4_det_infer.onnx")
_touch(package_models_dir / "japan_PP-OCRv4_rec_infer.onnx")

kwargs = rapidocr_support._build_runtime_constructor_kwargs(
_RapidOcrWithKwargs,
engine_type="onnxruntime",
lang_type="japan",
model_type="mobile",
ocr_version="PP-OCRv4",
model_cache_dir=model_cache_dir,
package_models_dir=package_models_dir,
)

assert kwargs["det_model_path"] == str(user_det_path)
assert kwargs["rec_model_path"] == str(user_rec_path)
assert kwargs["cls_model_path"] == str(package_cls_path)


def test_rapidocr_kwargs_resolves_server_variant_filenames(tmp_path: Path) -> None:
model_cache_dir = tmp_path / "RapidOCR" / "models"
package_models_dir = tmp_path / "package" / "models"
server_det_path = _touch(model_cache_dir / "ch_PP-OCRv4_det_server_infer.onnx")
server_rec_path = _touch(model_cache_dir / "ch_PP-OCRv4_rec_server_infer.onnx")
cls_path = _touch(package_models_dir / "ch_ppocr_mobile_v2.0_cls_infer.onnx")
# Mobile variants exist alongside server ones to ensure model_type drives selection.
_touch(package_models_dir / "ch_PP-OCRv4_det_infer.onnx")
_touch(package_models_dir / "ch_PP-OCRv4_rec_infer.onnx")

kwargs = rapidocr_support._build_runtime_constructor_kwargs(
_RapidOcrWithKwargs,
engine_type="onnxruntime",
lang_type="ch",
model_type="server",
ocr_version="PP-OCRv4",
model_cache_dir=model_cache_dir,
package_models_dir=package_models_dir,
)

assert kwargs == {
"det_model_path": str(server_det_path),
"rec_model_path": str(server_rec_path),
"cls_model_path": str(cls_path),
"engine_type": "onnxruntime",
}


def test_rapidocr_kwargs_omits_model_paths_when_configured_model_is_missing(tmp_path: Path) -> None:
model_cache_dir = tmp_path / "RapidOCR" / "models"
package_models_dir = tmp_path / "package" / "models"
_touch(package_models_dir / "ch_PP-OCRv4_det_infer.onnx")
_touch(package_models_dir / "ch_ppocr_mobile_v2.0_cls_infer.onnx")
_touch(package_models_dir / "ch_PP-OCRv4_rec_infer.onnx")

kwargs = rapidocr_support._build_runtime_constructor_kwargs(
_RapidOcrWithKwargs,
engine_type="onnxruntime",
lang_type="ch",
model_type="mobile",
ocr_version="PP-OCRv5",
model_cache_dir=model_cache_dir,
package_models_dir=package_models_dir,
)

assert kwargs == {"engine_type": "onnxruntime"}


def test_load_rapidocr_runtime_uses_imported_package_models_dir(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
install_target = tmp_path / "RapidOCR"
bundled_package_dir = tmp_path / "bundled" / "rapidocr_onnxruntime"
_touch(bundled_package_dir / "__init__.py")
det_path = _touch(bundled_package_dir / "models" / "ch_PP-OCRv4_det_infer.onnx")
cls_path = _touch(bundled_package_dir / "models" / "ch_ppocr_mobile_v2.0_cls_infer.onnx")
rec_path = _touch(bundled_package_dir / "models" / "ch_PP-OCRv4_rec_infer.onnx")
_RapidOcrWithKwargs.captured_kwargs = None

monkeypatch.setattr(
rapidocr_support.importlib,
"import_module",
lambda name: SimpleNamespace(
RapidOCR=_RapidOcrWithKwargs,
__file__=str(bundled_package_dir / "__init__.py"),
),
)
monkeypatch.setattr(rapidocr_support, "_onnxruntime_intra_op_thread_cap", lambda _limit: nullcontext())

runtime, metadata = rapidocr_support.load_rapidocr_runtime(
install_target_dir_raw=str(install_target),
engine_type="onnxruntime",
lang_type="ch",
model_type="mobile",
ocr_version="PP-OCRv4",
)

assert isinstance(runtime, _RapidOcrWithKwargs)
assert _RapidOcrWithKwargs.captured_kwargs == {
"det_model_path": str(det_path),
"cls_model_path": str(cls_path),
"rec_model_path": str(rec_path),
"engine_type": "onnxruntime",
}
assert metadata["detected_path"] == str(bundled_package_dir.resolve())
assert metadata["selected_model"] == "PP-OCRv4/ch/mobile"
Loading