Skip to content

Commit 30d46e1

Browse files
authored
feat(study): 为伴学插件添加图片识别 (#1518)
* feat(study): add vision image support * Fix study vision image submission safety * fix(study): validate vision_image_base64 in study_explain_text entry path External callers can dispatch vision_image_base64 directly to study_explain_text via method(**args), bypassing the MIME/size/feature checks enforced in study_submit_image. Add _normalize_submitted_image_payload validation and llm_vision_enabled gating to prevent malformed or oversized payloads from reaching the LLM call. * fix(study): address vision review feedback * fix(study): clear stale vision cache on capture failure * fix(study): preserve OCR fallback on image submit * Use vision model for study image calls
1 parent 78b8f7b commit 30d46e1

7 files changed

Lines changed: 1152 additions & 21 deletions

File tree

plugin/plugins/study_companion/__init__.py

Lines changed: 130 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import base64
5+
import binascii
56
from datetime import datetime
67
import math
78
from pathlib import Path
@@ -112,6 +113,60 @@ def _register_install_routes() -> None:
112113
)
113114

114115

116+
_MAX_SUBMITTED_IMAGE_BASE64_LENGTH = 10 * 1024 * 1024
117+
_MAX_SUBMITTED_IMAGE_BASE64_ENCODED_LENGTH = (
118+
((_MAX_SUBMITTED_IMAGE_BASE64_LENGTH + 2) // 3) * 4 + 64
119+
)
120+
_SUPPORTED_SUBMITTED_IMAGE_MIME_BY_DATA_URL_PREFIX = {
121+
"data:image/jpeg;base64": "image/jpeg",
122+
"data:image/png;base64": "image/png",
123+
}
124+
125+
126+
def _detect_submitted_image_mime(raw: bytes) -> str:
127+
if raw.startswith(b"\xff\xd8\xff"):
128+
return "image/jpeg"
129+
if raw.startswith(b"\x89PNG\r\n\x1a\n"):
130+
return "image/png"
131+
return ""
132+
133+
134+
def _normalize_submitted_image_payload(image_base64: str) -> str:
135+
image_payload = str(image_base64 or "").strip()
136+
if not image_payload:
137+
raise ValueError("image_base64 is required")
138+
139+
expected_mime = ""
140+
encoded_payload = image_payload
141+
if image_payload.lower().startswith("data:"):
142+
header, separator, encoded_payload = image_payload.partition(",")
143+
expected_mime = _SUPPORTED_SUBMITTED_IMAGE_MIME_BY_DATA_URL_PREFIX.get(
144+
header.strip().lower(),
145+
"",
146+
)
147+
if not separator or not encoded_payload.strip():
148+
raise ValueError("image_base64 data URL is malformed")
149+
if not expected_mime:
150+
raise ValueError("only JPEG/PNG data URLs are supported")
151+
encoded_payload = encoded_payload.strip()
152+
if len(encoded_payload) > _MAX_SUBMITTED_IMAGE_BASE64_ENCODED_LENGTH:
153+
raise ValueError("image_base64 is too large (max 10MB)")
154+
try:
155+
raw = base64.b64decode(encoded_payload, validate=True)
156+
except (binascii.Error, ValueError) as exc:
157+
raise ValueError("image_base64 is not valid base64") from exc
158+
if not raw:
159+
raise ValueError("image_base64 is not valid base64")
160+
if len(raw) > _MAX_SUBMITTED_IMAGE_BASE64_LENGTH:
161+
raise ValueError("image_base64 is too large (max 10MB)")
162+
actual_mime = _detect_submitted_image_mime(raw)
163+
if not actual_mime:
164+
raise ValueError("only JPEG/PNG images are supported")
165+
if expected_mime and actual_mime != expected_mime:
166+
raise ValueError("image_base64 MIME does not match image data")
167+
return f"data:{actual_mime};base64,{encoded_payload}"
168+
169+
115170
def _validated_pomodoro_focus_minutes(
116171
config: StudyConfig, focus_minutes: Any | None
117172
) -> int:
@@ -761,6 +816,25 @@ async def _build_learning_context(
761816
self._knowledge_tracker.get_status_summary,
762817
limit=5,
763818
)
819+
if bool(self._cfg.llm_vision_enabled):
820+
user_image = ""
821+
with self._lock:
822+
user_image = str(self._state.last_vision_image_base64 or "").strip()
823+
if user_image:
824+
context["vision_enabled"] = True
825+
context["vision_image_base64"] = user_image
826+
elif self._ocr_pipeline is not None:
827+
vision_snapshot = self._ocr_pipeline.latest_vision_snapshot()
828+
if vision_snapshot:
829+
context["vision_enabled"] = True
830+
context["vision_image_base64"] = str(
831+
vision_snapshot.get("vision_image_base64") or ""
832+
)
833+
context["vision_snapshot"] = {
834+
key: value
835+
for key, value in vision_snapshot.items()
836+
if key != "vision_image_base64"
837+
}
764838
if extra:
765839
context.update(extra)
766840
return context
@@ -2692,6 +2766,41 @@ async def study_ocr_snapshot(self, **_):
26922766
await self._persist_state()
26932767
return Ok(payload)
26942768

2769+
@plugin_entry(
2770+
id="study_submit_image",
2771+
name=tr("entries.submit_image.name", default="Submit Study Image"),
2772+
description=tr(
2773+
"entries.submit_image.description",
2774+
default="Accept a user image and explain it with the configured vision model.",
2775+
),
2776+
input_schema={
2777+
"type": "object",
2778+
"properties": {
2779+
"image_base64": {"type": "string"},
2780+
"text": {"type": "string", "default": ""},
2781+
},
2782+
"required": ["image_base64"],
2783+
},
2784+
timeout=60.0,
2785+
llm_result_fields=["summary", "reply", "diagnostic"],
2786+
)
2787+
async def study_submit_image(self, image_base64: str, text: str = "", **_):
2788+
try:
2789+
image_payload = _normalize_submitted_image_payload(image_base64)
2790+
except ValueError as exc:
2791+
return Err(SdkError(str(exc)))
2792+
if not bool(self._cfg.llm_vision_enabled):
2793+
return Err(SdkError("llm_vision_enabled is not enabled"))
2794+
normalized_text = str(text or "").strip()
2795+
if normalized_text:
2796+
with self._lock:
2797+
self._state.last_ocr_text = normalized_text
2798+
source_text = normalized_text or "请查看这张图片的内容"
2799+
return await self.study_explain_text(
2800+
text=source_text,
2801+
vision_image_base64=image_payload,
2802+
)
2803+
26952804
@plugin_entry(
26962805
id="study_explain_text",
26972806
name=tr("entries.explain_text.name", default="Explain Study Text"),
@@ -2703,12 +2812,13 @@ async def study_ocr_snapshot(self, **_):
27032812
"type": "object",
27042813
"properties": {
27052814
"text": {"type": "string", "default": ""},
2815+
"vision_image_base64": {"type": "string", "default": ""},
27062816
},
27072817
},
27082818
timeout=45.0,
27092819
llm_result_fields=["summary", "reply", "diagnostic"],
27102820
)
2711-
async def study_explain_text(self, text: str = "", **_):
2821+
async def study_explain_text(self, text: str = "", vision_image_base64: str = "", **_):
27122822
if self._agent is None:
27132823
return Err(SdkError("study tutor agent is not initialized"))
27142824
raw_text = str(text or "").strip()
@@ -2763,17 +2873,28 @@ async def study_explain_text(self, text: str = "", **_):
27632873
source_text = self._state.last_ocr_text
27642874
used_ocr_fallback = bool(source_text.strip())
27652875
# Phase 3: explain with the active mode selected above.
2876+
extra_context: dict[str, Any] = {
2877+
"source": "ocr_snapshot" if used_ocr_fallback or not raw_text else "manual",
2878+
"mode": active_mode,
2879+
"mode_switch": bool(mode_switch.get("changed")),
2880+
"source_text": source_text,
2881+
}
2882+
vision_image_payload = str(vision_image_base64 or "").strip()
2883+
if vision_image_payload:
2884+
if not bool(self._cfg.llm_vision_enabled):
2885+
return Err(SdkError("llm_vision_enabled is not enabled"))
2886+
try:
2887+
vision_image_payload = _normalize_submitted_image_payload(
2888+
vision_image_payload,
2889+
)
2890+
except ValueError as exc:
2891+
return Err(SdkError(str(exc)))
2892+
extra_context["vision_enabled"] = True
2893+
extra_context["vision_image_base64"] = vision_image_payload
27662894
tutor_context = await self._build_learning_context(
27672895
LLM_OPERATION_CONCEPT_EXPLAIN,
27682896
input_text=source_text,
2769-
extra={
2770-
"source": "ocr_snapshot"
2771-
if used_ocr_fallback or not raw_text
2772-
else "manual",
2773-
"mode": active_mode,
2774-
"mode_switch": bool(mode_switch.get("changed")),
2775-
"source_text": source_text,
2776-
},
2897+
extra=extra_context,
27772898
)
27782899
reply = await self._agent.concept_explain(
27792900
source_text,

plugin/plugins/study_companion/models.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ class StudyConfig:
208208
rapidocr_model_type: str = "mobile"
209209
rapidocr_ocr_version: str = "PP-OCRv4"
210210
llm_call_timeout_seconds: float = 30.0
211+
llm_vision_enabled: bool = False
212+
llm_vision_max_image_px: int = 768
211213
fsrs_retention_target: float = 0.90
212214
fsrs_auto_optimize_interval_days: int = 30
213215
knowledge_contribution_opt_in: bool = False
@@ -239,6 +241,10 @@ def __post_init__(self) -> None:
239241
self.llm_call_timeout_seconds = self._clamp_float(
240242
self.llm_call_timeout_seconds, 1.0, 3600.0, 30.0
241243
)
244+
self.llm_vision_enabled = bool(self.llm_vision_enabled)
245+
self.llm_vision_max_image_px = max(
246+
64, min(4096, self._coerce_int(self.llm_vision_max_image_px, 768))
247+
)
242248
self.fsrs_retention_target = self._clamp_float(
243249
self.fsrs_retention_target, 0.1, 0.99, 0.90
244250
)
@@ -316,6 +322,7 @@ class StudyState:
316322
last_error: str = ""
317323
last_started_at: str = ""
318324
last_ocr_text: str = ""
325+
last_vision_image_base64: str = ""
319326
last_ocr_at: str = ""
320327
last_screen_classification: dict[str, Any] = field(default_factory=dict)
321328
recent_screen_classifications: list[dict[str, Any]] = field(default_factory=list)
@@ -333,7 +340,9 @@ class StudyState:
333340
dependency_status: dict[str, Any] = field(default_factory=dict)
334341

335342
def to_dict(self) -> dict[str, Any]:
336-
return asdict(self)
343+
payload = asdict(self)
344+
payload.pop("last_vision_image_base64", None)
345+
return payload
337346

338347

339348
@dataclass(slots=True)
@@ -541,6 +550,16 @@ def _clamp(value: float, minimum: float, maximum: float, default: float) -> floa
541550
3600.0,
542551
30.0,
543552
),
553+
llm_vision_enabled=_bool(
554+
llm, "llm_vision_enabled", False, "llm_vision_enabled"
555+
),
556+
llm_vision_max_image_px=max(
557+
64,
558+
min(
559+
4096,
560+
_int(llm, "llm_vision_max_image_px", 768, "llm_vision_max_image_px"),
561+
),
562+
),
544563
fsrs_retention_target=_clamp(
545564
_float(fsrs, "retention_target", 0.90, "fsrs_retention_target"),
546565
0.1,

plugin/plugins/study_companion/plugin.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ enabled = true
5757

5858
[llm]
5959
llm_call_timeout_seconds = 30
60+
llm_vision_enabled = false
61+
llm_vision_max_image_px = 768
6062

6163
[ocr_reader]
6264
enabled = true

0 commit comments

Comments
 (0)