Skip to content

Commit b706f27

Browse files
authored
Merge pull request #16 from jethronap/8_testing
8 testing
2 parents 4a9648c + d1463b6 commit b706f27

File tree

12 files changed

+298
-19
lines changed

12 files changed

+298
-19
lines changed

local_test_pipeline.sh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
PYTHONPATH="${PYTHONPATH}:$(realpath "./src")"
3+
export PYTHONPATH
4+
5+
# Navigate to the script's directory (project root)
6+
cd "$(dirname "$0")" || exit
7+
8+
echo "Running base agent tests"
9+
pytest tests/test_base_agent.py
10+
echo "Done..."
11+
echo "==============================================="
12+
13+
echo "Running dummy agent tests"
14+
pytest tests/test_dummy_agent.py
15+
echo "Done..."
16+
echo "==============================================="
17+
18+
echo "Running llm wrapper tests"
19+
pytest tests/test_llm_wrapper.py
20+
echo "Done..."
21+
echo "==============================================="
22+
23+
echo "Running memory store tests"
24+
pytest tests/test_memory_store.py
25+
echo "Done..."
26+
echo "==============================================="
27+
28+
echo "Running ollama client tests"
29+
pytest tests/test_ollama_client.py
30+
echo "Done..."
31+
echo "==============================================="

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"loguru>=0.7.3",
99
"pre-commit>=4.2.0",
1010
"pydantic-settings>=2.9.1",
11+
"pytest>=8.3.5",
1112
"requests>=2.32.3",
1213
"ruff>=0.11.7",
1314
"sqlmodel>=0.0.24",

src/agents/dummy_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,5 @@ def act(self, action: str, context: Dict[str, Any]) -> Any:
4747
logger.info(f"{self.name} executing action: {action} with context: {context}")
4848
result = f"result_of_{action}"
4949
logger.info(f"{self.name} executed action: {action} -> {result}")
50+
print(f"{self.name} executed action: {action} -> {result}")
5051
return result

src/config/settings.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@ class OllamaSettings(BaseSettings):
1010
Base settings for Ollama interaction
1111
"""
1212

13-
base_url: str = Field(description="The Ollama base url")
14-
timeout_seconds: float
13+
base_url: str = Field(
14+
default="http://localhost:11434/api/generate", description="The Ollama base url"
15+
)
16+
timeout_seconds: float = Field(
17+
default=30.0, description="Timeout for calling Ollama"
18+
)
1519
stream: bool = Field(default=False, description="Flag to denote chunked streaming.")
16-
model: str = Field(description="Ollama model name.")
20+
model: str = Field(default="gemma2", description="Ollama model name.")
1721

18-
model_config = SettingsConfigDict(env_file=".env", env_prefix="OLLAMA_")
22+
model_config = SettingsConfigDict(
23+
env_file=".env", env_prefix="OLLAMA_", extra="allow"
24+
)
1925

2026

2127
class DatabaseSettings(BaseSettings):
@@ -36,16 +42,18 @@ class LoggingSettings(BaseSettings):
3642
Configuration for Loguru logging sinks.
3743
"""
3844

39-
level: str = Field(description="The log level")
45+
level: str = Field(default="DEBUG", description="The log level")
4046
console: bool = Field(default=True, description="Show logs in console")
4147
enable_file: bool = Field(
4248
default=False, description="Flag to denote persistence of logs"
4349
)
4450
filepath: Optional[Path] = Field(
4551
default=None, description="Optional file path for logs"
4652
)
47-
rotation: str = Field(description="Roll log after this size")
48-
retention: str = Field(description="Keep logs for this amount of time")
49-
compression: str = Field(description="Compress old logs")
53+
rotation: str = Field(default="10 MB", description="Roll log after this size")
54+
retention: str = Field(
55+
default="7 days", description="Keep logs for this amount of time"
56+
)
57+
compression: str = Field(default="zip", description="Compress old logs")
5058

5159
model_config = SettingsConfigDict(env_file=".env", env_prefix="LOG_", extra="allow")

src/tools/llm_wrapper.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,26 @@ def generate_response(self, prompt: str, **kwargs: Any) -> str:
4242

4343
# Extract generated text from response
4444
text = ""
45-
if "response" in response:
46-
text = response.get("response", "").strip()
47-
logger.debug("Extracted 'response' field from Ollama response")
48-
elif "choices" in response:
49-
text = response.get("choices", [{}][0].get("text", "").strip())
50-
logger.debug("Extracted 'choices' field from Ollama response")
51-
elif "results" in response:
52-
text = response.get("results", [{}])[0].get("text", "").strip()
53-
logger.debug("Extracted 'results' field from Ollama response")
45+
if isinstance(response, dict):
46+
if "response" in response and isinstance(response["response"], str):
47+
text = response["response"].strip()
48+
logger.debug("Extracted 'response' field from Ollama response")
49+
elif "choices" in response and isinstance(response["choices"], list):
50+
choice = response["choices"][0]
51+
text = (
52+
choice.get("text", "").strip() if isinstance(choice, dict) else ""
53+
)
54+
logger.debug("Extracted 'choices' field from Ollama response")
55+
elif "results" in response and isinstance(response["results"], list):
56+
result = response["results"][0]
57+
text = (
58+
result.get("text", "").strip() if isinstance(result, dict) else ""
59+
)
60+
logger.debug("Extracted 'results' field from Ollama response")
61+
else:
62+
logger.warning(f"Unexpected response format: {response}")
5463
else:
55-
logger.warning(f"Unexpected response format: {response}")
56-
text = str(response).strip()
64+
logger.warning(f"Invalid response type: {type(response)}")
5765

