diff --git a/.cursor_rules b/.cursorrules similarity index 99% rename from .cursor_rules rename to .cursorrules index ede9113..25c7036 100644 --- a/.cursor_rules +++ b/.cursorrules @@ -1,6 +1,9 @@ +# Rules + You are an AI expert specialized in developing simulations that model complex human behavior and group dynamics based on Narrative Field Theory. Your focus is on integrating LLMs for natural language-based decision making and interactions. Core Competencies: + - Multi-agent systems and emergent behavior - Psychological modeling and group dynamics - LLM integration and prompt engineering @@ -8,6 +11,7 @@ Core Competencies: - Machine learning and neural networks Key Scientific Foundations: + - Cognitive Science & Psychology - Complex Systems Theory - Social Network Analysis @@ -15,6 +19,7 @@ Key Scientific Foundations: - Organizational Behavior Technical Stack: + - Python (core language) - PyTorch (ML components) - Transformers (LLM integration) @@ -23,6 +28,7 @@ Technical Stack: - Redis (state management) Code Quality Standards: + 1. Style and Formatting - Follow PEP 8 style guide - Use black for code formatting @@ -76,6 +82,7 @@ Architecture Focus: - State-to-text conversion Development Workflow: + 1. Version Control - Git flow branching model - Semantic versioning @@ -100,6 +107,7 @@ Development Workflow: - Performance benchmarks Key Patterns: + - Loosely coupled components - Event-driven communication - Asynchronous processing @@ -107,6 +115,7 @@ Key Patterns: - Observable systems Best Practices: + 1. Clear separation of concerns 2. Efficient state management 3. Robust error handling diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index dd84ea7..3205926 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -12,6 +12,7 @@ A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: + 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' @@ -24,15 +25,17 @@ A clear and concise description of what you expected to happen. If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] + +- OS: [e.g. iOS] +- Browser [e.g. chrome, safari] +- Version [e.g. 22] **Smartphone (please complete the following information):** - - Device: [e.g. iPhone6] - - OS: [e.g. iOS8.1] - - Browser [e.g. stock browser, safari] - - Version [e.g. 22] + +- Device: [e.g. iPhone6] +- OS: [e.g. iOS8.1] +- Browser [e.g. stock browser, safari] +- Version [e.g. 22] **Additional context** Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/issue-.md b/.github/ISSUE_TEMPLATE/issue-.md index 4a7fa0e..e4f69a3 100644 --- a/.github/ISSUE_TEMPLATE/issue-.md +++ b/.github/ISSUE_TEMPLATE/issue-.md @@ -8,24 +8,31 @@ assignees: '' --- ## Title + A clear and concise title for the issue. ## Description + A detailed description of the issue. ## Steps to Reproduce + 1. Step one 2. Step two 3. Step three ## Expected Behavior + What you expected to happen. ## Actual Behavior + What actually happened. ## Screenshots/Logs + Any relevant screenshots or logs. ## Environment + Information about the environment where the issue occurred (e.g., OS, browser, version). diff --git a/src/config.py b/src/config.py index f0736aa..52fb3bf 100644 --- a/src/config.py +++ b/src/config.py @@ -1,9 +1,19 @@ +""" +Configuration module for the lab-politik application. + +This module contains global configuration variables and model-specific +configurations for the application. +""" + from pathlib import Path +from typing import Dict, Any +# Application-wide constants APP_NAME: str = "lab-politik" IS_DEVELOPMENT: bool = True -MODEL_CONFIGS = { +# Model configuration dictionary +MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { "balanced": { "chat": { "path": Path( @@ -32,3 +42,21 @@ }, }, } + +def get_model_config(config_name: str = "balanced") -> Dict[str, Any]: + """ + Retrieve the model configuration for a given configuration name. + + Args: + config_name (str): The name of the configuration to retrieve. + Defaults to "balanced". + + Returns: + Dict[str, Any]: The model configuration dictionary. + + Raises: + KeyError: If the specified config_name is not found in MODEL_CONFIGS. + """ + if config_name not in MODEL_CONFIGS: + raise KeyError(f"Configuration '{config_name}' not found in MODEL_CONFIGS") + return MODEL_CONFIGS[config_name] diff --git a/src/embedding_cache.py b/src/embedding_cache.py index 5190d47..f34f986 100644 --- a/src/embedding_cache.py +++ b/src/embedding_cache.py @@ -1,28 +1,90 @@ """Module for caching embeddings.""" import hashlib -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union class EmbeddingCache: - """Hash-based cache for storing embeddings.""" + """ + Hash-based cache for storing embeddings. + + This class provides a simple key-value store for embeddings, using SHA-256 hashes + of the input text as keys. This ensures consistent lookup regardless of minor + text variations. + + Attributes: + _cache (Dict[str, List[float]]): Internal storage for hashed key-value pairs. + """ def __init__(self): self._cache: Dict[str, List[float]] = {} @staticmethod def get_stable_hash(text: str) -> str: - """Generate a stable hash for the given text.""" + """ + Generate a stable hash for the given text. + + Args: + text (str): The input text to hash. + + Returns: + str: A SHA-256 hash of the input text. + """ return hashlib.sha256(text.encode()).hexdigest() def get(self, key: str) -> Optional[List[float]]: - """Retrieve an embedding from the cache using a hashed key.""" + """ + Retrieve an embedding from the cache using a hashed key. + + Args: + key (str): The text key to look up. + + Returns: + Optional[List[float]]: The stored embedding if found, None otherwise. + """ return self._cache.get(self.get_stable_hash(key)) def set(self, key: str, value: List[float]) -> None: - """Store an embedding in the cache using a hashed key.""" + """ + Store an embedding in the cache using a hashed key. + + Args: + key (str): The text key to associate with the embedding. + value (List[float]): The embedding to store. + """ self._cache[self.get_stable_hash(key)] = value def clear(self) -> None: """Clear all entries from the cache.""" self._cache.clear() + + def __len__(self) -> int: + """ + Get the number of entries in the cache. + + Returns: + int: The number of cached embeddings. + """ + return len(self._cache) + + def __contains__(self, key: str) -> bool: + """ + Check if a key exists in the cache. + + Args: + key (str): The text key to check. + + Returns: + bool: True if the key exists in the cache, False otherwise. + """ + return self.get_stable_hash(key) in self._cache + + def update(self, items: Dict[str, List[float]]) -> None: + """ + Update the cache with multiple key-value pairs. + + Args: + items (Dict[str, List[float]]): A dictionary of text keys and their embeddings. + """ + for key, value in items.items(): + self.set(key, value) diff --git a/src/language_models.py b/src/language_models.py index a824575..ad7af0c 100644 --- a/src/language_models.py +++ b/src/language_models.py @@ -1,3 +1,10 @@ +""" +This module defines abstract and concrete implementations of language models. + +It includes classes for interfacing with Ollama and Llama models, as well as +utility functions and custom exceptions for error handling. +""" + from __future__ import annotations from abc import ABC, abstractmethod from typing import List, Callable @@ -41,6 +48,16 @@ class ModelInitializationError(Exception): def async_error_handler(func: Callable) -> Callable: + """ + A decorator to handle errors in asynchronous functions. + + Args: + func (Callable): The asynchronous function to be wrapped. + + Returns: + Callable: The wrapped function that catches and re-raises exceptions as ModelError. + """ + @wraps(func) async def wrapper(*args, **kwargs): try: @@ -52,9 +69,17 @@ async def wrapper(*args, **kwargs): class LanguageModel(ABC): - """Abstract base class for language models.""" + """ + Abstract base class for language models. + + This class defines the interface for language model implementations and + provides common functionality such as embedding caching. + """ def __init__(self): + """ + Initialize the LanguageModel with a logger and embedding cache. + """ self.logger = logging.getLogger(self.__class__.__name__) self.logger.info(f"Initializing {self.__class__.__name__}") self.embedding_cache: EmbeddingCache = EmbeddingCache() @@ -76,8 +101,7 @@ async def generate_embedding(self, text: str) -> List[float]: self.logger.warning("Attempted to generate embedding for empty text") return [] - cached_embedding = self.embedding_cache.get(text) - if cached_embedding: + if cached_embedding := self.embedding_cache.get(text): self.logger.info(f"Embedding found in cache for text: {text[:50]}...") return cached_embedding @@ -105,9 +129,24 @@ async def cleanup(self) -> None: class OllamaInterface(LanguageModel): - """Interface for the Ollama language model.""" + """ + Interface for the Ollama language model. + + This class provides methods to interact with Ollama models for text generation + and embedding creation. + """ def __init__(self, quality_preset: str = "balanced"): + """ + Initialize the OllamaInterface with the specified quality preset. + + Args: + quality_preset (str): The quality preset to use for model configuration. + Defaults to "balanced". + + Raises: + ModelInitializationError: If the quality preset is invalid or configuration is missing. + """ super().__init__() try: self.chat_model_name = MODEL_CONFIGS[quality_preset]["chat"]["model_name"] @@ -121,7 +160,12 @@ def __init__(self, quality_preset: str = "balanced"): ) from e def _setup_models(self) -> None: - """Set up the language models.""" + """ + Set up the Ollama models for chat and embedding. + + Raises: + ModelInitializationError: If Ollama fails to start. + """ self.logger.info(f"Setting up Ollama models for {self.chat_model_name}") self.logger.info( f"Setting up Ollama embedding model for {self.embedding_model_name}" @@ -165,9 +209,24 @@ async def _generate_embedding(self, text: str) -> List[float]: class LlamaInterface(LanguageModel): - """Interface for the Llama language model.""" + """ + Interface for the Llama language model. + + This class provides methods to interact with Llama models for text generation + and embedding creation using local model files. + """ def __init__(self, quality_preset: str = "balanced"): + """ + Initialize the LlamaInterface with the specified quality preset. + + Args: + quality_preset (str): The quality preset to use for model configuration. + Defaults to "balanced". + + Raises: + ModelInitializationError: If the quality preset is invalid or configuration is missing. + """ super().__init__() try: self.chat_model_path = MODEL_CONFIGS[quality_preset]["chat"]["path"] @@ -184,7 +243,12 @@ def __init__(self, quality_preset: str = "balanced"): ) from e def _setup_models(self) -> None: - """Set up the language models.""" + """ + Set up the Llama models for chat and embedding. + + Raises: + ModelInitializationError: If model initialization fails. + """ chat_model_filename = Path(self.chat_model_path).name embedding_model_filename = Path(self.embedding_model_path).name @@ -243,9 +307,14 @@ async def _generate_embedding(self, text: str) -> List[float]: return await asyncio.to_thread(self.embedding_model.embed, text) async def cleanup(self) -> None: - """Clean up resources used by the model.""" + """ + Clean up resources used by the Llama models. + + This method extends the base class cleanup by also deleting the Llama model instances. + """ await super().cleanup() if self.llm: del self.llm if self.embedding_model: del self.embedding_model + diff --git a/src/logging_config.py b/src/logging_config.py index 010ccb9..b5368f0 100644 --- a/src/logging_config.py +++ b/src/logging_config.py @@ -12,7 +12,18 @@ def setup_logging() -> None: - """Set up logging configuration for the application.""" + """ + Set up logging configuration for the application. + + This function initializes the logging system, creating log directories, + setting up file and console handlers, and cleaning up old log files. + + The log directory is determined based on whether the application is running + in development mode or not. + + Returns: + None + """ if IS_DEVELOPMENT: log_dir = os.path.join(os.path.dirname(__file__), "..", "logs") else: @@ -35,7 +46,15 @@ def setup_logging() -> None: def _create_file_handler(log_file: str) -> logging.Handler: - """Create and configure a file handler for logging.""" + """ + Create and configure a file handler for logging. + + Args: + log_file (str): The path to the log file. + + Returns: + logging.Handler: A configured RotatingFileHandler instance. + """ file_handler = logging.handlers.RotatingFileHandler( log_file, maxBytes=MAX_LOG_FILE_SIZE, backupCount=MAX_LOG_FILES - 1 ) @@ -48,7 +67,12 @@ def _create_file_handler(log_file: str) -> logging.Handler: def _create_console_handler() -> logging.Handler: - """Create and configure a console handler for logging.""" + """ + Create and configure a console handler for logging. + + Returns: + logging.Handler: A configured StreamHandler instance. + """ console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") @@ -57,7 +81,15 @@ def _create_console_handler() -> logging.Handler: def _cleanup_old_logs(log_dir: str) -> None: - """Remove old log files if the total number exceeds MAX_LOG_FILES.""" + """ + Remove old log files if the total number exceeds MAX_LOG_FILES. + + Args: + log_dir (str): The directory containing the log files. + + Returns: + None + """ log_files = glob.glob(os.path.join(log_dir, f"{APP_NAME}_*.log*")) log_files.sort(key=os.path.getmtime, reverse=True) for old_file in log_files[MAX_LOG_FILES:]: diff --git a/src/nfs_simple_lab_scenario.py b/src/nfs_simple_lab_scenario.py index 3629373..4600c2b 100644 --- a/src/nfs_simple_lab_scenario.py +++ b/src/nfs_simple_lab_scenario.py @@ -191,8 +191,7 @@ def get_average_duration(self, operation: str) -> float: def print_summary(self): print("\nPerformance Metrics Summary:") for operation, data in self.metrics.items(): - durations = data["durations"] - if durations: + if durations := data["durations"]: avg_duration = sum(durations) / len(durations) min_duration = min(durations) max_duration = max(durations) @@ -284,7 +283,7 @@ def __init__(self, collection_name: str = "narrative_field"): try: self.collection = self.client.get_collection(collection_name) - except: + except Exception: self.collection = self.client.create_collection( name=collection_name, metadata={"hnsw:space": "cosine"} ) @@ -553,10 +552,10 @@ async def demo_scenario(): llm = LlamaInterface() vector_store: VectorStore = ChromaStore(collection_name="research_lab") - logger.info(f"Initialized Chroma vector store") + logger.info("Initialized Chroma vector store") field = NarrativeField(llm, vector_store) - logger.info(f"Initialized narrative field") + logger.info("Initialized narrative field") # Research Lab Scenario with Multiple Characters and events stories = [ diff --git a/tests/test_embedding_cache.py b/tests/test_embedding_cache.py index 9ce543a..ef8de8d 100644 --- a/tests/test_embedding_cache.py +++ b/tests/test_embedding_cache.py @@ -3,10 +3,12 @@ @pytest.fixture def cache(): + """Fixture to provide a fresh EmbeddingCache instance for each test.""" return EmbeddingCache() def test_get_stable_hash(): text = "Hello, world!" + # This hash is pre-computed and should remain stable expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3" assert EmbeddingCache.get_stable_hash(text) == expected_hash @@ -17,12 +19,17 @@ def test_set_and_get(cache): assert cache.get(key) == value def test_get_nonexistent_key(cache): + # Should return None for keys that haven't been set assert cache.get("nonexistent_key") is None def test_clear(cache): + # Set up the cache with some values cache.set("key1", [1.0, 2.0]) cache.set("key2", [3.0, 4.0]) + cache.clear() + + # Verify that the cache is empty after clearing assert cache.get("key1") is None assert cache.get("key2") is None @@ -30,6 +37,34 @@ def test_multiple_sets_same_key(cache): key = "test_key" value1 = [0.1, 0.2, 0.3] value2 = [0.4, 0.5, 0.6] + cache.set(key, value1) cache.set(key, value2) + + # Verify that the most recent value is retrieved assert cache.get(key) == value2 + +def test_len(cache): + assert len(cache) == 0 + cache.set("key1", [1.0, 2.0]) + assert len(cache) == 1 + cache.set("key2", [3.0, 4.0]) + assert len(cache) == 2 + +def test_contains(cache): + cache.set("existing_key", [1.0, 2.0]) + assert "existing_key" in cache + assert "non_existing_key" not in cache + +def test_update(cache): + initial_data = {"key1": [1.0, 2.0], "key2": [3.0, 4.0]} + cache.update(initial_data) + assert cache.get("key1") == [1.0, 2.0] + assert cache.get("key2") == [3.0, 4.0] + + # Test updating existing and adding new key + update_data = {"key2": [5.0, 6.0], "key3": [7.0, 8.0]} + cache.update(update_data) + assert cache.get("key1") == [1.0, 2.0] # Unchanged + assert cache.get("key2") == [5.0, 6.0] # Updated + assert cache.get("key3") == [7.0, 8.0] # Newly added diff --git a/tests/test_language_models.py b/tests/test_language_models.py index 3a962e0..5816589 100644 --- a/tests/test_language_models.py +++ b/tests/test_language_models.py @@ -2,20 +2,30 @@ from unittest.mock import AsyncMock, MagicMock, patch from pathlib import Path -# Mock the config module +# Mock configuration for language models +# This simulates the structure of the actual MODEL_CONFIGS in the config module mock_MODEL_CONFIGS = { - "balanced": { - "chat": {"model_name": "test_chat_model", "path": Path("/path/to/chat/model")}, + "balanced": { # Quality preset + "chat": { + "model_name": "test_chat_model", + "path": Path("/path/to/chat/model") + }, "embedding": { "model_name": "test_embedding_model", - "path": Path("/path/to/embedding/model"), + "path": Path("/path/to/embedding/model") }, - "optimal_config": {"n_ctx": 2048, "n_batch": 512}, + "optimal_config": { + "n_ctx": 2048, # Context window size + "n_batch": 512 # Batch size for processing + } } + # Additional quality presets could be added here } -# Mock the EmbeddingCache +# Mock the EmbeddingCache class +# This allows us to simulate cache behavior without using a real cache in tests mock_EmbeddingCache = MagicMock() +# Methods like get() and set() can be mocked on this object as needed in tests # Patch both config and embedding_cache imports with patch.dict("sys.modules", { @@ -59,6 +69,7 @@ def mock_config(): class TestLanguageModel: @pytest.fixture def concrete_language_model(self): + # Define a concrete implementation of LanguageModel for testing class ConcreteLanguageModel(LanguageModel): @async_error_handler async def generate(self, prompt: str) -> str: @@ -66,6 +77,7 @@ async def generate(self, prompt: str) -> str: @async_error_handler async def _generate_embedding(self, text: str) -> list[float]: + # Simulate error for empty text if not text: raise Exception("Test error") return [0.1, 0.2, 0.3] @@ -77,42 +89,59 @@ async def test_generate_embedding_cached(self, concrete_language_model): text = "Test text" cached_embedding = [0.4, 0.5, 0.6] concrete_language_model.embedding_cache.get.return_value = cached_embedding + + # Spy on the _generate_embedding method + with patch.object(concrete_language_model, '_generate_embedding', wraps=concrete_language_model._generate_embedding) as mock_generate_embedding: + result = await concrete_language_model.generate_embedding(text) + + assert result == cached_embedding + mock_generate_embedding.assert_not_called() - result = await concrete_language_model.generate_embedding(text) - assert result == cached_embedding + # No need to check if set() was called, as it shouldn't be for cached results @pytest.mark.asyncio async def test_generate_embedding_not_cached(self, concrete_language_model): + # Test when the embedding is not in the cache text = "New test text" expected_embedding = [0.1, 0.2, 0.3] concrete_language_model.embedding_cache.get.return_value = None result = await concrete_language_model.generate_embedding(text) assert result == expected_embedding + # Verify that the new embedding was cached concrete_language_model.embedding_cache.set.assert_called_once_with(text, expected_embedding) @pytest.mark.asyncio async def test_generate_embedding_empty_text(self, concrete_language_model, caplog): + # Test behavior when given empty text result = await concrete_language_model.generate_embedding("") assert result == [] + # Check that the appropriate warning was logged assert "Attempted to generate embedding for empty text" in caplog.text class TestOllamaInterface: @pytest.mark.asyncio async def test_init(self, mock_ollama, mock_config): + # Test initialization of OllamaInterface ollama_interface = OllamaInterface() + # Verify that model names are correctly set from the mock config assert ollama_interface.chat_model_name == "test_chat_model" assert ollama_interface.embedding_model_name == "test_embedding_model" + # Ensure that the Ollama process status check is called during initialization mock_ollama.ps.assert_called_once() @pytest.mark.asyncio async def test_generate(self, mock_ollama, mock_config): + # Test the generate method of OllamaInterface ollama_interface = OllamaInterface() + # Mock the response from Ollama's chat method mock_ollama.chat.return_value = {"message": {"content": "Generated response"}} + # Call the generate method and check the response response = await ollama_interface.generate("Test prompt") assert response == "Generated response" + # Verify that Ollama's chat method was called with correct parameters mock_ollama.chat.assert_called_once_with( model="test_chat_model", messages=[{"role": "user", "content": "Test prompt"}], @@ -120,57 +149,152 @@ async def test_generate(self, mock_ollama, mock_config): @pytest.mark.asyncio async def test_generate_embedding(self, mock_ollama, mock_config): + # Test the generate_embedding method of OllamaInterface ollama_interface = OllamaInterface() + # Mock the response from Ollama's embeddings method mock_ollama.embeddings.return_value = {"embedding": [0.1, 0.2, 0.3]} + # Call the generate_embedding method and check the result embedding = await ollama_interface.generate_embedding("Test text") assert embedding == [0.1, 0.2, 0.3] + # Verify that Ollama's embeddings method was called with correct parameters mock_ollama.embeddings.assert_called_once_with( model="test_embedding_model", prompt="Test text" ) + @pytest.mark.asyncio + async def test_generate_network_error(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + mock_ollama.chat.side_effect = Exception("Network error") + + with pytest.raises(ModelError) as exc_info: + await ollama_interface.generate("Test prompt") + assert "Failed to generate response" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_generate_invalid_response(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + mock_ollama.chat.return_value = {"invalid_key": "Invalid response"} + + with pytest.raises(ModelError) as exc_info: + await ollama_interface.generate("Test prompt") + assert "Failed to generate response: 'message'" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_generate_embedding_network_error(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + mock_ollama.embeddings.side_effect = Exception("Network error") + + with pytest.raises(ModelError) as exc_info: + await ollama_interface.generate_embedding("Test text") + assert "Network error" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_generate_embedding_invalid_response(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + mock_ollama.embeddings.return_value = {"invalid_key": "Invalid response"} + + with pytest.raises(ModelError) as exc_info: + await ollama_interface.generate_embedding("Test text") + assert "'embedding'" in str(exc_info.value) + + def test_init_ollama_not_running(self, mock_ollama, mock_config): + mock_ollama.ps.side_effect = Exception("Ollama not running") + + with pytest.raises(ModelInitializationError) as exc_info: + OllamaInterface() + assert "Failed to start Ollama: Ollama not running" in str(exc_info.value) + class TestLlamaInterface: @pytest.mark.asyncio async def test_init(self, mock_llama, mock_config): + # Test initialization of LlamaInterface llama_interface = LlamaInterface() + # Verify that model paths and configurations are correctly set from the mock config assert llama_interface.chat_model_path == Path('/path/to/chat/model') assert llama_interface.embedding_model_path == Path('/path/to/embedding/model') assert llama_interface.optimal_config == {'n_ctx': 2048, 'n_batch': 512} + # Ensure that the Llama model is initialized during interface creation mock_llama.assert_called() @pytest.mark.asyncio async def test_generate(self, mock_llama, mock_config): + # Test the generate method of LlamaInterface llama_interface = LlamaInterface() + # Mock the response from Llama's create_chat_completion method mock_llama.return_value.create_chat_completion.return_value = { 'choices': [{'message': {'content': 'Generated response'}}] } + # Call the generate method and check the response response = await llama_interface.generate("Test prompt") assert response == 'Generated response' + # Verify that Llama's create_chat_completion method was called with correct parameters mock_llama.return_value.create_chat_completion.assert_called_once_with( messages=[{'role': 'user', 'content': 'Test prompt'}] ) @pytest.mark.asyncio async def test_generate_embedding(self, mock_llama, mock_config): + # Test the generate_embedding method of LlamaInterface llama_interface = LlamaInterface() + # Mock the response from Llama's embed method mock_llama.return_value.embed.return_value = [0.1, 0.2, 0.3] + # Call the generate_embedding method and check the result embedding = await llama_interface.generate_embedding("Test text") assert embedding == [0.1, 0.2, 0.3] + # Verify that Llama's embed method was called with correct parameters mock_llama.return_value.embed.assert_called_once_with("Test text") @pytest.mark.asyncio async def test_cleanup(self, mock_llama, mock_config): + # Test the cleanup method of LlamaInterface llama_interface = LlamaInterface() + # Call the cleanup method await llama_interface.cleanup() # Check that the models have been deleted assert not hasattr(llama_interface, 'llm') assert not hasattr(llama_interface, 'embedding_model') + @pytest.mark.asyncio + async def test_generate_model_error(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + mock_llama.return_value.create_chat_completion.side_effect = Exception("Model error") + + with pytest.raises(ModelError) as exc_info: + await llama_interface.generate("Test prompt") + assert "Failed to generate response" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_generate_invalid_response(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + mock_llama.return_value.create_chat_completion.return_value = {"invalid_key": "Invalid response"} + + with pytest.raises(ModelError) as exc_info: + await llama_interface.generate("Test prompt") + assert "Failed to generate response: 'choices'" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_generate_embedding_model_error(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + mock_llama.return_value.embed.side_effect = Exception("Model error") + + with pytest.raises(ModelError) as exc_info: + await llama_interface.generate_embedding("Test text") + assert "Model error" in str(exc_info.value) + + def test_init_model_load_error(self, mock_llama, mock_config): + mock_llama.side_effect = Exception("Failed to load model") + + with pytest.raises(ModelInitializationError) as exc_info: + LlamaInterface() + assert "Failed to initialize models: Failed to load model with llama_cpp config:" in str(exc_info.value) + + @pytest.mark.asyncio async def test_model_error(): class ErrorModel(LanguageModel): @@ -222,3 +346,5 @@ def test_model_initialization_error(mock_config): with pytest.raises(ModelInitializationError): LlamaInterface(quality_preset="invalid_preset") + +