|
| 1 | +import ctypes |
| 2 | +import json |
| 3 | +import time |
| 4 | +from datetime import datetime |
| 5 | +from pathlib import Path |
| 6 | +from typing import Any |
| 7 | +from PIL import Image |
| 8 | + |
| 9 | +import cv2 |
| 10 | +import numpy as np |
| 11 | +from maa.agent.agent_server import AgentServer |
| 12 | +from maa.custom_action import CustomAction |
| 13 | +from maa.context import Context |
| 14 | + |
| 15 | +from utils.logger import logger |
| 16 | +from utils.maafocus import Print |
| 17 | + |
| 18 | + |
| 19 | +_KEY_LABELS = { |
| 20 | + 0: "none", |
| 21 | + 1: "A", |
| 22 | + 2: "D", |
| 23 | + 3: "W", |
| 24 | + 4: "S", |
| 25 | + 5: "AW", |
| 26 | + 6: "AS", |
| 27 | + 7: "DW", |
| 28 | + 8: "DS", |
| 29 | +} |
| 30 | +_VK = {"W": 0x57, "A": 0x41, "S": 0x53, "D": 0x44} |
| 31 | +_EXAMPLES_PER_SECOND = 2.0 |
| 32 | +_SEQUENCE_LENGTH = 5 |
| 33 | +_IMAGE_SIZE = (480, 270) |
| 34 | +_DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parents[3] / "debug" / "dataset" |
| 35 | + |
| 36 | + |
| 37 | +def _parse_params(raw: Any) -> dict[str, Any]: |
| 38 | + if not raw: |
| 39 | + return {} |
| 40 | + if isinstance(raw, dict): |
| 41 | + return raw |
| 42 | + if isinstance(raw, str): |
| 43 | + try: |
| 44 | + value = json.loads(raw) |
| 45 | + return value if isinstance(value, dict) else {} |
| 46 | + except json.JSONDecodeError: |
| 47 | + logger.warning(f"invalid dataset_recorder params: {raw!r}") |
| 48 | + return {} |
| 49 | + |
| 50 | + |
| 51 | +def _resolve_output_dir(value: Any) -> Path: |
| 52 | + if not value: |
| 53 | + return _DEFAULT_OUTPUT_DIR |
| 54 | + path = Path(str(value)).expanduser() |
| 55 | + if path.is_absolute(): |
| 56 | + return path |
| 57 | + return Path(__file__).resolve().parents[4] / path |
| 58 | + |
| 59 | + |
| 60 | +def _make_session_dir(base_dir: Path) -> Path: |
| 61 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 62 | + session_dir = base_dir / timestamp |
| 63 | + suffix = 1 |
| 64 | + while session_dir.exists(): |
| 65 | + session_dir = base_dir / f"{timestamp}_{suffix}" |
| 66 | + suffix += 1 |
| 67 | + session_dir.mkdir(parents=True, exist_ok=False) |
| 68 | + return session_dir |
| 69 | + |
| 70 | + |
| 71 | +def _pressed_keys() -> set[str]: |
| 72 | + pressed = set() |
| 73 | + for key, vk in _VK.items(): |
| 74 | + if ctypes.windll.user32.GetAsyncKeyState(vk) & 0x8000: |
| 75 | + pressed.add(key) |
| 76 | + return pressed |
| 77 | + |
| 78 | + |
| 79 | +def _label_from_keys(pressed: set[str]) -> int: |
| 80 | + if pressed == {"A"}: |
| 81 | + return 1 |
| 82 | + if pressed == {"D"}: |
| 83 | + return 2 |
| 84 | + if pressed == {"W"}: |
| 85 | + return 3 |
| 86 | + if pressed == {"S"}: |
| 87 | + return 4 |
| 88 | + if pressed == {"A", "W"}: |
| 89 | + return 5 |
| 90 | + if pressed == {"A", "S"}: |
| 91 | + return 6 |
| 92 | + if pressed == {"D", "W"}: |
| 93 | + return 7 |
| 94 | + if pressed == {"D", "S"}: |
| 95 | + return 8 |
| 96 | + return 0 |
| 97 | + |
| 98 | + |
| 99 | +def _prepare_frame(frame: np.ndarray, size: tuple[int, int]) -> np.ndarray | None: |
| 100 | + if frame is None or not isinstance(frame, np.ndarray) or frame.size == 0: |
| 101 | + return None |
| 102 | + if len(frame.shape) == 3 and frame.shape[2] == 4: |
| 103 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR) |
| 104 | + elif len(frame.shape) == 2: |
| 105 | + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) |
| 106 | + return cv2.resize(frame, size) |
| 107 | + |
| 108 | + |
| 109 | +def _save_sample( |
| 110 | + output_dir: Path, |
| 111 | + frames: list[np.ndarray], |
| 112 | + labels: list[int], |
| 113 | + number: int, |
| 114 | +) -> Path: |
| 115 | + label_part = "_".join(str(label) for label in labels) |
| 116 | + filename = f"K{number}%{label_part}.jpeg" |
| 117 | + path = output_dir / filename |
| 118 | + image = np.concatenate(frames, axis=1) |
| 119 | + Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).save(path) |
| 120 | + return path |
| 121 | + |
| 122 | + |
| 123 | +@AgentServer.custom_action("autonomous_driving_dataset_recorder") |
| 124 | +class AutonomousDrivingDatasetRecorder(CustomAction): |
| 125 | + def run(self, context: Context, argv: CustomAction.RunArg) -> CustomAction.RunResult: |
| 126 | + params = _parse_params(argv.custom_action_param) |
| 127 | + dataset_dir = _resolve_output_dir(params.get("output_dir")) |
| 128 | + print(f"Dataset recorder output directory: {dataset_dir}") |
| 129 | + output_dir = _make_session_dir(dataset_dir) |
| 130 | + print(f"Dataset recorder session directory: {output_dir}") |
| 131 | + |
| 132 | + try: |
| 133 | + duration_seconds = max(0.0, float(params.get("duration_seconds", 60.0))) |
| 134 | + except (TypeError, ValueError): |
| 135 | + duration_seconds = 60.0 |
| 136 | + try: |
| 137 | + start_delay_seconds = max( |
| 138 | + 0.0, float(params.get("start_delay_seconds", 1.0)) |
| 139 | + ) |
| 140 | + except (TypeError, ValueError): |
| 141 | + start_delay_seconds = 1.0 |
| 142 | + |
| 143 | + metadata = { |
| 144 | + "format": "K<number>%<label>_<label>_...jpeg", |
| 145 | + "labels": _KEY_LABELS, |
| 146 | + "sequence_length": _SEQUENCE_LENGTH, |
| 147 | + "image_width": _IMAGE_SIZE[0], |
| 148 | + "image_height": _IMAGE_SIZE[1], |
| 149 | + "examples_per_second": _EXAMPLES_PER_SECOND, |
| 150 | + "start_delay_seconds": start_delay_seconds, |
| 151 | + } |
| 152 | + (output_dir / "metadata.json").write_text( |
| 153 | + json.dumps(metadata, indent=2, ensure_ascii=True), encoding="utf-8" |
| 154 | + ) |
| 155 | + |
| 156 | + controller = context.tasker.controller |
| 157 | + frames = [ |
| 158 | + np.zeros((_IMAGE_SIZE[1], _IMAGE_SIZE[0], 3), dtype=np.uint8) |
| 159 | + for _ in range(_SEQUENCE_LENGTH) |
| 160 | + ] |
| 161 | + labels = [0 for _ in range(_SEQUENCE_LENGTH)] |
| 162 | + sample_no = 0 |
| 163 | + saved_count = 0 |
| 164 | + captured_count = 0 |
| 165 | + deadline = time.time() + duration_seconds if duration_seconds > 0 else None |
| 166 | + last_status = 0.0 |
| 167 | + |
| 168 | + Print( |
| 169 | + context, |
| 170 | + f"Dataset recorder started: {output_dir} " |
| 171 | + f"({_EXAMPLES_PER_SECOND:g} samples/s, {duration_seconds:g}s, " |
| 172 | + f"delay={start_delay_seconds:g}s)", |
| 173 | + ) |
| 174 | + |
| 175 | + delay_deadline = time.time() + start_delay_seconds |
| 176 | + while not context.tasker.stopping and time.time() < delay_deadline: |
| 177 | + remaining = max(0.0, delay_deadline - time.time()) |
| 178 | + Print(context, f"Dataset recorder starts in {remaining:.1f}s") |
| 179 | + time.sleep(min(1.0, remaining)) |
| 180 | + |
| 181 | + while not context.tasker.stopping: |
| 182 | + if deadline is not None and time.time() >= deadline: |
| 183 | + break |
| 184 | + start = time.time() |
| 185 | + image = controller.post_screencap().wait().get() |
| 186 | + frame = _prepare_frame(image, _IMAGE_SIZE) |
| 187 | + if frame is None: |
| 188 | + logger.warning("dataset_recorder: empty screenshot, retrying") |
| 189 | + time.sleep(0.1) |
| 190 | + continue |
| 191 | + |
| 192 | + pressed = _pressed_keys() |
| 193 | + label = _label_from_keys(pressed) |
| 194 | + frames = frames[1:] + [frame] |
| 195 | + labels = labels[1:] + [label] |
| 196 | + captured_count += 1 |
| 197 | + |
| 198 | + if captured_count < _SEQUENCE_LENGTH: |
| 199 | + wait_time = (start + 1.0 / _EXAMPLES_PER_SECOND) - time.time() |
| 200 | + if wait_time > 0: |
| 201 | + time.sleep(wait_time) |
| 202 | + continue |
| 203 | + |
| 204 | + _save_sample(output_dir, frames, labels, sample_no) |
| 205 | + sample_no += 1 |
| 206 | + saved_count += 1 |
| 207 | + |
| 208 | + now = time.time() |
| 209 | + if now - last_status >= 2.0: |
| 210 | + Print( |
| 211 | + context, |
| 212 | + f"Dataset recorder: saved={saved_count}, " |
| 213 | + f"last={_KEY_LABELS[label]}, keys={''.join(sorted(pressed)) or '-'}", |
| 214 | + ) |
| 215 | + last_status = now |
| 216 | + |
| 217 | + wait_time = (start + 1.0 / _EXAMPLES_PER_SECOND) - time.time() |
| 218 | + if wait_time > 0: |
| 219 | + time.sleep(wait_time) |
| 220 | + |
| 221 | + Print(context, f"Dataset recorder stopped: saved={saved_count}, dir={output_dir}") |
| 222 | + return CustomAction.RunResult(success=True) |
0 commit comments