@@ -33,18 +33,14 @@ class AnglePredictionResult:
3333class AnglePredictor :
3434 def __init__ (
3535 self ,
36- model_path : Path | None = None ,
3736 backend : str | None = None ,
38- pointer_roi : list [int ] | None = None ,
3937 threshold : float = 0.0 ,
4038 debug : bool = False ,
4139 ):
42- if model_path is None :
43- model_path = resource_base_path () / "model/navi/pointer_model.onnx"
44-
40+ model_path = resource_base_path () / "model/navi/pointer_model.onnx"
4541 self .model_path = Path (model_path )
46- self .backend = self ._resolve_backend (backend )
47- self .pointer_roi = pointer_roi or [73 , 60 , 64 , 64 ]
42+ self .backend = self .resolve_backend (backend )
43+ self .pointer_roi = [73 , 60 , 64 , 64 ]
4844 self .threshold = threshold
4945 self .debug = debug
5046 self ._session_cache = {}
@@ -55,7 +51,7 @@ def __init__(
5551 }
5652
5753 def predict (self , frame : np .ndarray ) -> AnglePredictionResult :
58- session , _provider_name = self ._get_session ()
54+ session , _ = self .get_session ()
5955 input_name = session .get_inputs ()[0 ].name
6056
6157 if frame .shape [2 ] == 4 :
@@ -161,10 +157,10 @@ def close_debug(self) -> None:
161157 cv2 .destroyWindow ("Angle Predictor" )
162158
163159 def provider_name (self ) -> str :
164- _session , provider_name = self ._get_session ()
160+ _ , provider_name = self .get_session ()
165161 return provider_name
166162
167- def _resolve_backend (self , backend : str | None ) -> str :
163+ def resolve_backend (self , backend : str | None ) -> str :
168164 backend = (
169165 str (backend or os .environ .get ("MAA_ONNX_BACKEND" , "cpu" )).strip ().lower ()
170166 )
@@ -184,7 +180,7 @@ def _resolve_backend(self, backend: str | None) -> str:
184180 return "cpu"
185181 return backend
186182
187- def _get_session (self ):
183+ def get_session (self ):
188184 backend = self .backend
189185 if backend in self ._session_cache :
190186 return self ._session_cache [backend ]
@@ -213,71 +209,3 @@ def _get_session(self):
213209 )
214210 self ._session_cache [backend ] = (session , provider_name )
215211 return self ._session_cache [backend ]
216-
217-
218- @AgentServer .custom_action ("predict_angle" )
219- class AnglePredictorTestAction (CustomAction ):
220- def run (
221- self , context : Context , argv : CustomAction .RunArg
222- ) -> CustomAction .RunResult :
223- params = _load_params (argv .custom_action_param )
224- debug = bool (params .get ("debug" , False ))
225- frame_interval = float (params .get ("frame_interval" , 0.016 ))
226- threshold = float (params .get ("threshold" , 0.0 ))
227- pointer_roi = params .get ("pointer_roi" ) or None
228- backend = params .get ("backend" )
229-
230- try :
231- predictor = AnglePredictor (
232- backend = backend ,
233- pointer_roi = pointer_roi ,
234- threshold = threshold ,
235- debug = debug ,
236- )
237- provider_name = predictor .provider_name ()
238- except Exception as exc :
239- logger .error (f"Angle predictor init failed: { exc } " )
240- return CustomAction .RunResult (success = False )
241-
242- logger .info (
243- f"Angle predictor started: backend={ predictor .backend } , provider={ provider_name } , debug={ debug } "
244- )
245- controller = context .tasker .controller
246- last_result = AnglePredictionResult (found = False , angle = None , confidence = 0.0 )
247-
248- try :
249- while not context .tasker .stopping :
250- started = time .perf_counter ()
251- frame = controller .post_screencap ().wait ().get ()
252- if frame is None :
253- continue
254-
255- last_result = predictor .predict (frame )
256-
257- if debug and (cv2 .waitKey (1 ) & 0xFF == ord ("q" )):
258- break
259-
260- sleep_time = frame_interval - (time .perf_counter () - started )
261- if sleep_time > 0 :
262- time .sleep (sleep_time )
263- except Exception as exc :
264- logger .error (f"Angle predictor failed: { exc } " )
265- return CustomAction .RunResult (success = False )
266- finally :
267- if debug :
268- cv2 .destroyAllWindows ()
269-
270- return CustomAction .RunResult (success = last_result .found )
271-
272-
273- def _load_params (custom_action_param ) -> dict :
274- if not custom_action_param :
275- return {}
276- if isinstance (custom_action_param , dict ):
277- return custom_action_param
278- try :
279- params = json .loads (custom_action_param )
280- return params if isinstance (params , dict ) else {}
281- except Exception as exc :
282- logger .warning (f"Parse custom_action_param failed, use defaults: { exc } " )
283- return {}
0 commit comments