|
| 1 | +import os |
1 | 2 | import traceback |
2 | 3 | from pathlib import Path |
3 | 4 | from typing import Any, Dict, List |
@@ -51,28 +52,29 @@ def __init__(self, cfg: RapidLayoutInput): |
51 | 52 |
|
52 | 53 | def _init_config(self, cfg: DictConfig) -> Dict[str, str]: |
53 | 54 | config = {} |
54 | | - engine_cfg = cfg.get("engine_cfg", {}) |
55 | 55 |
|
56 | 56 | def _set(k, v, *, cast=str): |
57 | 57 | if v is not None and v != -1: |
58 | | - config[k] = cast(v) |
59 | | - |
60 | | - _set("INFERENCE_NUM_THREADS", |
61 | | - engine_cfg.get("inference_num_threads", -1), |
62 | | - cast=lambda x: str(min(x, os.cpu_count())) if x > 0 else None) |
63 | | - |
64 | | - _set("PERFORMANCE_HINT", |
65 | | - engine_cfg.get("performance_hint")) |
66 | | - _set("PERFORMANCE_HINT_NUM_REQUESTS", |
67 | | - engine_cfg.get("performance_num_requests")) |
68 | | - _set("ENABLE_CPU_PINNING", |
69 | | - engine_cfg.get("enable_cpu_pinning")) |
70 | | - _set("NUM_STREAMS", |
71 | | - engine_cfg.get("num_streams")) |
72 | | - _set("ENABLE_HYPER_THREADING", |
73 | | - engine_cfg.get("enable_hyper_threading")) |
74 | | - _set("SCHEDULING_CORE_TYPE", |
75 | | - engine_cfg.get("scheduling_core_type")) |
| 58 | + casted_value = cast(v) |
| 59 | + if casted_value is not None: |
| 60 | + config[k] = casted_value |
| 61 | + |
| 62 | + inference_num_threads = cfg.get("inference_num_threads", -1) |
| 63 | + if inference_num_threads > 0: |
| 64 | + cpu_count = os.cpu_count() |
| 65 | + if cpu_count is not None: |
| 66 | + _set( |
| 67 | + "INFERENCE_NUM_THREADS", |
| 68 | + inference_num_threads, |
| 69 | + cast=lambda x: str(min(x, cpu_count)), |
| 70 | + ) |
| 71 | + |
| 72 | + _set("PERFORMANCE_HINT", cfg.get("performance_hint")) |
| 73 | + _set("PERFORMANCE_HINT_NUM_REQUESTS", cfg.get("performance_num_requests")) |
| 74 | + _set("ENABLE_CPU_PINNING", cfg.get("enable_cpu_pinning")) |
| 75 | + _set("NUM_STREAMS", cfg.get("num_streams")) |
| 76 | + _set("ENABLE_HYPER_THREADING", cfg.get("enable_hyper_threading")) |
| 77 | + _set("SCHEDULING_CORE_TYPE", cfg.get("scheduling_core_type")) |
76 | 78 |
|
77 | 79 | if config: |
78 | 80 | self.logger.info("OpenVINO runtime config: %s", config) |
@@ -107,14 +109,22 @@ def characters(self): |
107 | 109 | return self.get_character_list() |
108 | 110 |
|
109 | 111 | def get_character_list(self, key: str = "character") -> List[str]: |
110 | | - val = self.model.get_rt_info()["framework"][key] |
111 | | - return val.value.splitlines() |
| 112 | + rt_info = self.model.get_rt_info() |
| 113 | + framework_info = rt_info.get("framework", {}) |
| 114 | + val = framework_info.get(key) |
| 115 | + if val is None or not hasattr(val, "value"): |
| 116 | + return [] |
| 117 | + value = getattr(val, "value", None) |
| 118 | + if value is None: |
| 119 | + return [] |
| 120 | + return value.splitlines() |
112 | 121 |
|
113 | 122 | def have_key(self, key: str = "character") -> bool: |
114 | 123 | try: |
115 | 124 | rt_info = self.model.get_rt_info() |
116 | | - return key in rt_info |
117 | | - except: |
| 125 | + framework_info = rt_info.get("framework", {}) |
| 126 | + return key in framework_info |
| 127 | + except (AttributeError, TypeError, KeyError): |
118 | 128 | return False |
119 | 129 |
|
120 | 130 |
|
|
0 commit comments