Skip to content

Commit 9d75f8e

Browse files
feat: 添加数据集收集功能,用于自动驾驶模型训练 (#186)
* feat: 添加数据集收集功能,用于自动驾驶模型训练 * 更改任务描述 * 添加i18n,优化文件结构
1 parent a6d8a8a commit 9d75f8e

11 files changed

Lines changed: 347 additions & 6 deletions

File tree

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import ctypes
2+
import json
3+
import time
4+
from datetime import datetime
5+
from pathlib import Path
6+
from typing import Any
7+
from PIL import Image
8+
9+
import cv2
10+
import numpy as np
11+
from maa.agent.agent_server import AgentServer
12+
from maa.custom_action import CustomAction
13+
from maa.context import Context
14+
15+
from utils.logger import logger
16+
from utils.maafocus import Print
17+
18+
19+
_KEY_LABELS = {
20+
0: "none",
21+
1: "A",
22+
2: "D",
23+
3: "W",
24+
4: "S",
25+
5: "AW",
26+
6: "AS",
27+
7: "DW",
28+
8: "DS",
29+
}
30+
_VK = {"W": 0x57, "A": 0x41, "S": 0x53, "D": 0x44}
31+
_EXAMPLES_PER_SECOND = 2.0
32+
_SEQUENCE_LENGTH = 5
33+
_IMAGE_SIZE = (480, 270)
34+
_DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parents[3] / "debug" / "dataset"
35+
36+
37+
def _parse_params(raw: Any) -> dict[str, Any]:
38+
if not raw:
39+
return {}
40+
if isinstance(raw, dict):
41+
return raw
42+
if isinstance(raw, str):
43+
try:
44+
value = json.loads(raw)
45+
return value if isinstance(value, dict) else {}
46+
except json.JSONDecodeError:
47+
logger.warning(f"invalid dataset_recorder params: {raw!r}")
48+
return {}
49+
50+
51+
def _resolve_output_dir(value: Any) -> Path:
52+
if not value:
53+
return _DEFAULT_OUTPUT_DIR
54+
path = Path(str(value)).expanduser()
55+
if path.is_absolute():
56+
return path
57+
return Path(__file__).resolve().parents[4] / path
58+
59+
60+
def _make_session_dir(base_dir: Path) -> Path:
61+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
62+
session_dir = base_dir / timestamp
63+
suffix = 1
64+
while session_dir.exists():
65+
session_dir = base_dir / f"{timestamp}_{suffix}"
66+
suffix += 1
67+
session_dir.mkdir(parents=True, exist_ok=False)
68+
return session_dir
69+
70+
71+
def _pressed_keys() -> set[str]:
72+
pressed = set()
73+
for key, vk in _VK.items():
74+
if ctypes.windll.user32.GetAsyncKeyState(vk) & 0x8000:
75+
pressed.add(key)
76+
return pressed
77+
78+
79+
def _label_from_keys(pressed: set[str]) -> int:
80+
if pressed == {"A"}:
81+
return 1
82+
if pressed == {"D"}:
83+
return 2
84+
if pressed == {"W"}:
85+
return 3
86+
if pressed == {"S"}:
87+
return 4
88+
if pressed == {"A", "W"}:
89+
return 5
90+
if pressed == {"A", "S"}:
91+
return 6
92+
if pressed == {"D", "W"}:
93+
return 7
94+
if pressed == {"D", "S"}:
95+
return 8
96+
return 0
97+
98+
99+
def _prepare_frame(frame: np.ndarray, size: tuple[int, int]) -> np.ndarray | None:
100+
if frame is None or not isinstance(frame, np.ndarray) or frame.size == 0:
101+
return None
102+
if len(frame.shape) == 3 and frame.shape[2] == 4:
103+
frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)
104+
elif len(frame.shape) == 2:
105+
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
106+
return cv2.resize(frame, size)
107+
108+
109+
def _save_sample(
110+
output_dir: Path,
111+
frames: list[np.ndarray],
112+
labels: list[int],
113+
number: int,
114+
) -> Path:
115+
label_part = "_".join(str(label) for label in labels)
116+
filename = f"K{number}%{label_part}.jpeg"
117+
path = output_dir / filename
118+
image = np.concatenate(frames, axis=1)
119+
Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).save(path)
120+
return path
121+
122+
123+
@AgentServer.custom_action("autonomous_driving_dataset_recorder")
124+
class AutonomousDrivingDatasetRecorder(CustomAction):
125+
def run(self, context: Context, argv: CustomAction.RunArg) -> CustomAction.RunResult:
126+
params = _parse_params(argv.custom_action_param)
127+
dataset_dir = _resolve_output_dir(params.get("output_dir"))
128+
print(f"Dataset recorder output directory: {dataset_dir}")
129+
output_dir = _make_session_dir(dataset_dir)
130+
print(f"Dataset recorder session directory: {output_dir}")
131+
132+
try:
133+
duration_seconds = max(0.0, float(params.get("duration_seconds", 60.0)))
134+
except (TypeError, ValueError):
135+
duration_seconds = 60.0
136+
try:
137+
start_delay_seconds = max(
138+
0.0, float(params.get("start_delay_seconds", 1.0))
139+
)
140+
except (TypeError, ValueError):
141+
start_delay_seconds = 1.0
142+
143+
metadata = {
144+
"format": "K<number>%<label>_<label>_...jpeg",
145+
"labels": _KEY_LABELS,
146+
"sequence_length": _SEQUENCE_LENGTH,
147+
"image_width": _IMAGE_SIZE[0],
148+
"image_height": _IMAGE_SIZE[1],
149+
"examples_per_second": _EXAMPLES_PER_SECOND,
150+
"start_delay_seconds": start_delay_seconds,
151+
}
152+
(output_dir / "metadata.json").write_text(
153+
json.dumps(metadata, indent=2, ensure_ascii=True), encoding="utf-8"
154+
)
155+
156+
controller = context.tasker.controller
157+
frames = [
158+
np.zeros((_IMAGE_SIZE[1], _IMAGE_SIZE[0], 3), dtype=np.uint8)
159+
for _ in range(_SEQUENCE_LENGTH)
160+
]
161+
labels = [0 for _ in range(_SEQUENCE_LENGTH)]
162+
sample_no = 0
163+
saved_count = 0
164+
captured_count = 0
165+
deadline = time.time() + duration_seconds if duration_seconds > 0 else None
166+
last_status = 0.0
167+
168+
Print(
169+
context,
170+
f"Dataset recorder started: {output_dir} "
171+
f"({_EXAMPLES_PER_SECOND:g} samples/s, {duration_seconds:g}s, "
172+
f"delay={start_delay_seconds:g}s)",
173+
)
174+
175+
delay_deadline = time.time() + start_delay_seconds
176+
while not context.tasker.stopping and time.time() < delay_deadline:
177+
remaining = max(0.0, delay_deadline - time.time())
178+
Print(context, f"Dataset recorder starts in {remaining:.1f}s")
179+
time.sleep(min(1.0, remaining))
180+
181+
while not context.tasker.stopping:
182+
if deadline is not None and time.time() >= deadline:
183+
break
184+
start = time.time()
185+
image = controller.post_screencap().wait().get()
186+
frame = _prepare_frame(image, _IMAGE_SIZE)
187+
if frame is None:
188+
logger.warning("dataset_recorder: empty screenshot, retrying")
189+
time.sleep(0.1)
190+
continue
191+
192+
pressed = _pressed_keys()
193+
label = _label_from_keys(pressed)
194+
frames = frames[1:] + [frame]
195+
labels = labels[1:] + [label]
196+
captured_count += 1
197+
198+
if captured_count < _SEQUENCE_LENGTH:
199+
wait_time = (start + 1.0 / _EXAMPLES_PER_SECOND) - time.time()
200+
if wait_time > 0:
201+
time.sleep(wait_time)
202+
continue
203+
204+
_save_sample(output_dir, frames, labels, sample_no)
205+
sample_no += 1
206+
saved_count += 1
207+
208+
now = time.time()
209+
if now - last_status >= 2.0:
210+
Print(
211+
context,
212+
f"Dataset recorder: saved={saved_count}, "
213+
f"last={_KEY_LABELS[label]}, keys={''.join(sorted(pressed)) or '-'}",
214+
)
215+
last_status = now
216+
217+
wait_time = (start + 1.0 / _EXAMPLES_PER_SECOND) - time.time()
218+
if wait_time > 0:
219+
time.sleep(wait_time)
220+
221+
Print(context, f"Dataset recorder stopped: saved={saved_count}, dir={output_dir}")
222+
return CustomAction.RunResult(success=True)

