Skip to content

Commit 1ac121a

Browse files
authored
Merge pull request #284 from 1bananachicken/feature-autonavi
feat: 实时定位推送到在线地图,方向预测
2 parents 8060e15 + 6079380 commit 1ac121a

28 files changed

Lines changed: 1403 additions & 12 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ debug
460460
.maa/Debugger/store.json
461461
assets/resource/base/model/detect/limbo_stage_lightest.onnx
462462
assets/resource/base/model/detect/sos_nodes.onnx
463-
463+
.claude/
464464
.nicegui/
465465

466466
# local gui test

.vscode/settings.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,7 @@
5858
},
5959
"markdownlint.config": {
6060
"MD028": false
61-
}
61+
},
62+
"python-envs.defaultEnvManager": "ms-python.python:conda",
63+
"python-envs.defaultPackageManager": "ms-python.python:conda"
6264
}

agent/custom/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .action import *
1+
from .action import *
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .navi_websocket import *
2+
3+
__all__ = [
4+
"NaviWebSocketAction",
5+
]
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import json
2+
import math
3+
import os
4+
import time
5+
from dataclasses import dataclass
6+
from pathlib import Path
7+
8+
import cv2
9+
import numpy as np
10+
import onnxruntime
11+
12+
from ..Common.logger import get_logger
13+
from .resources import resource_base_path
14+
15+
from maa.agent.agent_server import AgentServer
16+
from maa.context import Context
17+
from maa.custom_action import CustomAction
18+
19+
logger = get_logger(__name__)
20+
21+
22+
@dataclass
23+
class AnglePredictionResult:
24+
found: bool
25+
angle: float | None
26+
confidence: float
27+
bbox: tuple[int, int, int, int] | None = None
28+
tip: tuple[int, int] | None = None
29+
left: tuple[int, int] | None = None
30+
right: tuple[int, int] | None = None
31+
32+
33+
class AnglePredictor:
34+
def __init__(
35+
self,
36+
backend: str | None = None,
37+
threshold: float = 0.0,
38+
debug: bool = False,
39+
):
40+
model_path = resource_base_path() / "model/navi/pointer_model.onnx"
41+
self.model_path = Path(model_path)
42+
self.backend = self.resolve_backend(backend)
43+
self.pointer_roi = [73, 60, 64, 64]
44+
self.threshold = threshold
45+
self.debug = debug
46+
self._session_cache = {}
47+
self._provider_name_map = {
48+
"cpu": "CPUExecutionProvider",
49+
"directml": "DmlExecutionProvider",
50+
"dml": "DmlExecutionProvider",
51+
}
52+
53+
def predict(self, frame: np.ndarray) -> AnglePredictionResult:
54+
session, _ = self.get_session()
55+
input_name = session.get_inputs()[0].name
56+
57+
if frame.shape[2] == 4:
58+
frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)
59+
60+
x, y, w, h = self.pointer_roi
61+
img_crop = frame[y : y + h, x : x + w].copy()
62+
img_rgb = cv2.cvtColor(img_crop, cv2.COLOR_BGR2RGB)
63+
64+
img_input = (img_rgb / 255.0).transpose(2, 0, 1).astype(np.float32)
65+
img_input = np.expand_dims(img_input, axis=0)
66+
67+
output = session.run(None, {input_name: img_input})[0][0]
68+
confidence = output[:, 4]
69+
best_idx = int(np.argmax(confidence))
70+
best_pred = output[best_idx]
71+
max_conf = float(confidence[best_idx])
72+
73+
result = AnglePredictionResult(found=False, angle=None, confidence=max_conf)
74+
if max_conf > self.threshold:
75+
kpts = best_pred[6:].reshape(3, 3)
76+
tip = kpts[0][:2]
77+
left = kpts[1][:2]
78+
right = kpts[2][:2]
79+
tail_center = (left + right) / 2
80+
81+
dx = tip[0] - tail_center[0]
82+
dy = tip[1] - tail_center[1]
83+
angle = math.degrees(math.atan2(dx, -dy)) % 360
84+
85+
x1, y1, x2, y2 = best_pred[0:4]
86+
result = AnglePredictionResult(
87+
found=True,
88+
angle=float(angle),
89+
confidence=max_conf,
90+
bbox=(int(x1), int(y1), int(x2), int(y2)),
91+
tip=(int(tip[0]), int(tip[1])),
92+
left=(int(left[0]), int(left[1])),
93+
right=(int(right[0]), int(right[1])),
94+
)
95+
96+
if self.debug:
97+
self.show_debug(img_crop, result)
98+
99+
return result
100+
101+
def show_debug(self, img_crop: np.ndarray, result: AnglePredictionResult) -> None:
102+
display_img = img_crop.copy()
103+
if result.found and result.bbox and result.tip and result.left and result.right:
104+
cv2.rectangle(
105+
display_img,
106+
(result.bbox[0], result.bbox[1]),
107+
(result.bbox[2], result.bbox[3]),
108+
(0, 255, 0),
109+
1,
110+
)
111+
tail = (
112+
int((result.left[0] + result.right[0]) / 2),
113+
int((result.left[1] + result.right[1]) / 2),
114+
)
115+
cv2.line(display_img, tail, result.tip, (255, 0, 255), 2)
116+
cv2.circle(display_img, result.tip, 2, (0, 0, 255), -1)
117+
cv2.circle(display_img, result.left, 2, (255, 255, 0), -1)
118+
cv2.circle(display_img, result.right, 2, (255, 255, 0), -1)
119+
120+
display_img = cv2.resize(display_img, (400, 400), interpolation=cv2.INTER_CUBIC)
121+
if result.found and result.angle is not None:
122+
cv2.putText(
123+
display_img,
124+
f"Angle: {result.angle:05.1f} deg",
125+
(10, 25),
126+
cv2.FONT_HERSHEY_SIMPLEX,
127+
0.6,
128+
(0, 255, 255),
129+
1,
130+
cv2.LINE_AA,
131+
)
132+
cv2.putText(
133+
display_img,
134+
f"Conf: {result.confidence:.2f}",
135+
(10, 50),
136+
cv2.FONT_HERSHEY_SIMPLEX,
137+
0.6,
138+
(0, 255, 0),
139+
1,
140+
cv2.LINE_AA,
141+
)
142+
else:
143+
cv2.putText(
144+
display_img,
145+
"NO TARGET",
146+
(10, 30),
147+
cv2.FONT_HERSHEY_SIMPLEX,
148+
0.8,
149+
(0, 0, 255),
150+
2,
151+
cv2.LINE_AA,
152+
)
153+
cv2.imshow("Angle Predictor", display_img)
154+
155+
def close_debug(self) -> None:
156+
if self.debug:
157+
cv2.destroyWindow("Angle Predictor")
158+
159+
def provider_name(self) -> str:
160+
_, provider_name = self.get_session()
161+
return provider_name
162+
163+
def resolve_backend(self, backend: str | None) -> str:
164+
backend = (
165+
str(backend or os.environ.get("MAA_ONNX_BACKEND", "cpu")).strip().lower()
166+
)
167+
if backend == "auto":
168+
available = onnxruntime.get_available_providers()
169+
if "DmlExecutionProvider" in available:
170+
return "directml"
171+
return "cpu"
172+
173+
provider_name_map = {
174+
"cpu": "CPUExecutionProvider",
175+
"directml": "DmlExecutionProvider",
176+
"dml": "DmlExecutionProvider",
177+
}
178+
if backend not in provider_name_map:
179+
logger.warning(f"Unknown inference backend {backend}, fallback to CPU")
180+
return "cpu"
181+
return backend
182+
183+
def get_session(self):
184+
backend = self.backend
185+
if backend in self._session_cache:
186+
return self._session_cache[backend]
187+
188+
if not self.model_path.exists():
189+
raise FileNotFoundError(f"Angle model not found: {self.model_path}")
190+
191+
provider_name = self._provider_name_map[backend]
192+
available = onnxruntime.get_available_providers()
193+
if provider_name not in available:
194+
logger.warning(
195+
f"Requested provider {provider_name} is unavailable, available providers: {available}; fallback to CPU"
196+
)
197+
backend = "cpu"
198+
self.backend = backend
199+
provider_name = self._provider_name_map[backend]
200+
201+
provider_options = (
202+
[{"device_id": 0}] if provider_name == "DmlExecutionProvider" else None
203+
)
204+
session = onnxruntime.InferenceSession(
205+
str(self.model_path),
206+
sess_options=onnxruntime.SessionOptions(),
207+
providers=[provider_name],
208+
provider_options=provider_options,
209+
)
210+
self._session_cache[backend] = (session, provider_name)
211+
return self._session_cache[backend]

0 commit comments

Comments
 (0)