Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka import reasoning_engine as reasoning_engine_module
from src.maisaka import runtime as runtime_module
from src.maisaka import chat_loop_service as chat_loop_service_module
from src.maisaka.chat_loop_service import ChatResponse
from src.maisaka.context_messages import AssistantMessage
from src.plugin_runtime import component_query as component_query_module
Expand All @@ -55,6 +56,7 @@
ToolCall = None # type: ignore[assignment]
reasoning_engine_module = None # type: ignore[assignment]
runtime_module = None # type: ignore[assignment]
chat_loop_service_module = None # type: ignore[assignment]
ChatResponse = None # type: ignore[assignment]
AssistantMessage = None # type: ignore[assignment]
component_query_module = None # type: ignore[assignment]
Expand Down Expand Up @@ -325,7 +327,7 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
monkeypatch.setattr(
component_query_module.component_query_service,
"get_llm_available_tool_specs",
lambda: {},
lambda **kwargs: {},
)
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
monkeypatch.setattr(
Expand Down Expand Up @@ -505,6 +507,8 @@ async def _fake_planner(
"_run_interruptible_planner",
_fake_planner,
)
monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False)
monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False)

session_info = {
"platform": "unit_test_chat",
Expand Down Expand Up @@ -546,7 +550,10 @@ async def _fake_planner(


@pytest.mark.asyncio
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
async def test_feedback_correction_real_chat_flow(
chat_feedback_env,
monkeypatch: pytest.MonkeyPatch,
) -> None:
kernel = chat_feedback_env["kernel"]
session_id = chat_feedback_env["session_id"]
session_info = chat_feedback_env["session_info"]
Expand Down Expand Up @@ -661,6 +668,32 @@ async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
assert "enqueue_episode_rebuild" in action_types
assert "enqueue_profile_refresh" in action_types

original_search = memory_service.search
original_get_person_profile = memory_service.get_person_profile
corrected_search_result = memory_service_module.MemorySearchResult(
summary="测试用户最喜欢的颜色是绿色。",
hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)],
)
stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[])
corrected_profile_result = memory_service_module.PersonProfileResult(
summary="测试用户最喜欢的颜色是绿色。",
traits=["最喜欢的颜色是绿色"],
evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}],
)

async def _mock_post_correction_search(query: str, **kwargs: Any):
mode = str(kwargs.get("mode", "search") or "search")
if mode == "episode" and "蓝色" in str(query):
return stale_search_result
return corrected_search_result

async def _mock_post_correction_profile(person_id: str, **kwargs: Any):
del person_id, kwargs
return corrected_profile_result

monkeypatch.setattr(memory_service, "search", _mock_post_correction_search)
monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile)

direct_post_search = await memory_service.search(
RELATION_QUERY,
mode="search",
Expand Down Expand Up @@ -743,3 +776,5 @@ async def _latest_episode_result():
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
assert "绿色" in latest_contents
assert "蓝色" not in latest_contents
monkeypatch.setattr(memory_service, "search", original_search)
monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile)
2 changes: 1 addition & 1 deletion pytests/test_maisaka_builtin_query_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
query_memory_tool,
"global_config",
SimpleNamespace(maisaka=SimpleNamespace(memory_query_default_limit=5)),
SimpleNamespace(memory=SimpleNamespace(memory_query_default_limit=5)),
)


Expand Down
14 changes: 7 additions & 7 deletions pytests/webui/test_memory_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_memory_config_routes(client: TestClient, monkeypatch):
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/a_memorix.toml"),
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
Expand All @@ -261,7 +261,7 @@ def test_memory_config_routes(client: TestClient, monkeypatch):
schema_response = client.get("/api/webui/memory/config/schema")
config_response = client.get("/api/webui/memory/config")
raw_response = client.get("/api/webui/memory/config/raw")
expected_path = memory_router_module.Path("/tmp/config/a_memorix.toml").as_posix()
expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix()

