From 085b74e8f2671c1cae921408cb175d47b9dedb18 Mon Sep 17 00:00:00 2001 From: CYJiang Date: Tue, 9 Jun 2026 22:31:38 +0800 Subject: [PATCH] Handle invalid UTF-8 in memory markdown --- src/memsearch/cli.py | 3 +- src/memsearch/core.py | 3 +- src/memsearch/io.py | 28 +++++++++++++++ src/memsearch/maintenance.py | 5 +-- tests/test_cli_error_handling.py | 33 ++++++++++++++++++ tests/test_core_encoding.py | 60 ++++++++++++++++++++++++++++++++ tests/test_maintenance.py | 13 ++++++- 7 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 src/memsearch/io.py create mode 100644 tests/test_core_encoding.py diff --git a/src/memsearch/cli.py b/src/memsearch/cli.py index 0e1a4d32..28a48145 100644 --- a/src/memsearch/cli.py +++ b/src/memsearch/cli.py @@ -21,6 +21,7 @@ save_config, set_config_value, ) +from .io import read_utf8_text_replace try: from pymilvus.exceptions import MilvusException @@ -373,7 +374,7 @@ def expand( click.echo(f"Source file not found: {source}", err=True) sys.exit(1) - all_lines = source_path.read_text(encoding="utf-8").splitlines() + all_lines = read_utf8_text_replace(source_path).splitlines() if lines is not None: # Show N lines before/after the chunk diff --git a/src/memsearch/core.py b/src/memsearch/core.py index b3eabb3b..7460132b 100644 --- a/src/memsearch/core.py +++ b/src/memsearch/core.py @@ -15,6 +15,7 @@ from .chunker import Chunk, chunk_markdown, clean_content_for_embedding, compute_chunk_id from .compact import compact_chunks from .embeddings import EmbeddingProvider, get_provider +from .io import read_utf8_text_replace from .scanner import ScannedFile, scan_paths from .store import MilvusStore @@ -125,7 +126,7 @@ async def index_file(self, path: str | Path) -> int: async def _index_file(self, f: ScannedFile, *, force: bool = False) -> int: source = str(f.path) - text = f.path.read_text(encoding="utf-8") + text = read_utf8_text_replace(f.path) chunks = chunk_markdown( text, source=source, diff --git a/src/memsearch/io.py b/src/memsearch/io.py new file mode 100644 index 00000000..a97fac3a --- /dev/null +++ b/src/memsearch/io.py @@ -0,0 +1,28 @@ +"""Text file helpers.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def read_utf8_text_replace(path: str | Path) -> str: + """Read text as UTF-8, replacing invalid byte sequences. + + Markdown memory files are expected to be UTF-8, but hooks may append + agent/tool output that contains malformed bytes. Keep indexing usable + and preserve as much surrounding text as possible. + """ + p = Path(path) + data = p.read_bytes() + try: + return data.decode("utf-8") + except UnicodeDecodeError as e: + logger.warning( + "File %s contains invalid UTF-8 at byte %d; replacing invalid bytes", + p, + e.start, + ) + return data.decode("utf-8", errors="replace") diff --git a/src/memsearch/maintenance.py b/src/memsearch/maintenance.py index 3f067fd8..8c8f6f85 100644 --- a/src/memsearch/maintenance.py +++ b/src/memsearch/maintenance.py @@ -24,6 +24,7 @@ config_to_dict, resolve_env_ref, ) +from .io import read_utf8_text_replace TASKS = ("project_review", "user_profile") MAX_PROMPT_CHARS = 80_000 @@ -255,8 +256,8 @@ def _read_recent_journals(input_dir: Path, max_files: int = 12) -> str: chunks: list[str] = [] files = sorted((p for p in input_dir.rglob("*.md") if p.is_file()), key=lambda p: p.stat().st_mtime)[-max_files:] for path in files: - with contextlib.suppress(OSError, UnicodeDecodeError): - chunks.append(f"\n\n{path.read_text(encoding='utf-8')}") + with contextlib.suppress(OSError): + chunks.append(f"\n\n{read_utf8_text_replace(path)}") return "\n".join(chunks) diff --git a/tests/test_cli_error_handling.py b/tests/test_cli_error_handling.py index 273b8ce3..d1329f12 100644 --- a/tests/test_cli_error_handling.py +++ b/tests/test_cli_error_handling.py @@ -77,3 +77,36 @@ def fake_load(_path): assert result.exit_code == 1 assert "Configuration error:" in result.stderr assert "DEFINITELY_NOT_SET_MEMSEARCH_API_KEY" in result.stderr + + +def test_expand_replaces_invalid_utf8_source_bytes(monkeypatch, tmp_path) -> None: + source = tmp_path / "bad.md" + source.write_bytes(b"# Bad\n\nbroken \xff byte\n") + + class FakeStore: + def __init__(self, **_kwargs): + pass + + def query(self, filter_expr: str): + assert filter_expr == 'chunk_hash == "abc123"' + return [ + { + "source": str(source), + "start_line": 1, + "end_line": 3, + "heading": "Bad", + "heading_level": 1, + } + ] + + def close(self) -> None: + pass + + monkeypatch.setattr(cli_module, "resolve_config", lambda _overrides=None: MemSearchConfig()) + monkeypatch.setattr(store_module, "MilvusStore", FakeStore) + + runner = CliRunner() + result = runner.invoke(cli, ["expand", "abc123"]) + + assert result.exit_code == 0 + assert "broken \ufffd byte" in result.output diff --git a/tests/test_core_encoding.py b/tests/test_core_encoding.py new file mode 100644 index 00000000..535d03a3 --- /dev/null +++ b/tests/test_core_encoding.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from memsearch.core import MemSearch + + +class FakeEmbedder: + @property + def model_name(self) -> str: + return "fake" + + @property + def dimension(self) -> int: + return 4 + + async def embed(self, texts: list[str]) -> list[list[float]]: + return [[0.0] * self.dimension for _ in texts] + + +class RecordingStore: + def __init__(self) -> None: + self.records: list[dict[str, Any]] = [] + + def hashes_by_source(self, source: str) -> set[str]: + return set() + + def delete_by_hashes(self, hashes: list[str]) -> None: + pass + + def upsert(self, records: list[dict[str, Any]]) -> int: + self.records.extend(records) + return len(records) + + +def make_memsearch() -> tuple[MemSearch, RecordingStore]: + ms = MemSearch.__new__(MemSearch) + ms._paths = [] + ms._max_chunk_size = 1500 + ms._overlap_lines = 2 + ms._embedder = FakeEmbedder() + store = RecordingStore() + ms._store = store + ms._reranker_model = "" + return ms, store + + +@pytest.mark.asyncio +async def test_index_file_replaces_invalid_utf8_bytes(tmp_path: Path) -> None: + note = tmp_path / "bad.md" + note.write_bytes(b"# Bad UTF-8\n\nThis line has an invalid byte: \xff.\n") + ms, store = make_memsearch() + + indexed = await ms.index_file(note) + + assert indexed == 1 + assert store.records[0]["content"] == "# Bad UTF-8\n\nThis line has an invalid byte: \ufffd." diff --git a/tests/test_maintenance.py b/tests/test_maintenance.py index 6a594d7c..aa576e81 100644 --- a/tests/test_maintenance.py +++ b/tests/test_maintenance.py @@ -6,7 +6,7 @@ from types import SimpleNamespace from memsearch.config import LLMProviderConfig, MemSearchConfig, PluginMaintenanceTaskConfig -from memsearch.maintenance import TaskContext, run_due_tasks, run_memory_command, run_task_llm +from memsearch.maintenance import TaskContext, _read_recent_journals, run_due_tasks, run_memory_command, run_task_llm def test_maintenance_routes_gemini_provider_to_tool_runner(tmp_path: Path, monkeypatch) -> None: @@ -35,6 +35,17 @@ def fake_gemini(ctx, prompt: str, model: str | None, provider_cfg) -> str: assert captured == {"model": "gemini-test", "provider_type": "gemini"} +def test_read_recent_journals_replaces_invalid_utf8_bytes(tmp_path: Path) -> None: + memory = tmp_path / "memory" + memory.mkdir() + (memory / "2026-06-09.md").write_bytes(b"### 10:00\n- broken \xff byte\n") + + journals = _read_recent_journals(memory) + + assert "