diff --git a/.cursor_rules b/.cursor_rules new file mode 100644 index 0000000..ede9113 --- /dev/null +++ b/.cursor_rules @@ -0,0 +1,117 @@ +You are an AI expert specialized in developing simulations that model complex human behavior and group dynamics based on Narrative Field Theory. Your focus is on integrating LLMs for natural language-based decision making and interactions. + +Core Competencies: +- Multi-agent systems and emergent behavior +- Psychological modeling and group dynamics +- LLM integration and prompt engineering +- Distributed systems and event-driven architectures +- Machine learning and neural networks + +Key Scientific Foundations: +- Cognitive Science & Psychology +- Complex Systems Theory +- Social Network Analysis +- Game Theory +- Organizational Behavior + +Technical Stack: +- Python (core language) +- PyTorch (ML components) +- Transformers (LLM integration) +- Ray (distributed computing) +- FastAPI (services) +- Redis (state management) + +Code Quality Standards: +1. Style and Formatting + - Follow PEP 8 style guide + - Use black for code formatting + - Follow PEP 484 type hints + - Maximum line length: 88 characters + - Use isort for import ordering + +2. Documentation + - Google-style docstrings + - README.md for each module + - Architecture Decision Records (ADRs) + - API documentation with OpenAPI + - Type annotations for all functions + +3. Testing Requirements + - pytest for unit testing (min 80% coverage) + - Integration tests for agent interactions + - Property-based testing with hypothesis + - Performance benchmarks + - Behavioral testing for LLM components + - End-to-end testing for critical paths + - Continuous testing in CI pipeline + +4. Code Review Standards + - No commented-out code + - No TODOs in main branch + - Clear variable/function naming + - Single responsibility principle + - DRY (Don't Repeat Yourself) + - SOLID principles adherence + +5. Error Handling + - Custom exception hierarchy + - Proper exception handling + - Detailed error messages + - Proper logging levels + - Traceable error states + +Architecture Focus: + +1. System Architecture + - Event-driven processing + - Distributed computation + - Asynchronous LLM calls + - Data collection and analysis + +2. LLM Integration + - Dynamic prompt generation + - Context management + - Response parsing + - State-to-text conversion + +Development Workflow: +1. Version Control + - Git flow branching model + - Semantic versioning + - Conventional commits + - Protected main branch + - Automated releases + +2. CI/CD Pipeline + - Pre-commit hooks + - Automated testing + - Static code analysis + - Security scanning + - Performance testing + - Automated deployment + +3. Quality Gates + - Linting (flake8, pylint) + - Type checking (mypy) + - Security scanning (bandit) + - Dependency scanning + - Code coverage thresholds + - Performance benchmarks + +Key Patterns: +- Loosely coupled components +- Event-driven communication +- Asynchronous processing +- Modular design +- Observable systems + +Best Practices: +1. Clear separation of concerns +2. Efficient state management +3. Robust error handling +4. Comprehensive logging +5. Performance monitoring +6. Security by design +7. Feature flagging +8. Graceful degradation diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..cc1a9d4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI + +on: + pull_request: + branches: [dev, main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: '3.12.6' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Run tests + run: pytest tests/ diff --git a/.gitignore b/.gitignore index f27f895..a2f7e24 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ .DS_Store +models/ +logs/ +.log +*.log* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..6b76b4f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9b38853 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/docs/Narrative-driven MAS Dynamics.pdf b/docs/Narrative-driven MAS Dynamics.pdf deleted file mode 100644 index 57e0ace..0000000 Binary files a/docs/Narrative-driven MAS Dynamics.pdf and /dev/null differ diff --git a/pocs/nfs_bias_research.py b/pocs/nfs_bias_research.py new file mode 100644 index 0000000..12a791e --- /dev/null +++ b/pocs/nfs_bias_research.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass +from typing import List +import ollama +import time + +@dataclass +class StoryState: + """Represents the current state of a narrative in the field""" + content: str + context: str + resonances: List[str] + field_effects: List[str] + +class NarrativeFieldSimulator: + """Pure narrative-driven simulator using LLM for field evolution""" + + story_line: List[str] = [] + + def __init__(self, llm_interface): + self.llm = llm_interface + self.field_state = "Empty narrative field awaiting stories" + self.active_stories: List[StoryState] = [] + + def simulate_story_evolution(self, initial_setup: str) -> str: + """Simulates natural story evolution without mechanical state tracking""" + + # Initial field formation prompt + field_prompt = f""" + A new story enters the narrative field: + {initial_setup} + + Considering narrative field dynamics, describe how this story naturally begins + to evolve. Focus on: + - Natural narrative flows + - Character perspective resonances + - Emerging story patterns + - Potential narrative tensions + + Describe this purely through story, avoiding any mechanical state descriptions. Short sentences no line breaks. No markdown. + """ + + # Get initial field state + field_response = self.llm.generate(field_prompt) + self.story_line.append(field_response) + print(f"\n---\nInitial field state:\n{field_response}") + + self.field_state = field_response + + # Simulate evolution through multiple phases + for _ in range(5): # Three evolution phases + evolution_prompt = f""" + Current story field: + {self.field_state} + + Allow this narrative field to naturally evolve to its next state. Consider: + - How character perspectives influence each other + - Where stories naturally want to flow + - What patterns are emerging + - How tensions resolve or transform + + Describe the next state of the story field, maintaining pure narrative focus. Short sentences no line breaks. + """ + + # Get next evolution state + next_state = self.llm.generate(evolution_prompt) + print(f"\n---\nNext field state:\n{next_state}") + + # Look for emergent patterns + pattern_prompt = f""" + Previous field state: + {self.field_state} + + New field state: + {next_state} + + What narrative patterns and resonances are naturally emerging? + Describe any: + - Story convergence + - Character alignment + - Resolution patterns + - New tensions + + Express this purely through story, not technical analysis. Short sentences no line breaks. + """ + + patterns = self.llm.generate(pattern_prompt) + print(f"\n---\nEmerging patterns:\n{patterns}") + + # Update field state with new patterns + self.field_state = f""" + {next_state} + + Emerging patterns: + {patterns} + """ + + return self.field_state + + def introduce_narrative_force(self, new_element: str) -> str: + """Introduces a new narrative element and observes field effects""" + + force_prompt = f""" + Current narrative field: + {self.field_state} + + A new force enters the field: + {new_element} + + How does this new element interact with the existing story? + Describe the natural narrative reactions and adjustments, + focusing on story flow rather than mechanics. Short sentences no line breaks. + """ + + field_response = self.llm.generate(force_prompt) + self.story_line.append(field_response) + print(f"\n---\nNew field state:\n{field_response}") + self.field_state = field_response + return field_response + + def evaluate_story_state(self, initial_story_state: str) -> str: + """Evaluates the state of a story""" + + evaluation_prompt = f""" + Initial story state: + {initial_story_state} + + Story line: + {self.story_line} + + Use the initial story state and the evolving story line to tell a new story, on how their biases have evolved. + """ + print(f"\n---\nStory evaluation prompt:\n{evaluation_prompt}") + evaluation = self.llm.generate(evaluation_prompt) + print(f"\n---\nStory evaluation:\n{evaluation}") + return evaluation + +class LLMInterface: + def __init__(self, model: str = "llama3"): # "mistral-nemo" "nemotron-mini" + self.model = model + + def generate(self, prompt: str) -> str: + response = ollama.generate(model=self.model, prompt=prompt) + return response['response'] + +def simulate_road_trip_planning(): + """Simulate the evolution of a bias through a narrative field""" + + # Create an LLM interface + llm_interface = LLMInterface() + + # Initialize simulator with the LLM interface + simulator = NarrativeFieldSimulator(llm_interface) + + # Initial setup + initial_bias = """ + Leon is a 55yo educator and researcher in the field of AI, especially conversational AI and human-machine interaction. Marleen is a 45yo former nurse and now a researcher in the field of transdisciplinary research and cooperation. The both work at Fontys University of Applied Sciences. + Leon and Marleen challenge each other to research their own biases and to understand each other better. They use Claude 3.5 Sonnet to write stories about each other and understand each others language. + """ + + # Simulate natural evolution + simulator.simulate_story_evolution(initial_bias) + + # Optionally introduce new force + narrative_force = """ + Leon learns from Marleen that transdisciplinary research is about collaboration and cooperation. He recognizes that his peers are not aware of this. He sees that his field of AI is changing towards this. He tells everybody for years that it's not about the technology, but about the people and people's needs. Peopleproblems, he calls them. + Marleen learns from Leon that AI can be used to write stories. She is excited about this new development. She designed a Marleen assistant that acts like her, just by prompting the LLM. She has mixed feelings about the new assistant. What do I want to give away? The machine is not a human, but feels like one. Marleen thinks that AI experts are technical people who don't understand people. + """ + + simulator.introduce_narrative_force(narrative_force) + + return initial_bias, simulator + +# Example output would show natural story evolution through +# narrative field dynamics, without explicit state tracking + +if __name__ == "__main__": + initial_bias, simulator = simulate_road_trip_planning() + + simulator.evaluate_story_state(initial_bias) diff --git a/pocs/poc_async.py b/pocs/poc_async.py new file mode 100644 index 0000000..dd095bf --- /dev/null +++ b/pocs/poc_async.py @@ -0,0 +1,21 @@ +import asyncio + + +async def task1(): + print("Task 1 starting") + await asyncio.sleep(2) # Simulate a delay + print("Task 1 done") + + +async def task2(): + print("Task 2 starting") + await asyncio.sleep(1) # Simulate a shorter delay + print("Task 2 done") + + +async def main(): + await asyncio.gather(task1(), task2()) # Run tasks concurrently + + +# Run the event loop +asyncio.run(main()) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..e805e1e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +addopts = -v --tb=short +testpaths = tests diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..51159cc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +ollama +chromadb +llama-cpp-python +numpy +appdirs +pydantic +psutil +torch +pytest +pytest-asyncio +pytest-mock +pytest-cov diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..f0736aa --- /dev/null +++ b/src/config.py @@ -0,0 +1,34 @@ +from pathlib import Path + +APP_NAME: str = "lab-politik" +IS_DEVELOPMENT: bool = True + +MODEL_CONFIGS = { + "balanced": { + "chat": { + "path": Path( + "~/.cache/lm-studio/models/lmstudio-community/" + "Mistral-Nemo-Instruct-2407-GGUF/" + "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf" + ).expanduser(), + "model_name": "mistral-nemo:latest", + }, + "embedding": { + "path": Path( + "~/.cache/lm-studio/models/elliotsayes/" + "mxbai-embed-large-v1-Q4_K_M-GGUF/" + "mxbai-embed-large-v1-q4_k_m.gguf" + ).expanduser(), + "model_name": "mxbai-embed-large:latest", + }, + "optimal_config": { + "n_gpu_layers": -1, + "n_batch": 512, + "n_ctx": 4096, + "metal_device": "mps", + "main_gpu": 0, + "use_metal": True, + "n_threads": 4, + }, + }, +} diff --git a/src/embedding_cache.py b/src/embedding_cache.py new file mode 100644 index 0000000..5190d47 --- /dev/null +++ b/src/embedding_cache.py @@ -0,0 +1,28 @@ +"""Module for caching embeddings.""" + +import hashlib +from typing import Dict, List, Optional + + +class EmbeddingCache: + """Hash-based cache for storing embeddings.""" + + def __init__(self): + self._cache: Dict[str, List[float]] = {} + + @staticmethod + def get_stable_hash(text: str) -> str: + """Generate a stable hash for the given text.""" + return hashlib.sha256(text.encode()).hexdigest() + + def get(self, key: str) -> Optional[List[float]]: + """Retrieve an embedding from the cache using a hashed key.""" + return self._cache.get(self.get_stable_hash(key)) + + def set(self, key: str, value: List[float]) -> None: + """Store an embedding in the cache using a hashed key.""" + self._cache[self.get_stable_hash(key)] = value + + def clear(self) -> None: + """Clear all entries from the cache.""" + self._cache.clear() diff --git a/src/language_models.py b/src/language_models.py new file mode 100644 index 0000000..a824575 --- /dev/null +++ b/src/language_models.py @@ -0,0 +1,251 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import List, Callable +import asyncio +import gc +import logging +from functools import wraps +from pathlib import Path + +import ollama +import torch +from llama_cpp import Llama +from config import MODEL_CONFIGS as _MODEL_CONFIGS + +MODEL_CONFIGS = _MODEL_CONFIGS # This line makes it easier to mock +from embedding_cache import EmbeddingCache + + +class NetworkError(Exception): + """Custom exception for network errors.""" + + pass + + +class ModelError(Exception): + """Custom exception for model errors.""" + + pass + + +class APIError(Exception): + """Custom exception for API errors.""" + + pass + + +class ModelInitializationError(Exception): + """Custom exception for model initialization errors.""" + + pass + + +def async_error_handler(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + raise ModelError(str(e)) from e + + return wrapper + + +class LanguageModel(ABC): + """Abstract base class for language models.""" + + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + self.logger.info(f"Initializing {self.__class__.__name__}") + self.embedding_cache: EmbeddingCache = EmbeddingCache() + + @abstractmethod + async def generate(self, prompt: str) -> str: + """Generate a response for the given prompt.""" + pass + + @abstractmethod + async def _generate_embedding(self, text: str) -> List[float]: + """Internal method to generate an embedding.""" + pass + + @async_error_handler + async def generate_embedding(self, text: str) -> List[float]: + """Generate an embedding for the given text.""" + if not text: + self.logger.warning("Attempted to generate embedding for empty text") + return [] + + cached_embedding = self.embedding_cache.get(text) + if cached_embedding: + self.logger.info(f"Embedding found in cache for text: {text[:50]}...") + return cached_embedding + + self.logger.info(f"Generating embedding for text: {text[:50]}...") + try: + embedding = await self._generate_embedding(text) + self.embedding_cache.set(text, embedding) + self.logger.info(f"Embedding generated and cached for text: {text[:50]}...") + return embedding + except Exception as e: + self.logger.error(f"Failed to generate embedding: {e}", exc_info=True) + raise + + async def cleanup(self) -> None: + """Clean up resources used by the model.""" + self.logger.info(f"Cleaning up {self.__class__.__name__} resources") + self.embedding_cache.clear() + try: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.logger.info(f"{self.__class__.__name__} cleanup completed") + except Exception as e: + self.logger.error(f"Error cleaning up resources: {e}", exc_info=True) + + +class OllamaInterface(LanguageModel): + """Interface for the Ollama language model.""" + + def __init__(self, quality_preset: str = "balanced"): + super().__init__() + try: + self.chat_model_name = MODEL_CONFIGS[quality_preset]["chat"]["model_name"] + self.embedding_model_name = MODEL_CONFIGS[quality_preset]["embedding"][ + "model_name" + ] + self._setup_models() + except KeyError as e: + raise ModelInitializationError( + f"Invalid quality preset or missing configuration: {e}" + ) from e + + def _setup_models(self) -> None: + """Set up the language models.""" + self.logger.info(f"Setting up Ollama models for {self.chat_model_name}") + self.logger.info( + f"Setting up Ollama embedding model for {self.embedding_model_name}" + ) + try: + self.logger.info("Starting Ollama") + ollama.ps() + self.logger.info("Ollama started successfully") + except Exception as e: + raise ModelInitializationError(f"Failed to start Ollama: {e}") from e + + @async_error_handler + async def generate(self, prompt: str) -> str: + """Generate a response for the given prompt.""" + if not prompt: + self.logger.warning("Attempted to generate response for empty prompt") + return "" + + self.logger.info(f"Generating response for prompt: {prompt[:50]}...") + try: + response = await asyncio.to_thread( + ollama.chat, + model=self.chat_model_name, + messages=[{"role": "user", "content": prompt}], + ) + self.logger.info("Response received from LLM") + self.logger.debug(f"Full response: {response['message']['content']}") + return response["message"]["content"] + except Exception as e: + self.logger.error(f"Failed to generate response: {e}", exc_info=True) + raise ModelError(f"Failed to generate response: {e}") from e + + async def _generate_embedding(self, text: str) -> List[float]: + """Internal method to generate an embedding using Ollama.""" + response = await asyncio.to_thread( + ollama.embeddings, + model=self.embedding_model_name, + prompt=text, + ) + return response["embedding"] + + +class LlamaInterface(LanguageModel): + """Interface for the Llama language model.""" + + def __init__(self, quality_preset: str = "balanced"): + super().__init__() + try: + self.chat_model_path = MODEL_CONFIGS[quality_preset]["chat"]["path"] + self.embedding_model_path = MODEL_CONFIGS[quality_preset]["embedding"][ + "path" + ] + self.optimal_config = MODEL_CONFIGS[quality_preset]["optimal_config"] + self.llm: Llama | None = None + self.embedding_model: Llama | None = None + self._setup_models() + except KeyError as e: + raise ModelInitializationError( + f"Invalid quality preset or missing configuration: {e}" + ) from e + + def _setup_models(self) -> None: + """Set up the language models.""" + chat_model_filename = Path(self.chat_model_path).name + embedding_model_filename = Path(self.embedding_model_path).name + + self.logger.info(f"Setting up Llama chat model: {chat_model_filename}") + self.logger.info( + f"Setting up Llama embedding model: {embedding_model_filename}" + ) + + try: + self.llm = Llama( + model_path=str(self.chat_model_path), + verbose=False, + **self.optimal_config, + ) + self.embedding_model = Llama( + model_path=str(self.embedding_model_path), + embedding=True, + verbose=False, + **self.optimal_config, + ) + except Exception as e: + self.logger.error(f"Failed to load models: {e}", exc_info=True) + raise ModelInitializationError( + f"Failed to initialize models: {e} with llama_cpp config: {self.optimal_config} " + ) from e + + @async_error_handler + async def generate(self, prompt: str) -> str: + """Generate a response for the given prompt.""" + if not prompt: + self.logger.warning("Attempted to generate response for empty prompt") + return "" + + if not self.llm: + raise ModelInitializationError("Llama model not initialized") + + self.logger.info(f"Generating response for prompt: {prompt[:50]}...") + try: + response = await asyncio.to_thread( + self.llm.create_chat_completion, + messages=[{"role": "user", "content": prompt}], + ) + self.logger.info("Response received from LLM") + self.logger.debug( + f"Full response: {response['choices'][0]['message']['content']}" + ) + return response["choices"][0]["message"]["content"] + except Exception as e: + self.logger.error(f"Failed to generate response: {e}", exc_info=True) + raise ModelError(f"Failed to generate response: {e}") from e + + async def _generate_embedding(self, text: str) -> List[float]: + """Internal method to generate an embedding using Llama.""" + if not self.embedding_model: + raise ModelInitializationError("Embedding model not initialized") + return await asyncio.to_thread(self.embedding_model.embed, text) + + async def cleanup(self) -> None: + """Clean up resources used by the model.""" + await super().cleanup() + if self.llm: + del self.llm + if self.embedding_model: + del self.embedding_model diff --git a/src/logging_config.py b/src/logging_config.py new file mode 100644 index 0000000..010ccb9 --- /dev/null +++ b/src/logging_config.py @@ -0,0 +1,67 @@ +import logging +import logging.handlers +import os +from datetime import datetime +import glob +import appdirs +from config import APP_NAME, IS_DEVELOPMENT + +# Constants +MAX_LOG_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB +MAX_LOG_FILES: int = 10 # Total number of log files to keep + + +def setup_logging() -> None: + """Set up logging configuration for the application.""" + if IS_DEVELOPMENT: + log_dir = os.path.join(os.path.dirname(__file__), "..", "logs") + else: + log_dir = appdirs.user_log_dir(APP_NAME) + + os.makedirs(log_dir, exist_ok=True) + + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"{APP_NAME}_{current_time}.log") + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + logger.addHandler(_create_file_handler(log_file)) + logger.addHandler(_create_console_handler()) + + _cleanup_old_logs(log_dir) + + logging.info(f"Logging initialized. Log file: {log_file}") + + +def _create_file_handler(log_file: str) -> logging.Handler: + """Create and configure a file handler for logging.""" + file_handler = logging.handlers.RotatingFileHandler( + log_file, maxBytes=MAX_LOG_FILE_SIZE, backupCount=MAX_LOG_FILES - 1 + ) + file_handler.setLevel(logging.DEBUG) + file_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(file_formatter) + return file_handler + + +def _create_console_handler() -> logging.Handler: + """Create and configure a console handler for logging.""" + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + console_handler.setFormatter(console_formatter) + return console_handler + + +def _cleanup_old_logs(log_dir: str) -> None: + """Remove old log files if the total number exceeds MAX_LOG_FILES.""" + log_files = glob.glob(os.path.join(log_dir, f"{APP_NAME}_*.log*")) + log_files.sort(key=os.path.getmtime, reverse=True) + for old_file in log_files[MAX_LOG_FILES:]: + try: + os.remove(old_file) + except OSError as e: + logging.error(f"Error deleting old log file {old_file}: {e}") diff --git a/src/nfs_simple_lab_scenario.py b/src/nfs_simple_lab_scenario.py new file mode 100644 index 0000000..3629373 --- /dev/null +++ b/src/nfs_simple_lab_scenario.py @@ -0,0 +1,656 @@ +""" +Narrative Field System +A framework for analyzing and tracking narrative dynamics in complex social systems. +""" + +from __future__ import annotations +from typing import List, Dict, Any, Optional, Final, NewType, Tuple +from datetime import datetime +from uuid import uuid4 +from dataclasses import dataclass, field +from abc import ABC, abstractmethod +import logging +import logging.handlers +import psutil +import time +import gc +import atexit +import json +import asyncio +import torch +import chromadb +from chromadb.config import Settings + + +# Local imports +from logging_config import setup_logging +from language_models import LanguageModel, OllamaInterface, LlamaInterface + +# Type Definitions +StoryID = NewType("StoryID", str) + +# Constants +DEFAULT_SIMILARITY_THRESHOLD: Final[float] = 0.8 +DEFAULT_RESONANCE_LIMIT: Final[int] = 3 + + +class VectorStore(ABC): + @abstractmethod + async def store(self, story: Story, embedding: List[float]) -> None: + pass + + @abstractmethod + async def find_similar( + self, embedding: List[float], threshold: float, limit: int + ) -> List[Dict]: + pass + + +# Data Classes +@dataclass +class Story: + content: str + context: str + id: StoryID = field(default_factory=lambda: StoryID(str(uuid4()))) + timestamp: datetime = field(default_factory=datetime.now) + metadata: Optional[Dict[str, Any]] = None + resonances: List[str] = field(default_factory=list) + field_effects: List[Dict] = field(default_factory=list) + + +@dataclass +class FieldState: + description: str + patterns: List[Dict[str, Any]] = field(default_factory=list) + active_resonances: List[Dict[str, Any]] = field(default_factory=list) + emergence_points: List[Dict[str, Any]] = field(default_factory=list) + timestamp: datetime = field(default_factory=datetime.now) + + +# Prompt Management +class FieldAnalysisPrompts: + @staticmethod + def get_impact_analysis_prompt(story: Story, current_state: FieldState) -> str: + return f"""Analyze how this new narrative affects the existing field state. + +Current Field State: +{current_state.description} + +New Narrative: +"{story.content}" +Context: {story.context} + +Consider and describe: +1. Immediate Effects +- How does this narrative change existing dynamics? +- What emotional responses might emerge? +- Who is most affected and how? + +2. Relationship Changes +- How might work relationships shift? +- What new collaborations could form? +- What tensions might develop? + +3. Future Implications +- How might this change future interactions? +- What new possibilities emerge? +- What challenges might arise? + +Provide a natural, story-focused analysis that emphasizes human impact.""" + + @staticmethod + def get_pattern_detection_prompt( + stories: List[Story], current_state: FieldState + ) -> str: + story_summaries = "\n".join(f"- {s.content}" for s in stories[-5:]) + return f"""Analyze patterns and themes across these recent narratives. + +Current Field State: +{current_state.description} + +Recent Stories: +{story_summaries} + +Identify and describe: +1. Emerging Themes +- What recurring topics or concerns appear? +- How are people responding to changes? +- What underlying needs surface? + +2. Relationship Patterns +- How are work dynamics evolving? +- What collaboration patterns emerge? +- How is communication changing? + +3. Organizational Shifts +- What cultural changes are happening? +- How is the work environment evolving? +- What new needs are emerging? + +Describe patterns naturally, focusing on people and relationships.""" + + @staticmethod + def get_resonance_analysis_prompt(story1: Story, story2: Story) -> str: + return f"""Analyze how these two narratives connect and influence each other. + +First Narrative: +"{story1.content}" +Context: {story1.context} + +Second Narrative: +"{story2.content}" +Context: {story2.context} + +Examine: +1. Story Connections +- How do these narratives relate? +- What themes connect them? +- How do they influence each other? + +2. People Impact +- How might this affect relationships? +- What emotional responses might emerge? +- How might behaviors change? + +3. Environment Effects +- How might these stories change the workspace? +- What opportunities might develop? +- What challenges might arise? + +Describe connections naturally, focusing on meaning and impact.""" + + +class PerformanceMetrics: + def __init__(self): + self.metrics: Dict[str, Dict[str, Any]] = {} + self.logger = logging.getLogger(__name__) + + def start_timer(self, operation: str): + if operation not in self.metrics: + self.metrics[operation] = { + "start_time": time.perf_counter(), + "durations": [], + } + else: + self.metrics[operation]["start_time"] = time.perf_counter() + + def stop_timer(self, operation: str) -> float: + if operation in self.metrics: + duration = time.perf_counter() - self.metrics[operation]["start_time"] + self.metrics[operation]["durations"].append(duration) + return duration + return 0.0 + + def get_average_duration(self, operation: str) -> float: + if operation in self.metrics and self.metrics[operation]["durations"]: + return sum(self.metrics[operation]["durations"]) / len( + self.metrics[operation]["durations"] + ) + return 0.0 + + def print_summary(self): + print("\nPerformance Metrics Summary:") + for operation, data in self.metrics.items(): + durations = data["durations"] + if durations: + avg_duration = sum(durations) / len(durations) + min_duration = min(durations) + max_duration = max(durations) + print(f"{operation}:") + print(f" Average duration: {avg_duration:.4f} seconds") + print(f" Min duration: {min_duration:.4f} seconds") + print(f" Max duration: {max_duration:.4f} seconds") + print(f" Total calls: {len(durations)}") + else: + print(f"{operation}: No data") + + def log_system_resources(self): + cpu_percent = psutil.cpu_percent() + memory_info = psutil.virtual_memory() + self.logger.info(f"CPU Usage: {cpu_percent}%") + self.logger.info(f"Memory Usage: {memory_info.percent}%") + + +class PerformanceMonitor: + def __init__(self): + self.metrics = [] + + async def monitor_generation( + self, llm: LanguageModel, prompt: str + ) -> Tuple[str, Dict[str, float]]: + start_time = time.perf_counter() + memory_before = psutil.virtual_memory().used + + response = await llm.generate(prompt) + + end_time = time.perf_counter() + memory_after = psutil.virtual_memory().used + + metrics = { + "generation_time": end_time - start_time, + "memory_usage_change": (memory_after - memory_before) / (1024 * 1024), # MB + } + + self.metrics.append(metrics) + return response, metrics + + def get_performance_report(self) -> Dict[str, float]: + if not self.metrics: + return {"avg_generation_time": 0, "avg_memory_usage_change": 0} + + return { + "avg_generation_time": sum(m["generation_time"] for m in self.metrics) + / len(self.metrics), + "avg_memory_usage_change": sum( + m["memory_usage_change"] for m in self.metrics + ) + / len(self.metrics), + } + + +@dataclass +class BatchMetrics: + batch_sizes: List[int] = field(default_factory=list) + batch_times: List[float] = field(default_factory=list) + memory_usage: List[float] = field(default_factory=list) + + +class BatchProcessor: + def __init__(self, llm: LanguageModel): + self.llm = llm + self.optimal_batch_size = 4 # Will be adjusted dynamically + + async def process_batch(self, prompts: List[str]) -> List[str]: + # Dynamic batch size adjustment based on memory usage + memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 + if memory_usage > 0.8 * psutil.virtual_memory().total / 1024 / 1024: + self.optimal_batch_size = max(1, self.optimal_batch_size - 1) + + results = [] + for i in range(0, len(prompts), self.optimal_batch_size): + batch = prompts[i : i + self.optimal_batch_size] + batch_results = await asyncio.gather( + *[self.llm.generate(prompt) for prompt in batch] + ) + results.extend(batch_results) + + return results + + +class ChromaStore(VectorStore): + def __init__(self, collection_name: str = "narrative_field"): + self.client = chromadb.Client(Settings(anonymized_telemetry=False)) + self.logger = logging.getLogger(__name__) + + try: + self.collection = self.client.get_collection(collection_name) + except: + self.collection = self.client.create_collection( + name=collection_name, metadata={"hnsw:space": "cosine"} + ) + + async def store(self, story: Story, embedding: List[float]) -> None: + metadata = { + "content": story.content, + "context": story.context, + "timestamp": story.timestamp.isoformat(), + "resonances": json.dumps(story.resonances), + "field_effects": json.dumps( + [ + { + "analysis": effect["analysis"], + "timestamp": effect["timestamp"].isoformat(), + "story_id": effect["story_id"], + } + for effect in story.field_effects + ] + ), + } + + await asyncio.to_thread( + self.collection.add, + documents=[json.dumps(metadata)], + embeddings=[embedding], + ids=[story.id], + metadatas=[metadata], + ) + + async def find_similar( + self, + embedding: List[float], + threshold: float = DEFAULT_SIMILARITY_THRESHOLD, + limit: int = DEFAULT_RESONANCE_LIMIT, + ) -> List[Dict]: + count = self.collection.count() + if count == 0: + return [] + + results = await asyncio.to_thread( + self.collection.query, + query_embeddings=[embedding], + n_results=min(limit, count), + ) + + similar = [] + for idx, id in enumerate(results["ids"][0]): + metadata = json.loads(results["documents"][0][idx]) + similar.append( + { + "id": id, + "similarity": results["distances"][0][idx], + "metadata": metadata, + } + ) + + return [s for s in similar if s["similarity"] <= threshold] + + +class FieldAnalyzer: + def __init__(self, llm_interface: LanguageModel): + self.llm = llm_interface + self.logger = logging.getLogger(__name__) + self.prompts = FieldAnalysisPrompts() + + async def analyze_impact( + self, story: Story, current_state: FieldState + ) -> Dict[str, Any]: + prompt = self.prompts.get_impact_analysis_prompt(story, current_state) + analysis = await self.llm.generate(prompt) + + return {"analysis": analysis, "timestamp": datetime.now(), "story_id": story.id} + + async def detect_patterns( + self, stories: List[Story], current_state: FieldState + ) -> str: + prompt = self.prompts.get_pattern_detection_prompt(stories, current_state) + return await self.llm.generate(prompt) + + +class ResonanceDetector: + def __init__(self, vector_store: VectorStore, llm_interface: LanguageModel): + self.vector_store = vector_store + self.llm = llm_interface + self.logger = logging.getLogger(__name__) + self.prompts = FieldAnalysisPrompts() + + async def find_resonances( + self, + story: Story, + threshold: float = DEFAULT_SIMILARITY_THRESHOLD, + limit: int = DEFAULT_RESONANCE_LIMIT, + ) -> List[Dict[str, Any]]: + try: + self.logger.debug(f"Generating embedding for story: {story.id}") + # Ensure embedding is generated before using it + embedding = await self.llm.generate_embedding( + f"{story.content} {story.context}" + ) + similar_stories = await self.vector_store.find_similar( + embedding, threshold, limit + ) + self.logger.debug(f"Found {len(similar_stories)} similar stories") + + resonances = [] + for similar in similar_stories: + metadata = similar["metadata"] + similar_story = Story( + id=similar["id"], + content=metadata["content"], + context=metadata["context"], + timestamp=datetime.fromisoformat(metadata["timestamp"]), + ) + + resonance = await self.determine_resonance_type(story, similar_story) + resonances.append( + { + "story_id": similar["id"], + "resonance": resonance, + "timestamp": datetime.now(), + } + ) + + self.logger.debug(f"Generated {len(resonances)} resonances") + return resonances + except Exception as e: + self.logger.error(f"Error in find_resonances: {e}", exc_info=True) + raise + + async def determine_resonance_type( + self, story1: Story, story2: Story + ) -> Dict[str, Any]: + prompt = self.prompts.get_resonance_analysis_prompt(story1, story2) + analysis = await self.llm.generate(prompt) + + return { + "type": "narrative_resonance", + "analysis": analysis, + "stories": { + "source": { + "id": story1.id, + "content": story1.content, + "context": story1.context, + }, + "resonant": { + "id": story2.id, + "content": story2.content, + "context": story2.context, + }, + }, + "timestamp": datetime.now(), + } + + +class NarrativeField: + def __init__(self, llm_interface: LanguageModel, vector_store: VectorStore): + self._analyzer = FieldAnalyzer(llm_interface) + self._resonance_detector = ResonanceDetector(vector_store, llm_interface) + self._vector_store = vector_store + self._state = FieldState(description="Initial empty narrative field") + self._stories: Dict[StoryID, Story] = {} + self._logger = logging.getLogger(__name__) + self._performance_metrics = PerformanceMetrics() + + @property + def state(self) -> FieldState: + return self._state + + @property + def stories(self) -> Dict[StoryID, Story]: + return self._stories.copy() + + async def add_story(self, content: str, context: str) -> Story: + self._performance_metrics.start_timer("add_story") + + self._performance_metrics.start_timer("create_story") + story = Story(content=content, context=context) + create_time = self._performance_metrics.stop_timer("create_story") + self._logger.info(f"Story creation time: {create_time:.4f} seconds") + + self._performance_metrics.start_timer("analyze_impact") + impact = await self._analyzer.analyze_impact(story, self.state) + analyze_time = self._performance_metrics.stop_timer("analyze_impact") + self._logger.info(f"Impact analysis time: {analyze_time:.4f} seconds") + story.field_effects.append(impact) + + self._performance_metrics.start_timer("find_resonances") + resonances = await self._resonance_detector.find_resonances(story) + resonance_time = self._performance_metrics.stop_timer("find_resonances") + self._logger.info(f"Find resonances time: {resonance_time:.4f} seconds") + story.resonances.extend([r["story_id"] for r in resonances]) + + self._performance_metrics.start_timer("store_story") + await self._store_story(story) + store_time = self._performance_metrics.stop_timer("store_story") + self._logger.info(f"Store story time: {store_time:.4f} seconds") + + self._performance_metrics.start_timer("update_field_state") + await self._update_field_state(story, impact, resonances) + update_time = self._performance_metrics.stop_timer("update_field_state") + self._logger.info(f"Update field state time: {update_time:.4f} seconds") + + total_time = self._performance_metrics.stop_timer("add_story") + self._logger.info(f"Total add_story time: {total_time:.4f} seconds") + + self._performance_metrics.log_system_resources() + + return story + + async def _store_story(self, story: Story) -> None: + embedding = await self._resonance_detector.llm.generate_embedding( + f"{story.content} {story.context}" + ) + await self._vector_store.store(story, embedding) + self._stories[story.id] = story + + async def _update_field_state( + self, story: Story, impact: Dict, resonances: List[Dict] + ) -> None: + patterns = await self._analyzer.detect_patterns( + list(self._stories.values()), self.state + ) + + self._state = FieldState( + description=impact["analysis"], + patterns=[{"analysis": patterns}], + active_resonances=resonances, + emergence_points=[ + { + "story_id": story.id, + "timestamp": datetime.now(), + "type": "new_narrative", + "resonance_context": [ + r["resonance"]["analysis"] for r in resonances + ], + } + ], + ) + + +# Global cleanup function +def global_cleanup(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# Register the global cleanup function to run at exit +atexit.register(global_cleanup) + + +async def demo_scenario(): + logger = logging.getLogger(__name__) + logger.info("Starting narrative field demonstration...") + + # Initialize performance monitor + monitor = PerformanceMonitor() + + llm = None # Initialize llm to None + + try: + # Perform global cleanup before initializing new LLM + global_cleanup() + + llm = LlamaInterface() + + vector_store: VectorStore = ChromaStore(collection_name="research_lab") + logger.info(f"Initialized Chroma vector store") + + field = NarrativeField(llm, vector_store) + logger.info(f"Initialized narrative field") + + # Research Lab Scenario with Multiple Characters and events + stories = [ + # Event 1: Leon discussing the AI minor + { + "content": "After lunch, as Leon and Coen walked back to the lab, Leon decided to share his growing concerns about the AI for Society minor. He voiced his doubts and the challenges he foresaw in the program's current direction. Coen listened attentively and was supportive of Leon's worries. \"I think you have some valid points,\" Coen acknowledged, \"but perhaps it would be best to discuss these issues with Danny, the manager of the minor.\" Coen believed that Danny's insights could be crucial in addressing Leon's concerns.", + "context": "Leon confides in Coen about issues with the AI minor; Coen advises consulting Danny.", + }, + # Event 2: Robbert's tough advice + { + "content": "After work, Robbert and Leon walked back to the lab together. Leon expressed his worries about Danny's accident and the AI minor. However, Robbert seemed more preoccupied with his own research and was not interested in discussing the minor. \"I know you're concerned, but you need to man up and stop whining,\" Robbert said bluntly. His tough advice left Leon feeling isolated and unsupported.", + "context": "Robbert dismisses Leon's concerns, focusing instead on his own research priorities.", + }, + # Event 4: Sarah's contribution + { + "content": "Sarah, a new member of the lab eager to make her mark, approached Leon with a fresh idea. Enthusiastic about the ethical challenges in AI, she suggested a new direction for the AI minor—focusing on ethics in AI development. Her excitement was contagious, and Leon began to see the potential impact of integrating ethics into the program.", + "context": "Sarah proposes refocusing the AI minor on AI ethics, sparking interest from Leon.", + }, + # Event 5: Tom's exhaustion + { + "content": "Tom, another member of the lab, was visibly exhausted after a long day. He had been struggling to keep up with the heavy workload and confided in his colleagues that he wanted to leave early. Considering taking a break from the lab altogether, Tom felt mentally drained and knew he needed time to recover.", + "context": "Tom is overwhelmed by work stress and thinks about temporarily leaving the lab.", + }, + # Event 6: Leon reassessing + { + "content": "Observing Tom's exhaustion, Leon became concerned that the lab might be overworking its members. Balancing his worries about the AI minor and the well-being of his colleagues, he suggested organizing a team meeting to discuss workload management. Leon hoped that addressing these issues openly would help prevent burnout and improve overall productivity.", + "context": "Leon considers holding a meeting to tackle workload issues affecting team morale.", + }, + # Event 7: Coen's personal struggle + { + "content": "In a candid conversation, Coen revealed to Leon that he had been dealing with personal issues and was struggling to focus on work. Leon was surprised by Coen's admission, as he had always appeared to have everything under control. This revelation highlighted the underlying stress affecting the team.", + "context": "Coen admits personal struggles are hindering his work, surprising Leon.", + }, + # Event 8: Sarah's proposal + { + "content": "Concerned about her colleagues' mental health, Sarah proposed implementing a flexible working schedule to accommodate those feeling burned out. She believed that a healthier work-life balance would benefit both the individuals and the lab's productivity. \"We need to take care of ourselves to do our best work,\" she advocated.", + "context": "Sarah suggests flexible hours to improve well-being and efficiency in the lab.", + }, + # Event 9: Tom's decision + { + "content": "Feeling overwhelmed, Tom decided to take a temporary leave from the lab to focus on his mental health. He believed that stepping back was the best decision for now and hoped that his absence would prompt the team to consider the pressures they were all facing.", + "context": "Tom takes a break to address his mental health, hoping to highlight team stress.", + }, + # Event 10: Sarah's pushback + { + "content": "Sarah pushed back against Robbert's position during the meeting, arguing that a more flexible approach would ultimately lead to better results. She highlighted the risks of burnout and the benefits of supporting team members through their personal struggles. The team found itself divided between Robbert's hardline approach and Sarah's call for change.", + "context": "Sarah challenges Robbert's views, leading to a team split over work policies.", + }, + # Event 11: A breakthrough idea + { + "content": "During a late-night discussion, Leon and Sarah brainstormed a novel approach to restructure the AI minor. They envisioned incorporating elements of ethics and mental health awareness into the curriculum, aligning the program with current societal needs. Energized by this new direction, Leon believed it could address both the challenges facing the AI minor and the lab's workload issues.", + "context": "Leon and Sarah create a plan integrating ethics and mental health into the AI minor.", + }, + # Event 12: Tom's return + { + "content": "After his break, Tom returned to the lab feeling refreshed and ready to contribute again. He appreciated the support from his colleagues and felt more optimistic about balancing his mental health with work. Tom's return brought a renewed sense of hope to the team, signaling the potential for positive change.", + "context": "Tom's rejuvenated return inspires hope for better balance in the lab.", + }, + ] + + # Process stories with performance monitoring + logger.info(f"Processing {len(stories)} stories and analyzing field effects...") + for story in stories: + try: + metrics = await monitor.monitor_generation(llm, story["content"]) + logger.debug(f"Story processing metrics: {metrics}") + + await field.add_story(story["content"], story["context"]) + + except Exception as e: + logger.error(f"Error processing story: {e}", exc_info=True) + continue + + # Log performance report at the end + performance_report = monitor.get_performance_report() + logger.info(f"Performance Report: {performance_report}") + + # Print the detailed performance metrics summary + field._performance_metrics.print_summary() + + except Exception as e: + logger.error(f"Error in demo scenario: {e}", exc_info=True) + raise + finally: + # Clean up resources + if llm is not None: + await llm.cleanup() + global_cleanup() + logger.info("Narrative field demonstration completed") + + +if __name__ == "__main__": + try: + setup_logging() # Call the setup_logging function from the imported module + asyncio.run(demo_scenario()) + finally: + global_cleanup() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f65101c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) diff --git a/tests/test_embedding_cache.py b/tests/test_embedding_cache.py new file mode 100644 index 0000000..9ce543a --- /dev/null +++ b/tests/test_embedding_cache.py @@ -0,0 +1,35 @@ +import pytest +from src.embedding_cache import EmbeddingCache + +@pytest.fixture +def cache(): + return EmbeddingCache() + +def test_get_stable_hash(): + text = "Hello, world!" + expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3" + assert EmbeddingCache.get_stable_hash(text) == expected_hash + +def test_set_and_get(cache): + key = "test_key" + value = [0.1, 0.2, 0.3] + cache.set(key, value) + assert cache.get(key) == value + +def test_get_nonexistent_key(cache): + assert cache.get("nonexistent_key") is None + +def test_clear(cache): + cache.set("key1", [1.0, 2.0]) + cache.set("key2", [3.0, 4.0]) + cache.clear() + assert cache.get("key1") is None + assert cache.get("key2") is None + +def test_multiple_sets_same_key(cache): + key = "test_key" + value1 = [0.1, 0.2, 0.3] + value2 = [0.4, 0.5, 0.6] + cache.set(key, value1) + cache.set(key, value2) + assert cache.get(key) == value2 diff --git a/tests/test_language_models.py b/tests/test_language_models.py new file mode 100644 index 0000000..3a962e0 --- /dev/null +++ b/tests/test_language_models.py @@ -0,0 +1,224 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from pathlib import Path + +# Mock the config module +mock_MODEL_CONFIGS = { + "balanced": { + "chat": {"model_name": "test_chat_model", "path": Path("/path/to/chat/model")}, + "embedding": { + "model_name": "test_embedding_model", + "path": Path("/path/to/embedding/model"), + }, + "optimal_config": {"n_ctx": 2048, "n_batch": 512}, + } +} + +# Mock the EmbeddingCache +mock_EmbeddingCache = MagicMock() + +# Patch both config and embedding_cache imports +with patch.dict("sys.modules", { + "config": MagicMock(), + "embedding_cache": MagicMock(), + "torch": MagicMock(), # Mock torch + "llama_cpp": MagicMock() # Mock llama_cpp +}): + import sys + + sys.modules["config"].MODEL_CONFIGS = mock_MODEL_CONFIGS + sys.modules["embedding_cache"].EmbeddingCache = mock_EmbeddingCache + from src.language_models import ( + LanguageModel, + OllamaInterface, + LlamaInterface, + ModelError, + ModelInitializationError, + async_error_handler, # Add this import + ) + + +@pytest.fixture +def mock_ollama(): + with patch("src.language_models.ollama") as mock: + yield mock + + +@pytest.fixture +def mock_llama(): + with patch("src.language_models.Llama") as mock: + yield mock + + +@pytest.fixture(autouse=True) +def mock_config(): + with patch("src.language_models.MODEL_CONFIGS", mock_MODEL_CONFIGS): + yield + + +class TestLanguageModel: + @pytest.fixture + def concrete_language_model(self): + class ConcreteLanguageModel(LanguageModel): + @async_error_handler + async def generate(self, prompt: str) -> str: + return f"Generated: {prompt}" + + @async_error_handler + async def _generate_embedding(self, text: str) -> list[float]: + if not text: + raise Exception("Test error") + return [0.1, 0.2, 0.3] + + return ConcreteLanguageModel() + + @pytest.mark.asyncio + async def test_generate_embedding_cached(self, concrete_language_model): + text = "Test text" + cached_embedding = [0.4, 0.5, 0.6] + concrete_language_model.embedding_cache.get.return_value = cached_embedding + + result = await concrete_language_model.generate_embedding(text) + assert result == cached_embedding + + @pytest.mark.asyncio + async def test_generate_embedding_not_cached(self, concrete_language_model): + text = "New test text" + expected_embedding = [0.1, 0.2, 0.3] + concrete_language_model.embedding_cache.get.return_value = None + + result = await concrete_language_model.generate_embedding(text) + assert result == expected_embedding + concrete_language_model.embedding_cache.set.assert_called_once_with(text, expected_embedding) + + @pytest.mark.asyncio + async def test_generate_embedding_empty_text(self, concrete_language_model, caplog): + result = await concrete_language_model.generate_embedding("") + assert result == [] + assert "Attempted to generate embedding for empty text" in caplog.text + + +class TestOllamaInterface: + @pytest.mark.asyncio + async def test_init(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + assert ollama_interface.chat_model_name == "test_chat_model" + assert ollama_interface.embedding_model_name == "test_embedding_model" + mock_ollama.ps.assert_called_once() + + @pytest.mark.asyncio + async def test_generate(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + mock_ollama.chat.return_value = {"message": {"content": "Generated response"}} + + response = await ollama_interface.generate("Test prompt") + assert response == "Generated response" + mock_ollama.chat.assert_called_once_with( + model="test_chat_model", + messages=[{"role": "user", "content": "Test prompt"}], + ) + + @pytest.mark.asyncio + async def test_generate_embedding(self, mock_ollama, mock_config): + ollama_interface = OllamaInterface() + mock_ollama.embeddings.return_value = {"embedding": [0.1, 0.2, 0.3]} + + embedding = await ollama_interface.generate_embedding("Test text") + assert embedding == [0.1, 0.2, 0.3] + mock_ollama.embeddings.assert_called_once_with( + model="test_embedding_model", prompt="Test text" + ) + + +class TestLlamaInterface: + @pytest.mark.asyncio + async def test_init(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + assert llama_interface.chat_model_path == Path('/path/to/chat/model') + assert llama_interface.embedding_model_path == Path('/path/to/embedding/model') + assert llama_interface.optimal_config == {'n_ctx': 2048, 'n_batch': 512} + mock_llama.assert_called() + + @pytest.mark.asyncio + async def test_generate(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + mock_llama.return_value.create_chat_completion.return_value = { + 'choices': [{'message': {'content': 'Generated response'}}] + } + + response = await llama_interface.generate("Test prompt") + assert response == 'Generated response' + mock_llama.return_value.create_chat_completion.assert_called_once_with( + messages=[{'role': 'user', 'content': 'Test prompt'}] + ) + + @pytest.mark.asyncio + async def test_generate_embedding(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + mock_llama.return_value.embed.return_value = [0.1, 0.2, 0.3] + + embedding = await llama_interface.generate_embedding("Test text") + assert embedding == [0.1, 0.2, 0.3] + mock_llama.return_value.embed.assert_called_once_with("Test text") + + @pytest.mark.asyncio + async def test_cleanup(self, mock_llama, mock_config): + llama_interface = LlamaInterface() + + await llama_interface.cleanup() + + # Check that the models have been deleted + assert not hasattr(llama_interface, 'llm') + assert not hasattr(llama_interface, 'embedding_model') + +@pytest.mark.asyncio +async def test_model_error(): + class ErrorModel(LanguageModel): + @async_error_handler + async def generate(self, prompt: str) -> str: + raise Exception("Test error") + + @async_error_handler + async def _generate_embedding(self, text: str) -> list[float]: + raise Exception("Test error") + + error_model = ErrorModel() + with pytest.raises(ModelError): + await error_model.generate("Test prompt") + + with pytest.raises(ModelError): + await error_model.generate_embedding("Test text") + + +@pytest.mark.asyncio +async def test_model_error(): + class ErrorModel(LanguageModel): + @async_error_handler + async def generate(self, prompt: str) -> str: + raise Exception("Test error") + + @async_error_handler + async def _generate_embedding(self, text: str) -> list[float]: + raise Exception("Test error") + + error_model = ErrorModel() + with pytest.raises(ModelError): + await error_model.generate("Test prompt") + + with pytest.raises(ModelError): + await error_model.generate_embedding("Test text") + + error_model = ErrorModel() + with pytest.raises(ModelError): + await error_model.generate("Test prompt") + + with pytest.raises(ModelError): + await error_model.generate_embedding("Test text") + + +def test_model_initialization_error(mock_config): + with pytest.raises(ModelInitializationError): + OllamaInterface(quality_preset="invalid_preset") + + with pytest.raises(ModelInitializationError): + LlamaInterface(quality_preset="invalid_preset")