agent/custom/action/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .Common.alt_click import *
2020
from .furniture_claim import *
2121
from .auto_piano.action import *
22+
from .DatasetCollection.autonomous_driving_dataset_recorder import *
2223

2324
__all__ = [
2425
"AutoMakeCoffee",
@@ -40,4 +41,5 @@
4041
"AltClick",
4142
"FurnitureClaim",
4243
"AutoPlayPiano",
44+
"AutonomousDrivingDatasetRecorder",
4345
]

assets/interface.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@
9292
"name": "RealTimeAssist",
9393
"label": "$group.RealTimeAssist.label",
9494
"default_expand": true
95+
},
96+
{
97+
"name": "DatasetCollection",
98+
"label": "$group.DatasetCollection.label",
99+
"default_expand": true
95100
}
96101
],
97102
"task": [],
@@ -114,6 +119,8 @@
114119
// group:others
115120
"resource/tasks/FountainCheckin.json",
116121
"resource/tasks/AutoPiano.json",
122+
// group:DatasetCollection
123+
"resource/tasks/AutonomousDrivingDataset.json",
117124
// preset
118125
"resource/tasks/preset/AFK.json",
119126
// "resource/tasks/preset/FullDaily.json",
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"AutonomousDrivingDatasetRecorder": {
3+
"enabled": true,
4+
"action": "Custom",
5+
"custom_action": "autonomous_driving_dataset_recorder",
6+
"custom_action_param": {
7+
"output_dir": "debug/dataset",
8+
"duration_seconds": 60,
9+
"start_delay_seconds": 5
10+
}
11+
}
12+
}