5866
if not text:
5967
logger.warning("Received empty response from LLM")

tests/conftest.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
from src.config.settings import OllamaSettings, DatabaseSettings
3+
4+
5+
@pytest.fixture
6+
def ollama_settings():
7+
return OllamaSettings(
8+
base_url="http://test-server/api/generate",
9+
timeout_seconds=0.1,
10+
model="test-model",
11+
)
12+
13+
14+
@pytest.fixture
15+
def db_settings(tmp_path):
16+
return DatabaseSettings(
17+
url=f"sqlite:///{tmp_path / 'test_memory.db'}",
18+
echo=False,
19+
)
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def patch_client(monkeypatch):
24+
# Replace OllamaClient used inside LocalLLM with our stub
25+
monkeypatch.setattr("src.tools.llm_wrapper.OllamaClient", StubClient)
26+
27+
28+
class StubClient:
29+
def __init__(self, settings):
30+
pass
31+
32+
def __call__(self, prompt, **kwargs):
33+
return self._response

tests/test_base_agent.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
from src.agents.base_agent import Agent
3+
4+
5+
def test_base_agent_abstract():
6+
with pytest.raises(TypeError):
7+
Agent(name="X", tools={}, memory=None)
8+
9+
10+
class SimpleAgent(Agent):
11+
def perceive(self, input_data):
12+
return {"in": input_data}
13+
14+
def plan(self, observation):
15+
return ["step"]
16+
17+
def act(self, action, context):
18+
return "ok"
19+
20+
21+
def test_simple_agent_runs():
22+
agent = SimpleAgent(name="Simple", tools={}, memory=None)
23+
# Should run without errors
24+
agent.achieve_goal("DATA")

tests/test_dummy_agent.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from src.agents.dummy_agent import DummyAgent
2+
3+
4+
def test_dummy_agent_cycle(capsys):
5+
agent = DummyAgent(name="TestDummy")
6+
agent.achieve_goal("INPUT")
7+
captured = capsys.readouterr().out.strip().splitlines()
8+
# Expect three lines for dummy_step1..3
9+
assert len(captured) == 3
10+
for i, line in enumerate(captured, start=1):
11+
assert f"executed action: dummy_step{i}" in line

tests/test_llm_wrapper.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from src.config.settings import OllamaSettings
2+
from src.tools.llm_wrapper import LocalLLM
3+
from tests.conftest import StubClient
4+
5+
6+
def make_wrapper(resp):
7+
settings = OllamaSettings(base_url="http://x", timeout_seconds=1, model="m")
8+
stub = StubClient(settings)
9+
stub._response = resp
10+
# ensure LocalLLM uses our stub instance
11+
wrapper = LocalLLM(settings)
12+
wrapper.client = stub
13+
return wrapper
14+
15+
16+
def test_generate_from_response_field():
17+
wrapper = make_wrapper({"response": "hi there"})
18+
assert wrapper.generate_response("x") == "hi there"
19+
20+
21+
def test_generate_from_choices_field():
22+
wrapper = make_wrapper({"choices": [{"text": "foo "}]})
23+
assert wrapper.generate_response("x") == "foo"
24+
25+
26+
def test_generate_empty_and_warn(caplog):
27+
wrapper = make_wrapper({})
28+
caplog.set_level("WARNING")
29+
out = wrapper.generate_response("x")
30+
assert out == ""
31+
assert "" in caplog.text.lower()

tests/test_memory_store.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from src.memory.store import MemoryStore
2+
from src.memory.models import Memory
3+
4+
5+
def test_store_and_load_roundtrip(db_settings):
6+
store = MemoryStore(db_settings)
7+
# store 2 memories for AgentA
8+
m1 = store.store("AgentA", "step1", "data1")
9+
m2 = store.store("AgentA", "step2", "data2")
10+
assert isinstance(m1, Memory) and isinstance(m2, Memory)
11+
12+
records = store.load("AgentA")
13+
assert len(records) == 2
14+
# ensure content matches
15+
contents = {r.content for r in records}
16+
assert contents == {"data1", "data2"}
17+
18+
19+
def test_isolated_databases(tmp_path):
20+
# two separate stores shouldn't see each other's data
21+
from src.config.settings import DatabaseSettings
22+
23+
s1 = DatabaseSettings(url=f"sqlite:///{tmp_path / 'a.db'}")
24+
s2 = DatabaseSettings(url=f"sqlite:///{tmp_path / 'b.db'}")
25+
st1 = MemoryStore(s1)
26+
st2 = MemoryStore(s2)
27+
st1.store("X", "s", "d")
28+
assert len(st1.load("X")) == 1
29+
assert st2.load("X") == []

0 commit comments

Comments
 (0)