diff --git a/CLAUDE.md b/CLAUDE.md index 942c3e3..eee4e1a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -194,6 +194,7 @@ The codebase uses a comprehensive multi-layered testing approach: - **Production supports both backends**: While tests use SQLite, production code fully supports both SQLite and PostgreSQL - **No PostgreSQL required for testing**: Developers can run the full test suite without installing PostgreSQL - **Backend-agnostic implementation**: Repository code works identically with both backends +- **Real server tests required**: When writing unit tests, always add corresponding real server integration tests in `tests/test_real_server.py` to verify tool behavior via the actual MCP protocol #### Test Database Protection: - **Automatic Isolation**: All tests use `prevent_default_db_pollution` fixture (session-scoped, autouse) diff --git a/app/repositories/context_repository.py b/app/repositories/context_repository.py index bf951f5..d47ed79 100644 --- a/app/repositories/context_repository.py +++ b/app/repositories/context_repository.py @@ -936,3 +936,423 @@ def row_to_dict(row: sqlite3.Row) -> dict[str, Any]: entry['metadata'] = None return entry + + async def store_contexts_batch( + self, + entries: list[dict[str, Any]], + ) -> list[tuple[int, int | None, str | None]]: + """Store multiple context entries in a single transaction. + + Each entry is processed with deduplication logic: if an entry with + identical thread_id, source, and text_content exists, it is updated + rather than creating a duplicate. + + Args: + entries: List of entry dictionaries with keys: + - thread_id: str + - source: str + - text_content: str + - metadata: str | None (JSON string) + - content_type: str ('text' or 'multimodal') + + Returns: + List of tuples: (index, context_id or None, error or None) + On success: (0, 123, None) + On failure: (0, None, 'Error message') + """ + if self.backend.backend_type == 'sqlite': + + def _store_batch_sqlite(conn: sqlite3.Connection) -> list[tuple[int, int | None, str | None]]: + cursor = conn.cursor() + results: list[tuple[int, int | None, str | None]] = [] + + for idx, entry in enumerate(entries): + try: + thread_id = entry['thread_id'] + source = entry['source'] + text_content = entry['text_content'] + metadata = entry.get('metadata') + content_type = entry.get('content_type', 'text') + + # Check for deduplication - find latest entry with same thread_id, source, text_content + cursor.execute( + f''' + SELECT id FROM context_entries + WHERE thread_id = {self._placeholder(1)} + AND source = {self._placeholder(2)} + AND text_content = {self._placeholder(3)} + ORDER BY id DESC + LIMIT 1 + ''', + (thread_id, source, text_content), + ) + existing = cursor.fetchone() + + if existing: + # Update existing entry + existing_id = existing['id'] + cursor.execute( + f''' + UPDATE context_entries + SET metadata = {self._placeholder(1)}, + content_type = {self._placeholder(2)}, + updated_at = CURRENT_TIMESTAMP + WHERE id = {self._placeholder(3)} + ''', + (metadata, content_type, existing_id), + ) + results.append((idx, existing_id, None)) + logger.debug(f'Batch: updated existing context entry {existing_id}') + else: + # Insert new entry + cursor.execute( + f''' + INSERT INTO context_entries + (thread_id, source, content_type, text_content, metadata) + VALUES ({self._placeholders(5)}) + ''', + (thread_id, source, content_type, text_content, metadata), + ) + new_id = cursor.lastrowid or 0 + results.append((idx, new_id, None)) + logger.debug(f'Batch: inserted new context entry {new_id}') + + except Exception as e: + results.append((idx, None, str(e))) + logger.warning(f'Batch store failed for entry {idx}: {e}') + + return results + + return await self.backend.execute_write(_store_batch_sqlite) + + # PostgreSQL + async def _store_batch_postgresql(conn: asyncpg.Connection) -> list[tuple[int, int | None, str | None]]: + results: list[tuple[int, int | None, str | None]] = [] + + for idx, entry in enumerate(entries): + try: + thread_id = entry['thread_id'] + source = entry['source'] + text_content = entry['text_content'] + metadata = entry.get('metadata') + content_type = entry.get('content_type', 'text') + + # Check for deduplication + existing = await conn.fetchrow( + f''' + SELECT id FROM context_entries + WHERE thread_id = {self._placeholder(1)} + AND source = {self._placeholder(2)} + AND text_content = {self._placeholder(3)} + ORDER BY id DESC + LIMIT 1 + ''', + thread_id, + source, + text_content, + ) + + if existing: + # Update existing entry + existing_id = existing['id'] + await conn.execute( + f''' + UPDATE context_entries + SET metadata = {self._placeholder(1)}, + content_type = {self._placeholder(2)}, + updated_at = CURRENT_TIMESTAMP + WHERE id = {self._placeholder(3)} + ''', + metadata, + content_type, + existing_id, + ) + results.append((idx, existing_id, None)) + logger.debug(f'Batch: updated existing context entry {existing_id}') + else: + # Insert new entry + new_id_result = await conn.fetchval( + f''' + INSERT INTO context_entries + (thread_id, source, content_type, text_content, metadata) + VALUES ({self._placeholders(5)}) + RETURNING id + ''', + thread_id, + source, + content_type, + text_content, + metadata, + ) + new_id = cast(int, new_id_result) + results.append((idx, new_id, None)) + logger.debug(f'Batch: inserted new context entry {new_id}') + + except Exception as e: + results.append((idx, None, str(e))) + logger.warning(f'Batch store failed for entry {idx}: {e}') + + return results + + return await self.backend.execute_write(_store_batch_postgresql) + + async def update_contexts_batch( + self, + updates: list[dict[str, Any]], + ) -> list[tuple[int, int, list[str] | None, str | None]]: + """Update multiple context entries in a single transaction. + + Args: + updates: List of update dictionaries with keys: + - context_id: int (required) + - text_content: str | None (optional) + - metadata: str | None (JSON string, full replacement) + - content_type: str | None (optional) + + Returns: + List of tuples: (index, context_id, updated_fields or None, error or None) + """ + if self.backend.backend_type == 'sqlite': + + def _update_batch_sqlite(conn: sqlite3.Connection) -> list[tuple[int, int, list[str] | None, str | None]]: + cursor = conn.cursor() + results: list[tuple[int, int, list[str] | None, str | None]] = [] + + for idx, update in enumerate(updates): + try: + context_id = update['context_id'] + + # Check if entry exists + cursor.execute( + f'SELECT id FROM context_entries WHERE id = {self._placeholder(1)}', + (context_id,), + ) + if not cursor.fetchone(): + results.append((idx, context_id, None, f'Context entry {context_id} not found')) + continue + + # Build dynamic update + update_parts: list[str] = [] + params: list[Any] = [] + updated_fields: list[str] = [] + + if 'text_content' in update and update['text_content'] is not None: + update_parts.append(f'text_content = {self._placeholder(len(params) + 1)}') + params.append(update['text_content']) + updated_fields.append('text_content') + + if 'metadata' in update: + update_parts.append(f'metadata = {self._placeholder(len(params) + 1)}') + params.append(update['metadata']) + updated_fields.append('metadata') + + if 'content_type' in update and update['content_type'] is not None: + update_parts.append(f'content_type = {self._placeholder(len(params) + 1)}') + params.append(update['content_type']) + updated_fields.append('content_type') + + if not update_parts: + results.append((idx, context_id, None, 'No fields to update')) + continue + + # Always update timestamp + update_parts.append('updated_at = CURRENT_TIMESTAMP') + + id_placeholder = self._placeholder(len(params) + 1) + query = f"UPDATE context_entries SET {', '.join(update_parts)} WHERE id = {id_placeholder}" + params.append(context_id) + cursor.execute(query, tuple(params)) + + if cursor.rowcount > 0: + results.append((idx, context_id, updated_fields, None)) + logger.debug(f'Batch: updated context entry {context_id}, fields: {updated_fields}') + else: + results.append((idx, context_id, None, 'Update had no effect')) + + except Exception as e: + results.append((idx, update.get('context_id', 0), None, str(e))) + logger.warning(f'Batch update failed for entry {idx}: {e}') + + return results + + return await self.backend.execute_write(_update_batch_sqlite) + + # PostgreSQL + async def _update_batch_postgresql( + conn: asyncpg.Connection, + ) -> list[tuple[int, int, list[str] | None, str | None]]: + results: list[tuple[int, int, list[str] | None, str | None]] = [] + + for idx, update in enumerate(updates): + try: + context_id = update['context_id'] + + # Check if entry exists + row = await conn.fetchrow( + f'SELECT id FROM context_entries WHERE id = {self._placeholder(1)}', + context_id, + ) + if not row: + results.append((idx, context_id, None, f'Context entry {context_id} not found')) + continue + + # Build dynamic update + update_parts: list[str] = [] + params: list[Any] = [] + updated_fields: list[str] = [] + + if 'text_content' in update and update['text_content'] is not None: + update_parts.append(f'text_content = {self._placeholder(len(params) + 1)}') + params.append(update['text_content']) + updated_fields.append('text_content') + + if 'metadata' in update: + update_parts.append(f'metadata = {self._placeholder(len(params) + 1)}') + params.append(update['metadata']) + updated_fields.append('metadata') + + if 'content_type' in update and update['content_type'] is not None: + update_parts.append(f'content_type = {self._placeholder(len(params) + 1)}') + params.append(update['content_type']) + updated_fields.append('content_type') + + if not update_parts: + results.append((idx, context_id, None, 'No fields to update')) + continue + + # Always update timestamp + update_parts.append('updated_at = CURRENT_TIMESTAMP') + + id_placeholder = self._placeholder(len(params) + 1) + set_clause = ', '.join(update_parts) + query = f'UPDATE context_entries SET {set_clause} WHERE id = {id_placeholder}' + params.append(context_id) + result = await conn.execute(query, *params) + + # asyncpg returns "UPDATE N" where N is the count + rows_affected = int(result.split()[-1]) if result else 0 + if rows_affected > 0: + results.append((idx, context_id, updated_fields, None)) + logger.debug(f'Batch: updated context entry {context_id}, fields: {updated_fields}') + else: + results.append((idx, context_id, None, 'Update had no effect')) + + except Exception as e: + results.append((idx, update.get('context_id', 0), None, str(e))) + logger.warning(f'Batch update failed for entry {idx}: {e}') + + return results + + return await self.backend.execute_write(_update_batch_postgresql) + + async def delete_contexts_batch( + self, + context_ids: list[int] | None = None, + thread_ids: list[str] | None = None, + source: str | None = None, + older_than_days: int | None = None, + ) -> tuple[int, list[str]]: + """Delete multiple context entries by various criteria. + + At least one criterion must be provided. Criteria can be combined + for more targeted deletion. Cascading delete removes associated + tags, images, and embeddings. + + Args: + context_ids: Specific context entry IDs to delete + thread_ids: Delete all entries in these threads + source: Filter by source ('user' or 'agent') - combine with other criteria + older_than_days: Delete entries older than N days + + Returns: + Tuple of (deleted_count, list_of_criteria_used) + """ + criteria_used: list[str] = [] + + if self.backend.backend_type == 'sqlite': + + def _delete_batch_sqlite(conn: sqlite3.Connection) -> tuple[int, list[str]]: + cursor = conn.cursor() + conditions: list[str] = [] + params: list[Any] = [] + + if context_ids: + placeholders = ','.join([self._placeholder(len(params) + i + 1) for i in range(len(context_ids))]) + conditions.append(f'id IN ({placeholders})') + params.extend(context_ids) + criteria_used.append(f'context_ids: {len(context_ids)} IDs') + + if thread_ids: + placeholders = ','.join([ + self._placeholder(len(params) + i + 1) for i in range(len(thread_ids)) + ]) + conditions.append(f'thread_id IN ({placeholders})') + params.extend(thread_ids) + criteria_used.append(f'thread_ids: {len(thread_ids)} threads') + + if source: + conditions.append(f'source = {self._placeholder(len(params) + 1)}') + params.append(source) + criteria_used.append(f'source: {source}') + + if older_than_days is not None: + conditions.append( + f"created_at < datetime('now', {self._placeholder(len(params) + 1)})", + ) + params.append(f'-{older_than_days} days') + criteria_used.append(f'older_than_days: {older_than_days}') + + if not conditions: + return 0, criteria_used + + where_clause = ' AND '.join(conditions) + query = f'DELETE FROM context_entries WHERE {where_clause}' + cursor.execute(query, tuple(params)) + + deleted_count = cursor.rowcount + logger.info(f'Batch delete: removed {deleted_count} entries using criteria: {criteria_used}') + return deleted_count, criteria_used + + return await self.backend.execute_write(_delete_batch_sqlite) + + # PostgreSQL + async def _delete_batch_postgresql(conn: asyncpg.Connection) -> tuple[int, list[str]]: + conditions: list[str] = [] + params: list[Any] = [] + + if context_ids: + placeholders = ','.join([self._placeholder(len(params) + i + 1) for i in range(len(context_ids))]) + conditions.append(f'id IN ({placeholders})') + params.extend(context_ids) + criteria_used.append(f'context_ids: {len(context_ids)} IDs') + + if thread_ids: + placeholders = ','.join([self._placeholder(len(params) + i + 1) for i in range(len(thread_ids))]) + conditions.append(f'thread_id IN ({placeholders})') + params.extend(thread_ids) + criteria_used.append(f'thread_ids: {len(thread_ids)} threads') + + if source: + conditions.append(f'source = {self._placeholder(len(params) + 1)}') + params.append(source) + criteria_used.append(f'source: {source}') + + if older_than_days is not None: + conditions.append( + f"created_at < (NOW() - INTERVAL '{older_than_days} days')", + ) + criteria_used.append(f'older_than_days: {older_than_days}') + + if not conditions: + return 0, criteria_used + + where_clause = ' AND '.join(conditions) + query = f'DELETE FROM context_entries WHERE {where_clause}' + result = await conn.execute(query, *params) + + # asyncpg returns "DELETE N" where N is the count + deleted_count = int(result.split()[-1]) if result else 0 + logger.info(f'Batch delete: removed {deleted_count} entries using criteria: {criteria_used}') + return deleted_count, criteria_used + + return await self.backend.execute_write(_delete_batch_postgresql) diff --git a/app/server.py b/app/server.py index f087cac..1497731 100644 --- a/app/server.py +++ b/app/server.py @@ -9,6 +9,7 @@ import contextlib import json import logging +import operator import sqlite3 from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -29,6 +30,11 @@ from app.logger_config import config_logger from app.repositories import RepositoryContainer from app.settings import get_settings +from app.types import BulkDeleteResponseDict +from app.types import BulkStoreResponseDict +from app.types import BulkStoreResultItemDict +from app.types import BulkUpdateResponseDict +from app.types import BulkUpdateResultItemDict from app.types import ContextEntryDict from app.types import JsonValue from app.types import MetadataDict @@ -1884,6 +1890,625 @@ async def semantic_search_context( raise ToolError(f'Semantic search failed: {str(e)}') from e +# Bulk Operation MCP Tools + + +@mcp.tool() +async def store_context_batch( + entries: Annotated[ + list[dict[str, Any]], + Field( + description='List of context entries to store. Each entry must have: ' + 'thread_id (str), source ("user" or "agent"), text (str). ' + 'Optional: metadata (dict), tags (list[str]), images (list[dict]).', + min_length=1, + max_length=100, + ), + ], + atomic: Annotated[ + bool, + Field( + description='If true, ALL entries must succeed or NONE are stored (transaction rollback). ' + 'If false, partial success is allowed with per-item error reporting.', + ), + ] = True, + ctx: Context | None = None, +) -> BulkStoreResponseDict: + """Store multiple context entries in a single batch operation. + + Batch processing is significantly faster than individual store_context calls + when storing many entries. Use for migrations, imports, or bulk operations. + + Atomicity modes: + - atomic=True (default): All-or-nothing. If ANY entry fails, ALL are rolled back. + - atomic=False: Best-effort. Each entry processed independently; partial success possible. + + Returns: { + success: bool (true if ALL succeeded), + total: int, + succeeded: int, + failed: int, + results: [{index, success, context_id, error}, ...], + message: str + } + + Size limits: + - Maximum 100 entries per batch + - Image limits per entry: 10MB each, 100MB total + - Standard tag normalization (lowercase) + + Args: + entries: List of context entry dicts with thread_id, source, text, optional metadata/tags/images + atomic: If true, all succeed or all fail; if false, partial success allowed + ctx: FastMCP context object + + Returns: + BulkStoreResponseDict: Detailed results for each entry + + Raises: + ToolError: If atomic mode and any validation fails + """ + try: + if ctx: + await ctx.info(f'Batch storing {len(entries)} context entries (atomic={atomic})') + + repos = await _ensure_repositories() + + # Phase 1: Validate all entries before processing + validated_entries: list[dict[str, Any]] = [] + validation_errors: list[tuple[int, str]] = [] + + for idx, entry in enumerate(entries): + # Validate required fields + if 'thread_id' not in entry or not entry.get('thread_id'): + validation_errors.append((idx, 'Missing required field: thread_id')) + continue + if 'source' not in entry or entry.get('source') not in ('user', 'agent'): + validation_errors.append((idx, 'Missing or invalid source (must be "user" or "agent")')) + continue + if 'text' not in entry or not entry.get('text'): + validation_errors.append((idx, 'Missing required field: text')) + continue + + # Clean input strings + thread_id = str(entry['thread_id']).strip() + text = str(entry['text']).strip() + + if not thread_id: + validation_errors.append((idx, 'thread_id cannot be empty or whitespace')) + continue + if not text: + validation_errors.append((idx, 'text cannot be empty or whitespace')) + continue + + # Validate images if present + images = entry.get('images', []) + content_type = 'multimodal' if images else 'text' + + if images: + total_size = 0.0 + for img_idx, img in enumerate(images): + if 'data' not in img: + validation_errors.append((idx, f'Image {img_idx} is missing required "data" field')) + break + try: + img_data = base64.b64decode(img['data']) + img_size_mb = len(img_data) / (1024 * 1024) + if img_size_mb > MAX_IMAGE_SIZE_MB: + validation_errors.append((idx, f'Image {img_idx} exceeds {MAX_IMAGE_SIZE_MB}MB limit')) + break + total_size += img_size_mb + if total_size > MAX_TOTAL_SIZE_MB: + validation_errors.append((idx, f'Total image size exceeds {MAX_TOTAL_SIZE_MB}MB limit')) + break + except Exception: + validation_errors.append((idx, f'Image {img_idx} has invalid base64 encoding')) + break + else: + # All images valid for this entry + pass + + # Check if entry already had validation errors from images + if any(idx == err[0] for err in validation_errors): + continue + + # Prepare validated entry + metadata = entry.get('metadata') + validated_entries.append({ + 'index': idx, + 'thread_id': thread_id, + 'source': entry['source'], + 'text_content': text, + 'metadata': json.dumps(metadata, ensure_ascii=False) if metadata else None, + 'content_type': content_type, + 'tags': entry.get('tags', []), + 'images': images, + }) + + # Phase 2: In atomic mode, fail fast if any validation errors + if atomic and validation_errors: + first_error = validation_errors[0] + raise ToolError( + f'Validation failed for {len(validation_errors)} entries. ' + f'First error at index {first_error[0]}: {first_error[1]}', + ) + + # Phase 3: Process validated entries through repository + # Build results list including validation errors + results: list[BulkStoreResultItemDict] = [] + + # Add validation errors to results + for idx, error in validation_errors: + results.append(BulkStoreResultItemDict( + index=idx, + success=False, + context_id=None, + error=error, + )) + + if validated_entries: + # Prepare entries for repository batch operation + repo_entries = [ + { + 'thread_id': e['thread_id'], + 'source': e['source'], + 'text_content': e['text_content'], + 'metadata': e['metadata'], + 'content_type': e['content_type'], + } + for e in validated_entries + ] + + # Execute batch store + batch_results = await repos.context.store_contexts_batch(repo_entries) + + # Process repository results and store tags/images + for repo_idx, ctx_id, repo_error in batch_results: + original_entry = validated_entries[repo_idx] + original_idx = original_entry['index'] + + if ctx_id is not None and repo_error is None: + # Store tags if provided + if original_entry.get('tags'): + await repos.tags.store_tags(ctx_id, original_entry['tags']) + + # Store images if provided + if original_entry.get('images'): + await repos.images.store_images(ctx_id, original_entry['images']) + + # Generate embedding if semantic search enabled (non-blocking) + if _embedding_service is not None: + try: + embedding = await _embedding_service.generate_embedding(original_entry['text_content']) + await repos.embeddings.store( + context_id=ctx_id, + embedding=embedding, + model=settings.embedding_model, + ) + except Exception as emb_err: + logger.warning(f'Failed to generate embedding for context {ctx_id}: {emb_err}') + + results.append(BulkStoreResultItemDict( + index=original_idx, + success=True, + context_id=ctx_id, + error=None, + )) + else: + results.append(BulkStoreResultItemDict( + index=original_idx, + success=False, + context_id=None, + error=repo_error or 'Unknown error', + )) + + # Sort results by index for consistent ordering + results.sort(key=operator.itemgetter('index')) + + # Calculate summary + succeeded = sum(1 for r in results if r['success']) + failed = len(entries) - succeeded + + logger.info(f'Batch store completed: {succeeded}/{len(entries)} succeeded') + + return BulkStoreResponseDict( + success=failed == 0, + total=len(entries), + succeeded=succeeded, + failed=failed, + results=results, + message=f'Stored {succeeded}/{len(entries)} entries successfully', + ) + + except ToolError: + raise + except Exception as e: + logger.error(f'Error in batch store: {e}') + raise ToolError(f'Batch store failed: {str(e)}') from e + + +@mcp.tool() +async def update_context_batch( + updates: Annotated[ + list[dict[str, Any]], + Field( + description='List of update operations. Each must have context_id (int). ' + 'Optional: text (str), metadata (dict - full replace), ' + 'metadata_patch (dict - RFC 7396 merge), tags (list[str]), images (list[dict]).', + min_length=1, + max_length=100, + ), + ], + atomic: Annotated[ + bool, + Field( + description='If true, ALL updates succeed or NONE are applied. ' + 'If false, partial success allowed.', + ), + ] = True, + ctx: Context | None = None, +) -> BulkUpdateResponseDict: + """Update multiple context entries in a single batch operation. + + Similar semantics to update_context but for multiple entries: + - Each update is identified by context_id + - Only provided fields are modified + - metadata vs metadata_patch are mutually exclusive per entry + - Tags and images use replacement semantics + + Atomicity modes: + - atomic=True (default): All-or-nothing transaction + - atomic=False: Best-effort with per-item error reporting + + Args: + updates: List of update dicts with context_id (required) and optional fields + atomic: If true, all succeed or all fail; if false, partial success allowed + ctx: FastMCP context object + + Returns: + BulkUpdateResponseDict: Detailed results for each update including updated_fields + + Raises: + ToolError: If atomic mode and any validation fails + """ + try: + if ctx: + await ctx.info(f'Batch updating {len(updates)} context entries (atomic={atomic})') + + repos = await _ensure_repositories() + + # Phase 1: Validate all updates before processing + validated_updates: list[dict[str, Any]] = [] + validation_errors: list[tuple[int, int, str]] = [] # (index, context_id, error) + + for idx, update in enumerate(updates): + # Validate required context_id + if 'context_id' not in update: + validation_errors.append((idx, 0, 'Missing required field: context_id')) + continue + + context_id = update['context_id'] + if not isinstance(context_id, int) or context_id <= 0: + validation_errors.append((idx, 0, 'context_id must be a positive integer')) + continue + + # Validate mutual exclusivity of metadata and metadata_patch + if update.get('metadata') is not None and update.get('metadata_patch') is not None: + validation_errors.append(( + idx, + context_id, + 'Cannot use both metadata and metadata_patch. Use one or the other.', + )) + continue + + # Validate text if provided + text = update.get('text') + if text is not None: + text = str(text).strip() + if not text: + validation_errors.append((idx, context_id, 'text cannot be empty or whitespace')) + continue + + # Check that at least one field is provided for update + has_update = any( + update.get(field) is not None + for field in ['text', 'metadata', 'metadata_patch', 'tags', 'images'] + ) + if not has_update: + validation_errors.append((idx, context_id, 'At least one field must be provided for update')) + continue + + # Validate images if provided + images = update.get('images') + if images is not None and len(images) > 0: + total_size = 0.0 + for img_idx, img in enumerate(images): + if 'data' not in img: + validation_errors.append((idx, context_id, f'Image {img_idx} missing "data" field')) + break + try: + img_data = base64.b64decode(img['data']) + img_size_mb = len(img_data) / (1024 * 1024) + if img_size_mb > MAX_IMAGE_SIZE_MB: + validation_errors.append(( + idx, + context_id, + f'Image {img_idx} exceeds {MAX_IMAGE_SIZE_MB}MB', + )) + break + total_size += img_size_mb + if total_size > MAX_TOTAL_SIZE_MB: + validation_errors.append(( + idx, + context_id, + f'Total size exceeds {MAX_TOTAL_SIZE_MB}MB', + )) + break + except Exception: + validation_errors.append((idx, context_id, f'Image {img_idx} has invalid base64')) + break + + # Check if entry already had validation errors from images + if any(idx == err[0] for err in validation_errors): + continue + + # Prepare validated update + validated_updates.append({ + 'index': idx, + 'context_id': context_id, + 'text': text, + 'metadata': update.get('metadata'), + 'metadata_patch': update.get('metadata_patch'), + 'tags': update.get('tags'), + 'images': images, + }) + + # Phase 2: In atomic mode, fail fast if any validation errors + if atomic and validation_errors: + first_error = validation_errors[0] + raise ToolError( + f'Validation failed for {len(validation_errors)} entries. ' + f'First error at context_id {first_error[1]}: {first_error[2]}', + ) + + # Phase 3: Process validated updates + results: list[BulkUpdateResultItemDict] = [] + + # Add validation errors to results + for idx, context_id, error in validation_errors: + results.append(BulkUpdateResultItemDict( + index=idx, + context_id=context_id, + success=False, + updated_fields=None, + error=error, + )) + + # Process each validated update + for update in validated_updates: + original_idx = update['index'] + context_id = update['context_id'] + updated_fields: list[str] = [] + + try: + # Check if entry exists + exists = await repos.context.check_entry_exists(context_id) + if not exists: + results.append(BulkUpdateResultItemDict( + index=original_idx, + context_id=context_id, + success=False, + updated_fields=None, + error=f'Context entry {context_id} not found', + )) + continue + + # Update text and/or metadata (full replacement) + if update.get('text') is not None or update.get('metadata') is not None: + metadata_str = None + if update.get('metadata') is not None: + metadata_str = json.dumps(update['metadata'], ensure_ascii=False) + + success, fields = await repos.context.update_context_entry( + context_id=context_id, + text_content=update.get('text'), + metadata=metadata_str, + ) + if success: + updated_fields.extend(fields) + + # Apply metadata patch if provided + if update.get('metadata_patch') is not None: + success, fields = await repos.context.patch_metadata( + context_id=context_id, + patch=update['metadata_patch'], + ) + if success: + updated_fields.extend(fields) + + # Replace tags if provided + if update.get('tags') is not None: + await repos.tags.replace_tags_for_context(context_id, update['tags']) + updated_fields.append('tags') + + # Replace images if provided + if update.get('images') is not None: + images = update['images'] + if len(images) == 0: + await repos.images.replace_images_for_context(context_id, []) + await repos.context.update_content_type(context_id, 'text') + updated_fields.extend(['images', 'content_type']) + else: + await repos.images.replace_images_for_context(context_id, images) + await repos.context.update_content_type(context_id, 'multimodal') + updated_fields.extend(['images', 'content_type']) + + # Regenerate embedding if text changed and semantic search available + if update.get('text') is not None and _embedding_service is not None: + try: + new_embedding = await _embedding_service.generate_embedding(update['text']) + embedding_exists = await repos.embeddings.exists(context_id) + if embedding_exists: + await repos.embeddings.update(context_id=context_id, embedding=new_embedding) + else: + await repos.embeddings.store( + context_id=context_id, + embedding=new_embedding, + model=settings.embedding_model, + ) + updated_fields.append('embedding') + except Exception as emb_err: + logger.warning(f'Failed to update embedding for context {context_id}: {emb_err}') + + results.append(BulkUpdateResultItemDict( + index=original_idx, + context_id=context_id, + success=True, + updated_fields=updated_fields, + error=None, + )) + + except Exception as e: + results.append(BulkUpdateResultItemDict( + index=original_idx, + context_id=context_id, + success=False, + updated_fields=None, + error=str(e), + )) + + # Sort results by index for consistent ordering + results.sort(key=operator.itemgetter('index')) + + # Calculate summary + succeeded = sum(1 for r in results if r['success']) + failed = len(updates) - succeeded + + logger.info(f'Batch update completed: {succeeded}/{len(updates)} succeeded') + + return BulkUpdateResponseDict( + success=failed == 0, + total=len(updates), + succeeded=succeeded, + failed=failed, + results=results, + message=f'Updated {succeeded}/{len(updates)} entries successfully', + ) + + except ToolError: + raise + except Exception as e: + logger.error(f'Error in batch update: {e}') + raise ToolError(f'Batch update failed: {str(e)}') from e + + +@mcp.tool() +async def delete_context_batch( + context_ids: Annotated[ + list[int] | None, + Field(description='Specific context IDs to delete'), + ] = None, + thread_ids: Annotated[ + list[str] | None, + Field(description='Delete ALL entries in these threads'), + ] = None, + source: Annotated[ + Literal['user', 'agent'] | None, + Field(description='Delete only entries from this source (combine with thread_ids)'), + ] = None, + older_than_days: Annotated[ + int | None, + Field(description='Delete entries older than N days', gt=0), + ] = None, + ctx: Context | None = None, +) -> BulkDeleteResponseDict: + """Delete multiple context entries by various criteria. IRREVERSIBLE. + + Criteria can be combined for targeted deletion: + - context_ids: Delete specific entries by ID + - thread_ids: Delete all entries in specified threads + - source: Filter by source ('user' or 'agent') + - older_than_days: Delete entries created more than N days ago + + At least one criterion must be provided. + Cascading delete removes associated tags, images, and embeddings. + + WARNING: This operation cannot be undone. Verify criteria before deletion. + + Args: + context_ids: Specific context entry IDs to delete + thread_ids: Delete all entries in these threads + source: Filter by source type (combine with other criteria) + older_than_days: Delete entries older than N days + ctx: FastMCP context object + + Returns: + BulkDeleteResponseDict: {success, deleted_count, criteria_used, message} + + Raises: + ToolError: If no criteria provided or deletion fails + """ + try: + # Validate at least one criterion is provided + if not any([context_ids, thread_ids, source, older_than_days]): + raise ToolError( + 'At least one deletion criterion must be provided: ' + 'context_ids, thread_ids, source, or older_than_days', + ) + + # Validate source if provided alone + if source and not any([context_ids, thread_ids, older_than_days]): + raise ToolError( + 'source filter must be combined with another criterion ' + '(context_ids, thread_ids, or older_than_days)', + ) + + if ctx: + criteria_summary = [] + if context_ids: + criteria_summary.append(f'{len(context_ids)} IDs') + if thread_ids: + criteria_summary.append(f'{len(thread_ids)} threads') + if source: + criteria_summary.append(f'source={source}') + if older_than_days: + criteria_summary.append(f'older_than={older_than_days}d') + await ctx.info(f'Batch delete with criteria: {", ".join(criteria_summary)}') + + repos = await _ensure_repositories() + + # Delete embeddings first if context_ids are specified + if settings.enable_semantic_search and context_ids: + for cid in context_ids: + try: + await repos.embeddings.delete(cid) + except Exception as e: + logger.warning(f'Failed to delete embedding for context {cid}: {e}') + + # Execute batch delete through repository + deleted_count, criteria_used = await repos.context.delete_contexts_batch( + context_ids=context_ids, + thread_ids=thread_ids, + source=source, + older_than_days=older_than_days, + ) + + logger.info(f'Batch delete completed: {deleted_count} entries removed') + + return BulkDeleteResponseDict( + success=True, + deleted_count=deleted_count, + criteria_used=criteria_used, + message=f'Successfully deleted {deleted_count} context entries', + ) + + except ToolError: + raise + except Exception as e: + logger.error(f'Error in batch delete: {e}') + raise ToolError(f'Batch delete failed: {str(e)}') from e + + # Main entry point def main() -> None: """Main entry point for the MCP Context Server. diff --git a/app/types.py b/app/types.py index 41a35d9..7149b0d 100644 --- a/app/types.py +++ b/app/types.py @@ -88,3 +88,87 @@ class UpdateContextSuccessDict(TypedDict): context_id: int updated_fields: list[str] message: str + + +# Bulk operation TypedDicts + + +class BulkStoreItemDict(TypedDict, total=False): + """Type definition for a single item in bulk store request. + + Required fields: thread_id, source, text + Optional fields: metadata, tags, images + """ + + thread_id: str + source: str + text: str + metadata: MetadataDict | None + tags: list[str] | None + images: list[dict[str, str]] | None + + +class BulkStoreResultItemDict(TypedDict): + """Type definition for a single result in bulk store response.""" + + index: int + success: bool + context_id: int | None + error: str | None + + +class BulkStoreResponseDict(TypedDict): + """Type definition for bulk store response.""" + + success: bool + total: int + succeeded: int + failed: int + results: list[BulkStoreResultItemDict] + message: str + + +class BulkUpdateItemDict(TypedDict, total=False): + """Type definition for a single item in bulk update request. + + Required field: context_id + Optional fields: text, metadata, metadata_patch, tags, images + Note: metadata and metadata_patch are mutually exclusive per entry. + """ + + context_id: int + text: str | None + metadata: MetadataDict | None + metadata_patch: MetadataDict | None + tags: list[str] | None + images: list[dict[str, str]] | None + + +class BulkUpdateResultItemDict(TypedDict): + """Type definition for a single result in bulk update response.""" + + index: int + context_id: int + success: bool + updated_fields: list[str] | None + error: str | None + + +class BulkUpdateResponseDict(TypedDict): + """Type definition for bulk update response.""" + + success: bool + total: int + succeeded: int + failed: int + results: list[BulkUpdateResultItemDict] + message: str + + +class BulkDeleteResponseDict(TypedDict): + """Type definition for bulk delete response.""" + + success: bool + deleted_count: int + criteria_used: list[str] + message: str diff --git a/tests/test_bulk_operations.py b/tests/test_bulk_operations.py new file mode 100644 index 0000000..d53499b --- /dev/null +++ b/tests/test_bulk_operations.py @@ -0,0 +1,564 @@ +"""Comprehensive tests for bulk operations in MCP Context Server. + +This module tests the three bulk operation MCP tools: +- store_context_batch: Batch insert with deduplication +- update_context_batch: Batch update with partial field support +- delete_context_batch: Criteria-based batch delete + +Tests cover: +- Repository layer unit tests +- MCP tool integration tests +- Atomic and non-atomic mode tests +- Error handling and validation tests +""" + +from __future__ import annotations + +import pytest +from fastmcp.exceptions import ToolError + +# Import the actual async functions from app.server, not the MCP-wrapped versions +# The FunctionTool objects store the original functions in their 'fn' attribute +import app.server + +# Get the actual async functions from the FunctionTool wrappers +# FastMCP wraps our functions in FunctionTool objects, but we need the original functions for testing +store_context = app.server.store_context.fn +store_context_batch = app.server.store_context_batch.fn +update_context_batch = app.server.update_context_batch.fn +delete_context_batch = app.server.delete_context_batch.fn + + +@pytest.mark.usefixtures('initialized_server') +class TestStoreContextBatch: + """Tests for the store_context_batch MCP tool.""" + + @pytest.mark.asyncio + async def test_store_batch_success_atomic(self) -> None: + """Test successful batch store with atomic mode.""" + entries = [ + {'thread_id': 'batch-test-1', 'source': 'user', 'text': 'First entry'}, + {'thread_id': 'batch-test-1', 'source': 'agent', 'text': 'Second entry'}, + {'thread_id': 'batch-test-2', 'source': 'user', 'text': 'Third entry'}, + ] + + result = await store_context_batch(entries=entries, atomic=True) + + assert result['success'] is True + assert result['total'] == 3 + assert result['succeeded'] == 3 + assert result['failed'] == 0 + assert len(result['results']) == 3 + + # Verify all results have context_ids + for item in result['results']: + assert item['success'] is True + assert item['context_id'] is not None + assert item['error'] is None + + @pytest.mark.asyncio + async def test_store_batch_success_non_atomic(self) -> None: + """Test successful batch store with non-atomic mode.""" + entries = [ + {'thread_id': 'batch-non-atomic-1', 'source': 'user', 'text': 'Entry one'}, + {'thread_id': 'batch-non-atomic-2', 'source': 'agent', 'text': 'Entry two'}, + ] + + result = await store_context_batch(entries=entries, atomic=False) + + assert result['success'] is True + assert result['total'] == 2 + assert result['succeeded'] == 2 + assert result['failed'] == 0 + + @pytest.mark.asyncio + async def test_store_batch_with_metadata(self) -> None: + """Test batch store with metadata.""" + entries = [ + { + 'thread_id': 'batch-meta-1', + 'source': 'user', + 'text': 'Entry with metadata', + 'metadata': {'priority': 1, 'status': 'pending'}, + }, + { + 'thread_id': 'batch-meta-1', + 'source': 'agent', + 'text': 'Another entry', + 'metadata': {'agent_name': 'test-agent'}, + }, + ] + + result = await store_context_batch(entries=entries, atomic=True) + + assert result['success'] is True + assert result['succeeded'] == 2 + + @pytest.mark.asyncio + async def test_store_batch_with_tags(self) -> None: + """Test batch store with tags.""" + entries = [ + { + 'thread_id': 'batch-tags-1', + 'source': 'user', + 'text': 'Entry with tags', + 'tags': ['important', 'review'], + }, + { + 'thread_id': 'batch-tags-1', + 'source': 'agent', + 'text': 'Another entry', + 'tags': ['processed'], + }, + ] + + result = await store_context_batch(entries=entries, atomic=True) + + assert result['success'] is True + assert result['succeeded'] == 2 + + @pytest.mark.asyncio + async def test_store_batch_deduplication(self) -> None: + """Test deduplication when storing duplicate entries.""" + entries = [ + {'thread_id': 'batch-dedup-1', 'source': 'user', 'text': 'Duplicate entry'}, + {'thread_id': 'batch-dedup-1', 'source': 'user', 'text': 'Duplicate entry'}, + ] + + result = await store_context_batch(entries=entries, atomic=True) + + assert result['success'] is True + assert result['succeeded'] == 2 + + # Both should return the same context_id due to deduplication + ids = [r['context_id'] for r in result['results']] + assert ids[0] == ids[1] + + @pytest.mark.asyncio + async def test_store_batch_validation_error_atomic(self) -> None: + """Test that atomic mode fails fast on validation error.""" + entries = [ + {'thread_id': 'batch-valid-1', 'source': 'user', 'text': 'Valid entry'}, + {'thread_id': 'batch-valid-1', 'source': 'invalid', 'text': 'Invalid source'}, + {'thread_id': 'batch-valid-1', 'source': 'user', 'text': 'Another valid'}, + ] + + with pytest.raises(ToolError) as exc_info: + await store_context_batch(entries=entries, atomic=True) + + assert 'Validation failed' in str(exc_info.value) + assert 'source' in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_store_batch_validation_error_non_atomic(self) -> None: + """Test non-atomic mode allows partial success with validation errors.""" + entries = [ + {'thread_id': 'batch-partial-1', 'source': 'user', 'text': 'Valid entry'}, + {'thread_id': 'batch-partial-1', 'source': 'invalid', 'text': 'Invalid source'}, + {'thread_id': 'batch-partial-1', 'source': 'agent', 'text': 'Another valid'}, + ] + + result = await store_context_batch(entries=entries, atomic=False) + + assert result['success'] is False + assert result['total'] == 3 + assert result['succeeded'] == 2 + assert result['failed'] == 1 + + # Check the failed entry + failed = [r for r in result['results'] if not r['success']] + assert len(failed) == 1 + assert failed[0]['index'] == 1 + assert 'source' in failed[0]['error'].lower() + + @pytest.mark.asyncio + async def test_store_batch_missing_thread_id(self) -> None: + """Test validation error for missing thread_id.""" + entries = [ + {'source': 'user', 'text': 'Missing thread_id'}, + ] + + result = await store_context_batch(entries=entries, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'thread_id' in result['results'][0]['error'].lower() + + @pytest.mark.asyncio + async def test_store_batch_missing_text(self) -> None: + """Test validation error for missing text.""" + entries = [ + {'thread_id': 'batch-no-text', 'source': 'user'}, + ] + + result = await store_context_batch(entries=entries, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'text' in result['results'][0]['error'].lower() + + @pytest.mark.asyncio + async def test_store_batch_empty_text(self) -> None: + """Test validation error for empty text.""" + entries = [ + {'thread_id': 'batch-empty-text', 'source': 'user', 'text': ' '}, + ] + + result = await store_context_batch(entries=entries, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'empty' in result['results'][0]['error'].lower() + + +@pytest.mark.usefixtures('initialized_server') +class TestUpdateContextBatch: + """Tests for the update_context_batch MCP tool.""" + + @pytest.mark.asyncio + async def test_update_batch_text_success(self) -> None: + """Test successful batch text update.""" + # First, create entries to update + entry1 = await store_context( + thread_id='update-batch-1', + source='user', + text='Original text 1', + ) + entry2 = await store_context( + thread_id='update-batch-1', + source='agent', + text='Original text 2', + ) + + updates = [ + {'context_id': entry1['context_id'], 'text': 'Updated text 1'}, + {'context_id': entry2['context_id'], 'text': 'Updated text 2'}, + ] + + result = await update_context_batch(updates=updates, atomic=True) + + assert result['success'] is True + assert result['total'] == 2 + assert result['succeeded'] == 2 + + # Verify updated_fields includes text_content + for item in result['results']: + assert item['success'] is True + assert 'text_content' in item['updated_fields'] + + @pytest.mark.asyncio + async def test_update_batch_metadata_success(self) -> None: + """Test successful batch metadata update.""" + entry = await store_context( + thread_id='update-meta-1', + source='user', + text='Entry for metadata update', + metadata={'original': True}, + ) + + updates = [ + { + 'context_id': entry['context_id'], + 'metadata': {'updated': True, 'version': 2}, + }, + ] + + result = await update_context_batch(updates=updates, atomic=True) + + assert result['success'] is True + assert result['succeeded'] == 1 + assert 'metadata' in result['results'][0]['updated_fields'] + + @pytest.mark.asyncio + async def test_update_batch_tags_success(self) -> None: + """Test successful batch tags update.""" + entry = await store_context( + thread_id='update-tags-1', + source='user', + text='Entry for tags update', + tags=['old-tag'], + ) + + updates = [ + { + 'context_id': entry['context_id'], + 'tags': ['new-tag-1', 'new-tag-2'], + }, + ] + + result = await update_context_batch(updates=updates, atomic=True) + + assert result['success'] is True + assert result['succeeded'] == 1 + assert 'tags' in result['results'][0]['updated_fields'] + + @pytest.mark.asyncio + async def test_update_batch_not_found(self) -> None: + """Test update of non-existent context entry.""" + updates = [ + {'context_id': 999999, 'text': 'This should fail'}, + ] + + result = await update_context_batch(updates=updates, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'not found' in result['results'][0]['error'].lower() + + @pytest.mark.asyncio + async def test_update_batch_missing_context_id(self) -> None: + """Test validation error for missing context_id.""" + updates = [ + {'text': 'No context_id provided'}, + ] + + result = await update_context_batch(updates=updates, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'context_id' in result['results'][0]['error'].lower() + + @pytest.mark.asyncio + async def test_update_batch_metadata_and_patch_conflict(self) -> None: + """Test validation error when both metadata and metadata_patch are provided.""" + entry = await store_context( + thread_id='update-conflict-1', + source='user', + text='Entry for conflict test', + ) + + updates = [ + { + 'context_id': entry['context_id'], + 'metadata': {'full': True}, + 'metadata_patch': {'partial': True}, + }, + ] + + result = await update_context_batch(updates=updates, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'both' in result['results'][0]['error'].lower() + + @pytest.mark.asyncio + async def test_update_batch_no_fields_to_update(self) -> None: + """Test validation error when no fields are provided for update.""" + entry = await store_context( + thread_id='update-nofields-1', + source='user', + text='Entry for no-fields test', + ) + + updates = [ + {'context_id': entry['context_id']}, + ] + + result = await update_context_batch(updates=updates, atomic=False) + + assert result['success'] is False + assert result['failed'] == 1 + assert 'field' in result['results'][0]['error'].lower() + + @pytest.mark.asyncio + async def test_update_batch_partial_success(self) -> None: + """Test non-atomic mode with partial success.""" + entry = await store_context( + thread_id='update-partial-1', + source='user', + text='Valid entry', + ) + + updates = [ + {'context_id': entry['context_id'], 'text': 'Updated successfully'}, + {'context_id': 999999, 'text': 'This will fail'}, + ] + + result = await update_context_batch(updates=updates, atomic=False) + + assert result['success'] is False + assert result['total'] == 2 + assert result['succeeded'] == 1 + assert result['failed'] == 1 + + @pytest.mark.asyncio + async def test_update_batch_atomic_rollback(self) -> None: + """Test atomic mode fails fast on validation error.""" + updates = [ + {'context_id': 1, 'text': 'Valid update'}, + {'context_id': -1, 'text': 'Invalid context_id'}, + ] + + with pytest.raises(ToolError) as exc_info: + await update_context_batch(updates=updates, atomic=True) + + assert 'Validation failed' in str(exc_info.value) + + +@pytest.mark.usefixtures('initialized_server') +class TestDeleteContextBatch: + """Tests for the delete_context_batch MCP tool.""" + + @pytest.mark.asyncio + async def test_delete_batch_by_ids(self) -> None: + """Test batch delete by context IDs.""" + entry1 = await store_context( + thread_id='delete-batch-1', + source='user', + text='Entry to delete 1', + ) + entry2 = await store_context( + thread_id='delete-batch-1', + source='agent', + text='Entry to delete 2', + ) + + result = await delete_context_batch( + context_ids=[entry1['context_id'], entry2['context_id']], + ) + + assert result['success'] is True + assert result['deleted_count'] == 2 + assert any('context_ids' in c for c in result['criteria_used']) + + @pytest.mark.asyncio + async def test_delete_batch_by_thread_ids(self) -> None: + """Test batch delete by thread IDs.""" + await store_context( + thread_id='delete-thread-1', + source='user', + text='Entry in thread 1', + ) + await store_context( + thread_id='delete-thread-1', + source='agent', + text='Another entry in thread 1', + ) + await store_context( + thread_id='delete-thread-2', + source='user', + text='Entry in thread 2', + ) + + result = await delete_context_batch( + thread_ids=['delete-thread-1', 'delete-thread-2'], + ) + + assert result['success'] is True + assert result['deleted_count'] == 3 + assert any('thread_ids' in c for c in result['criteria_used']) + + @pytest.mark.asyncio + async def test_delete_batch_by_thread_and_source(self) -> None: + """Test batch delete by thread with source filter.""" + await store_context( + thread_id='delete-source-1', + source='user', + text='User entry', + ) + await store_context( + thread_id='delete-source-1', + source='agent', + text='Agent entry', + ) + + result = await delete_context_batch( + thread_ids=['delete-source-1'], + source='user', + ) + + assert result['success'] is True + assert result['deleted_count'] == 1 + assert any('source' in c for c in result['criteria_used']) + + @pytest.mark.asyncio + async def test_delete_batch_no_criteria_error(self) -> None: + """Test that delete fails when no criteria provided.""" + with pytest.raises(ToolError) as exc_info: + await delete_context_batch() + + assert 'criterion' in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_delete_batch_source_only_error(self) -> None: + """Test that source alone is not a valid criterion.""" + with pytest.raises(ToolError) as exc_info: + await delete_context_batch(source='user') + + assert 'combined' in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_delete_batch_no_matches(self) -> None: + """Test delete with no matching entries.""" + result = await delete_context_batch( + context_ids=[999998, 999999], + ) + + assert result['success'] is True + assert result['deleted_count'] == 0 + + +@pytest.mark.usefixtures('initialized_server') +class TestBulkOperationsEdgeCases: + """Edge case tests for bulk operations.""" + + @pytest.mark.asyncio + async def test_store_batch_large_batch(self) -> None: + """Test batch store with larger batch (50 entries).""" + entries = [ + { + 'thread_id': f'large-batch-{i % 5}', + 'source': 'user' if i % 2 == 0 else 'agent', + 'text': f'Entry number {i}', + } + for i in range(50) + ] + + result = await store_context_batch(entries=entries, atomic=True) + + assert result['success'] is True + assert result['total'] == 50 + assert result['succeeded'] == 50 + + @pytest.mark.asyncio + async def test_update_batch_multiple_fields(self) -> None: + """Test updating multiple fields in single batch update.""" + entry = await store_context( + thread_id='multi-field-1', + source='user', + text='Original text', + metadata={'original': True}, + tags=['original'], + ) + + updates = [ + { + 'context_id': entry['context_id'], + 'text': 'Updated text', + 'metadata': {'updated': True}, + 'tags': ['updated'], + }, + ] + + result = await update_context_batch(updates=updates, atomic=True) + + assert result['success'] is True + assert result['succeeded'] == 1 + updated_fields = result['results'][0]['updated_fields'] + assert 'text_content' in updated_fields + assert 'metadata' in updated_fields + assert 'tags' in updated_fields + + @pytest.mark.asyncio + async def test_results_ordered_by_index(self) -> None: + """Test that results are ordered by original index.""" + entries = [ + {'thread_id': 'order-test', 'source': 'user', 'text': f'Entry {i}'} + for i in range(10) + ] + + result = await store_context_batch(entries=entries, atomic=True) + + # Verify results are in order + indices = [r['index'] for r in result['results']] + assert indices == list(range(10)) diff --git a/tests/test_real_server.py b/tests/test_real_server.py index 80c8830..a1cd21c 100644 --- a/tests/test_real_server.py +++ b/tests/test_real_server.py @@ -1695,6 +1695,382 @@ async def test_semantic_search_context_with_date_filtering(self) -> bool: self.test_results.append((test_name, False, f'Exception: {e}')) return False + async def test_store_context_batch(self) -> bool: + """Test bulk store context operations. + + Tests atomic and non-atomic modes for batch storing multiple entries. + + Returns: + bool: True if test passed. + """ + test_name = 'store_context_batch' + assert self.client is not None # Type guard for Pyright + try: + # Create a separate thread for bulk store tests + bulk_store_thread = f'{self.test_thread_id}_bulk_store' + + # Test 1: Store multiple entries successfully (atomic=True) + entries = [ + { + 'thread_id': bulk_store_thread, + 'source': 'user', + 'text': 'First bulk entry', + 'metadata': {'priority': 1, 'type': 'test'}, + 'tags': ['bulk', 'first'], + }, + { + 'thread_id': bulk_store_thread, + 'source': 'agent', + 'text': 'Second bulk entry', + 'metadata': {'priority': 2, 'type': 'test'}, + 'tags': ['bulk', 'second'], + }, + { + 'thread_id': bulk_store_thread, + 'source': 'user', + 'text': 'Third bulk entry', + 'tags': ['bulk', 'third'], + }, + ] + + result = await self.client.call_tool( + 'store_context_batch', + {'entries': entries, 'atomic': True}, + ) + + result_data = self._extract_content(result) + + if not result_data.get('success'): + self.test_results.append((test_name, False, f'Atomic batch store failed: {result_data}')) + return False + + if result_data.get('total') != 3 or result_data.get('succeeded') != 3: + self.test_results.append( + (test_name, False, f"Expected 3 stored, got {result_data.get('succeeded')}/{result_data.get('total')}"), + ) + return False + + # Verify all entries have context_ids + results = result_data.get('results', []) + if len(results) != 3 or not all(r.get('context_id') for r in results): + self.test_results.append((test_name, False, 'Missing context_ids in results')) + return False + + # Test 2: Verify entries stored correctly via search + search_result = await self.client.call_tool( + 'search_context', + {'thread_id': bulk_store_thread}, + ) + + search_data = self._extract_content(search_result) + + if len(search_data.get('results', [])) != 3: + self.test_results.append( + (test_name, False, f"Expected 3 entries, found {len(search_data.get('results', []))}"), + ) + return False + + # Test 3: Non-atomic mode (atomic=False) + non_atomic_thread = f'{self.test_thread_id}_bulk_nonatomic' + non_atomic_entries = [ + { + 'thread_id': non_atomic_thread, + 'source': 'agent', + 'text': 'Non-atomic entry 1', + }, + { + 'thread_id': non_atomic_thread, + 'source': 'user', + 'text': 'Non-atomic entry 2', + 'metadata': {'processed': True}, + }, + ] + + non_atomic_result = await self.client.call_tool( + 'store_context_batch', + {'entries': non_atomic_entries, 'atomic': False}, + ) + + non_atomic_data = self._extract_content(non_atomic_result) + + if not non_atomic_data.get('success') or non_atomic_data.get('succeeded') != 2: + self.test_results.append((test_name, False, f'Non-atomic batch store failed: {non_atomic_data}')) + return False + + stored_count = result_data.get('succeeded', 0) + non_atomic_data.get('succeeded', 0) + self.test_results.append((test_name, True, f'Stored {stored_count} entries in batch')) + return True + + except Exception as e: + self.test_results.append((test_name, False, f'Exception: {e}')) + return False + + async def test_update_context_batch(self) -> bool: + """Test bulk update context operations. + + Tests batch updating multiple entries with various field combinations. + + Returns: + bool: True if test passed. + """ + test_name = 'update_context_batch' + assert self.client is not None # Type guard for Pyright + try: + # Create a separate thread for bulk update tests + bulk_update_thread = f'{self.test_thread_id}_bulk_update' + + # First, create entries to update + setup_entries = [ + { + 'thread_id': bulk_update_thread, + 'source': 'user', + 'text': 'Original text 1', + 'metadata': {'status': 'draft', 'version': 1}, + 'tags': ['original'], + }, + { + 'thread_id': bulk_update_thread, + 'source': 'agent', + 'text': 'Original text 2', + 'metadata': {'status': 'pending', 'version': 1}, + 'tags': ['original'], + }, + { + 'thread_id': bulk_update_thread, + 'source': 'user', + 'text': 'Original text 3', + 'tags': ['original'], + }, + ] + + setup_result = await self.client.call_tool( + 'store_context_batch', + {'entries': setup_entries, 'atomic': True}, + ) + + setup_data = self._extract_content(setup_result) + + if not setup_data.get('success') or setup_data.get('succeeded') != 3: + self.test_results.append((test_name, False, f'Failed to setup test entries: {setup_data}')) + return False + + # Get the context IDs + context_ids = [r['context_id'] for r in setup_data['results']] + + # Test 1: Batch update text for multiple entries + updates = [ + {'context_id': context_ids[0], 'text': 'Updated text 1'}, + {'context_id': context_ids[1], 'text': 'Updated text 2'}, + {'context_id': context_ids[2], 'text': 'Updated text 3'}, + ] + + update_result = await self.client.call_tool( + 'update_context_batch', + {'updates': updates, 'atomic': True}, + ) + + update_data = self._extract_content(update_result) + + if not update_data.get('success') or update_data.get('succeeded') != 3: + self.test_results.append((test_name, False, f'Batch text update failed: {update_data}')) + return False + + # Verify updated_fields contains text_content + for item in update_data.get('results', []): + if 'text_content' not in item.get('updated_fields', []): + self.test_results.append((test_name, False, 'text_content not in updated_fields')) + return False + + # Test 2: Batch update metadata + metadata_updates = [ + { + 'context_id': context_ids[0], + 'metadata': {'status': 'completed', 'version': 2}, + }, + { + 'context_id': context_ids[1], + 'metadata_patch': {'version': 2, 'reviewed': True}, + }, + ] + + metadata_result = await self.client.call_tool( + 'update_context_batch', + {'updates': metadata_updates, 'atomic': True}, + ) + + metadata_data = self._extract_content(metadata_result) + + if not metadata_data.get('success') or metadata_data.get('succeeded') != 2: + self.test_results.append((test_name, False, f'Batch metadata update failed: {metadata_data}')) + return False + + # Test 3: Verify updates via get_context_by_ids + verify_result = await self.client.call_tool( + 'get_context_by_ids', + {'context_ids': context_ids}, + ) + + verify_data = self._extract_content(verify_result) + + if not verify_data.get('success') or len(verify_data.get('results', [])) != 3: + self.test_results.append((test_name, False, 'Failed to verify updates')) + return False + + # Check that text was updated + for entry in verify_data['results']: + if not entry.get('text_content', '').startswith('Updated text'): + self.test_results.append((test_name, False, 'Text not updated correctly')) + return False + + self.test_results.append((test_name, True, f'Updated {update_data.get("succeeded", 0)} entries in batch')) + return True + + except Exception as e: + self.test_results.append((test_name, False, f'Exception: {e}')) + return False + + async def test_delete_context_batch(self) -> bool: + """Test bulk delete context operations. + + Tests deletion by various criteria: context_ids, thread_ids, and combined filters. + + Returns: + bool: True if test passed. + """ + test_name = 'delete_context_batch' + assert self.client is not None # Type guard for Pyright + try: + # Create separate threads for bulk delete tests + delete_by_ids_thread = f'{self.test_thread_id}_bulk_del_ids' + delete_by_thread_thread = f'{self.test_thread_id}_bulk_del_thread' + delete_combined_thread = f'{self.test_thread_id}_bulk_del_combined' + + # Setup: Create entries in different threads for deletion tests + setup_entries = [ + # Entries for delete by IDs test + {'thread_id': delete_by_ids_thread, 'source': 'user', 'text': 'Delete by ID 1'}, + {'thread_id': delete_by_ids_thread, 'source': 'agent', 'text': 'Delete by ID 2'}, + # Entries for delete by thread test + {'thread_id': delete_by_thread_thread, 'source': 'user', 'text': 'Delete by thread 1'}, + {'thread_id': delete_by_thread_thread, 'source': 'agent', 'text': 'Delete by thread 2'}, + {'thread_id': delete_by_thread_thread, 'source': 'user', 'text': 'Delete by thread 3'}, + # Entries for combined criteria test + {'thread_id': delete_combined_thread, 'source': 'user', 'text': 'Combined user 1'}, + {'thread_id': delete_combined_thread, 'source': 'user', 'text': 'Combined user 2'}, + {'thread_id': delete_combined_thread, 'source': 'agent', 'text': 'Combined agent 1'}, + ] + + setup_result = await self.client.call_tool( + 'store_context_batch', + {'entries': setup_entries, 'atomic': True}, + ) + + setup_data = self._extract_content(setup_result) + + if not setup_data.get('success') or setup_data.get('succeeded') != 8: + self.test_results.append((test_name, False, f'Failed to setup delete test entries: {setup_data}')) + return False + + # Get context IDs for the first two entries (delete by IDs test) + ids_to_delete = [setup_data['results'][0]['context_id'], setup_data['results'][1]['context_id']] + + # Test 1: Delete by context_ids + delete_by_ids_result = await self.client.call_tool( + 'delete_context_batch', + {'context_ids': ids_to_delete}, + ) + + delete_by_ids_data = self._extract_content(delete_by_ids_result) + + if not delete_by_ids_data.get('success') or delete_by_ids_data.get('deleted_count') != 2: + self.test_results.append( + (test_name, False, f"Delete by IDs failed: expected 2, got {delete_by_ids_data.get('deleted_count')}"), + ) + return False + + # Verify criteria_used contains context_ids + if 'context_ids' not in str(delete_by_ids_data.get('criteria_used', [])): + self.test_results.append((test_name, False, 'context_ids not in criteria_used')) + return False + + # Verify entries are deleted + verify_deleted = await self.client.call_tool( + 'get_context_by_ids', + {'context_ids': ids_to_delete}, + ) + + verify_deleted_data = self._extract_content(verify_deleted) + + if len(verify_deleted_data.get('results', [])) > 0: + self.test_results.append((test_name, False, 'Entries not deleted by IDs')) + return False + + # Test 2: Delete by thread_ids + delete_by_thread_result = await self.client.call_tool( + 'delete_context_batch', + {'thread_ids': [delete_by_thread_thread]}, + ) + + delete_by_thread_data = self._extract_content(delete_by_thread_result) + + if not delete_by_thread_data.get('success') or delete_by_thread_data.get('deleted_count') != 3: + deleted = delete_by_thread_data.get('deleted_count') + self.test_results.append( + (test_name, False, f'Delete by thread failed: expected 3, got {deleted}'), + ) + return False + + # Verify thread is empty + verify_thread = await self.client.call_tool( + 'search_context', + {'thread_id': delete_by_thread_thread}, + ) + + verify_thread_data = self._extract_content(verify_thread) + + if len(verify_thread_data.get('results', [])) > 0: + self.test_results.append((test_name, False, 'Thread entries not deleted')) + return False + + # Test 3: Delete by combined criteria (thread + source) + delete_combined_result = await self.client.call_tool( + 'delete_context_batch', + {'thread_ids': [delete_combined_thread], 'source': 'user'}, + ) + + delete_combined_data = self._extract_content(delete_combined_result) + + if not delete_combined_data.get('success') or delete_combined_data.get('deleted_count') != 2: + self.test_results.append( + (test_name, False, f"Combined delete failed: expected 2, got {delete_combined_data.get('deleted_count')}"), + ) + return False + + # Verify only agent entry remains + verify_combined = await self.client.call_tool( + 'search_context', + {'thread_id': delete_combined_thread}, + ) + + verify_combined_data = self._extract_content(verify_combined) + + remaining = verify_combined_data.get('results', []) + if len(remaining) != 1 or remaining[0].get('source') != 'agent': + self.test_results.append((test_name, False, 'Combined criteria did not filter correctly')) + return False + + total_deleted = ( + delete_by_ids_data.get('deleted_count', 0) + + delete_by_thread_data.get('deleted_count', 0) + + delete_combined_data.get('deleted_count', 0) + ) + self.test_results.append((test_name, True, f'Deleted {total_deleted} entries with various criteria')) + return True + + except Exception as e: + self.test_results.append((test_name, False, f'Exception: {e}')) + return False + async def test_semantic_search_context(self) -> bool: """Test semantic search functionality (conditional on availability). @@ -1879,6 +2255,9 @@ async def run_all_tests(self) -> bool: ('Metadata Patch Deep Merge', self.test_metadata_patch_deep_merge), ('List Threads', self.test_list_threads), ('Get Statistics', self.test_get_statistics), + ('Store Context Batch', self.test_store_context_batch), + ('Update Context Batch', self.test_update_context_batch), + ('Delete Context Batch', self.test_delete_context_batch), ('Semantic Search', self.test_semantic_search_context), ('Semantic Search Date Filtering', self.test_semantic_search_context_with_date_filtering), ]