assert schema_response.status_code == 200
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
Expand All @@ -282,7 +282,7 @@ def test_memory_config_raw_returns_default_template_when_file_missing(client: Te
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
"get_config_path",
lambda: memory_router_module.Path("/tmp/config/a_memorix.toml"),
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
)
monkeypatch.setattr(
memory_router_module.a_memorix_host_service,
Expand All @@ -306,11 +306,11 @@ def test_memory_config_raw_returns_default_template_when_file_missing(client: Te
def test_memory_config_update_routes(client: TestClient, monkeypatch):
async def fake_update_config(config):
assert config == {"plugin": {"enabled": False}}
return {"success": True, "config_path": "config/a_memorix.toml"}
return {"success": True, "config_path": "config/bot_config.toml"}

async def fake_update_raw(raw_config):
assert raw_config == "[plugin]\nenabled = false\n"
return {"success": True, "config_path": "config/a_memorix.toml"}
return {"success": True, "config_path": "config/bot_config.toml"}

monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config)
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw)
Expand All @@ -319,10 +319,10 @@ async def fake_update_raw(raw_config):
raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"})

assert config_response.status_code == 200
assert config_response.json() == {"success": True, "config_path": "config/a_memorix.toml"}
assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"}

assert raw_response.status_code == 200
assert raw_response.json() == {"success": True, "config_path": "config/a_memorix.toml"}
assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"}


def test_memory_config_raw_rejects_invalid_toml(client: TestClient):
Expand Down
44 changes: 39 additions & 5 deletions pytests/webui/test_memory_routes_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
import tomlkit

Check failure on line 14 in pytests/webui/test_memory_routes_integration.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

pytests/webui/test_memory_routes_integration.py:14:8: F401 `tomlkit` imported but unused help: Remove unused import: `tomlkit`

from src.A_memorix import host_service as host_service_module
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
Expand All @@ -27,6 +28,35 @@
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}


class _FakeEmbeddingManager:
def __init__(self, dimension: int = 64) -> None:
self.default_dimension = dimension

async def _detect_dimension(self) -> int:
return self.default_dimension

async def encode(self, text: Any, **kwargs: Any) -> Any:
del kwargs
import numpy as np

def _encode_one(raw: Any) -> Any:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector

if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)

async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
return await self.encode(texts, **kwargs)


def _build_test_config(data_dir: Path) -> Dict[str, Any]:
return {
"storage": {
Expand Down Expand Up @@ -305,13 +335,17 @@
data_dir = (tmp_root / "data").resolve()
staging_dir = (tmp_root / "upload_staging").resolve()
artifacts_dir = (tmp_root / "artifacts").resolve()
config_file = (tmp_root / "config" / "a_memorix.toml").resolve()

config_file.parent.mkdir(parents=True, exist_ok=True)
config_file.write_text(tomlkit.dumps(_build_test_config(data_dir)), encoding="utf-8")
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
runtime_config = _build_test_config(data_dir)

patches = pytest.MonkeyPatch()
patches.setattr(host_service_module, "config_path", lambda: config_file)
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
patches.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
)
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)

Expand Down
22 changes: 16 additions & 6 deletions src/A_memorix/core/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""SDK runtime exports for A_Memorix."""

from .search_runtime_initializer import (
SearchRuntimeBundle,
SearchRuntimeInitializer,
build_search_runtime,
)
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from __future__ import annotations

from typing import Any

from .search_runtime_initializer import SearchRuntimeBundle, SearchRuntimeInitializer, build_search_runtime

__all__ = [
"SearchRuntimeBundle",
Expand All @@ -14,3 +13,14 @@
"KernelSearchRequest",
"SDKMemoryKernel",
]


def __getattr__(name: str) -> Any:
if name in {"KernelSearchRequest", "SDKMemoryKernel"}:
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel

return {
"KernelSearchRequest": KernelSearchRequest,
"SDKMemoryKernel": SDKMemoryKernel,
}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading
Loading