assets/resource/locales/interface/en_us.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"group.CityTycoon.label": "City Tycoon",
1111
"group.HethereauHobbies.label": "Hethereau Hobbies",
1212
"group.RealTimeAssist.label": "Real-time Assistance",
13+
"group.DatasetCollection.label": "DatasetCollection",
1314
"preset.AFK.label": "AFK Tasks",
1415
"preset.AFK.description": "Suitable for long-running AFK tasks",
1516
"preset.FullDaily.label": "Full Daily",
@@ -133,5 +134,12 @@
133134
"task_auto_piano_input_speed_label": "Playback Speed",
134135
"task_auto_piano_input_transpose_label": "Transpose",
135136
"task_auto_piano_input_speed_pattern_msg": "Please enter a number greater than 0, e.g. 1.0 or 1.25.",
136-
"task_auto_piano_input_transpose_pattern_msg": "Please enter an integer number of semitones, e.g. -12, 0, 7."
137+
"task_auto_piano_input_transpose_pattern_msg": "Please enter an integer number of semitones, e.g. -12, 0, 7.",
138+
"task_auto_drive_dataset_recorder_label": "Autonomous Driving Dataset Collection",
139+
"task_auto_drive_dataset_recorder_desc": "Before starting this task, make sure the minimap destination is set. The default recording duration is 60 seconds, with 2 samples per second.",
140+
"task_auto_drive_dataset_option_settings": "Recording Settings",
141+
"task_auto_drive_dataset_input_output_dir_label": "Output Directory",
142+
"task_auto_drive_dataset_input_duration_label": "Recording Duration",
143+
"task_auto_drive_dataset_input_duration_desc": "Recording duration, 60 seconds by default.",
144+
"task_auto_drive_dataset_input_start_delay_label": "Start Delay"
137145
}

assets/resource/locales/interface/ja_jp.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"group.CityTycoon.label": "シティ名声",
1111
"group.HethereauHobbies.label": "シティライフ",
1212
"group.RealTimeAssist.label": "リアルタイム支援",
13+
"group.DatasetCollection.label": "データセットコレクション",
1314
"preset.AFK.label": "放置タスク",
1415
"preset.AFK.description": "長時間の放置プレイに適したタスク",
1516
"preset.FullDaily.label": "フルデイリー",
@@ -133,5 +134,12 @@
133134
"task_auto_piano_input_speed_label": "再生速度",
134135
"task_auto_piano_input_transpose_label": "移調",
135136
"task_auto_piano_input_speed_pattern_msg": "0より大きい数値を入力してください。例:1.0、1.25。",
136-
"task_auto_piano_input_transpose_pattern_msg": "半音数を整数で入力してください。例:-12、0、7。"
137+
"task_auto_piano_input_transpose_pattern_msg": "半音数を整数で入力してください。例:-12、0、7。",
138+
"task_auto_drive_dataset_recorder_label": "自動運転データセット収集",
139+
"task_auto_drive_dataset_recorder_desc": "このタスクを開始する前に、ミニマップの目的地が設定されていることを確認してください。デフォルトの録画時間は 60 秒で、1 秒あたり 2 回サンプリングします。",
140+
"task_auto_drive_dataset_option_settings": "録画設定",
141+
"task_auto_drive_dataset_input_output_dir_label": "出力ディレクトリ",
142+
"task_auto_drive_dataset_input_duration_label": "録画時間",
143+
"task_auto_drive_dataset_input_duration_desc": "録画時間。デフォルトは 60 秒です。",
144+
"task_auto_drive_dataset_input_start_delay_label": "開始遅延"
137145
}

assets/resource/locales/interface/ko_kr.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"group.CityTycoon.label": "도시 타이쿤",
1111
"group.HethereauHobbies.label": "도시 일상",
1212
"group.RealTimeAssist.label": "실시간 지원",
13+
"group.DatasetCollection.label": "데이터셋 수집",
1314
"preset.AFK.label": "잠수 작업",
1415
"preset.AFK.description": "장시간 잠수 모드에 적합한 작업",
1516
"preset.FullDaily.label": "전체 일일",
@@ -133,5 +134,12 @@
133134
"task_auto_piano_input_speed_label": "재생 속도",
134135
"task_auto_piano_input_transpose_label": "조옮김",
135136
"task_auto_piano_input_speed_pattern_msg": "0보다 큰 숫자를 입력하세요. 예: 1.0, 1.25.",
136-
"task_auto_piano_input_transpose_pattern_msg": "반음 수를 정수로 입력하세요. 예: -12, 0, 7."
137+
"task_auto_piano_input_transpose_pattern_msg": "반음 수를 정수로 입력하세요. 예: -12, 0, 7.",
138+
"task_auto_drive_dataset_recorder_label": "자율 주행 데이터세트 수집",
139+
"task_auto_drive_dataset_recorder_desc": "이 작업을 시작하기 전에 미니맵 목적지가 설정되어 있는지 확인하세요. 기본 녹화 시간은 60초이며 초당 2회 샘플링합니다.",
140+
"task_auto_drive_dataset_option_settings": "녹화 설정",
141+
"task_auto_drive_dataset_input_output_dir_label": "출력 디렉터리",
142+
"task_auto_drive_dataset_input_duration_label": "녹화 시간",
143+
"task_auto_drive_dataset_input_duration_desc": "녹화 시간이며 기본값은 60초입니다.",
144+
"task_auto_drive_dataset_input_start_delay_label": "시작 지연"
137145
}

assets/resource/locales/interface/zh_cn.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"group.CityTycoon.label": "都市大亨",
1111
"group.HethereauHobbies.label": "都市闲趣",
1212
"group.RealTimeAssist.label": "实时辅助",
13+
"group.DatasetCollection.label": "数据集收集",
1314
"preset.AFK.label": "挂机任务",
1415
"preset.AFK.description": "适合挂机长时间进行的任务",
1516
"preset.FullDaily.label": "全套日常",
@@ -133,5 +134,12 @@
133134
"task_auto_piano_input_speed_label": "播放速度",
134135
"task_auto_piano_input_transpose_label": "转调",
135136
"task_auto_piano_input_speed_pattern_msg": "请输入大于 0 的数字,例如 1.0 或 1.25。",
136-
"task_auto_piano_input_transpose_pattern_msg": "请输入整数半音数,例如 -12、0、7。"
137+
"task_auto_piano_input_transpose_pattern_msg": "请输入整数半音数,例如 -12、0、7。",
138+
"task_auto_drive_dataset_recorder_label": "自动驾驶数据集收集",
139+
"task_auto_drive_dataset_recorder_desc": "开始此任务前确保小地图已设置好目的地,默认录制时长为 60 秒,每秒采样 2 次。",
140+
"task_auto_drive_dataset_option_settings": "录制设置",
141+
"task_auto_drive_dataset_input_output_dir_label": "输出目录",
142+
"task_auto_drive_dataset_input_duration_label": "录制时长",
143+
"task_auto_drive_dataset_input_duration_desc": "录制时长,默认为 60 秒。",
144+
"task_auto_drive_dataset_input_start_delay_label": "录制延迟"
137145
}

assets/resource/locales/interface/zh_tw.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"group.CityTycoon.label": "都市大亨",
1111
"group.HethereauHobbies.label": "都市閒趣",
1212
"group.RealTimeAssist.label": "實時輔助",
13+
"group.DatasetCollection.label": "數據集收集",
1314
"preset.AFK.label": "掛機任務",
1415
"preset.AFK.description": "適合掛機長時間進行的任務",
1516
"preset.FullDaily.label": "全套日常",
@@ -133,5 +134,12 @@
133134
"task_auto_piano_input_speed_label": "播放速度",
134135
"task_auto_piano_input_transpose_label": "轉調",
135136
"task_auto_piano_input_speed_pattern_msg": "請輸入大於 0 的數字,例如 1.0 或 1.25。",
136-
"task_auto_piano_input_transpose_pattern_msg": "請輸入整數半音數,例如 -12、0、7。"
137+
"task_auto_piano_input_transpose_pattern_msg": "請輸入整數半音數,例如 -12、0、7。",
138+
"task_auto_drive_dataset_recorder_label": "自動駕駛資料集收集",
139+
"task_auto_drive_dataset_recorder_desc": "開始此任務前,請確保小地圖已設定好目的地,預設錄製時長為 60 秒,每秒取樣 2 次。",
140+
"task_auto_drive_dataset_option_settings": "錄製設定",
141+
"task_auto_drive_dataset_input_output_dir_label": "輸出目錄",
142+
"task_auto_drive_dataset_input_duration_label": "錄製時長",
143+
"task_auto_drive_dataset_input_duration_desc": "錄製時長,預設為 60 秒。",
144+
"task_auto_drive_dataset_input_start_delay_label": "錄製延遲"
137145
}

assets/resource/tasks/AutoFScroll.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"label": "$task_autofscroll_label",
66
"entry": "AutoFScroll",
77
"description": "$task_autofscroll_desc",
8-
"group": ["RealTimeAssist"]
8+
"group": [
9+
"RealTimeAssist"
10+
]
911
}
1012
]
1113
}

0 commit comments

Comments
 (0)