diff --git a/src/chroma_mcp/server.py b/src/chroma_mcp/server.py index b2f728f..dbe8ddf 100644 --- a/src/chroma_mcp/server.py +++ b/src/chroma_mcp/server.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, TypedDict +from typing import Dict, List, TypedDict from enum import Enum import chromadb from mcp.server.fastmcp import FastMCP @@ -145,8 +145,8 @@ def get_chroma_client(args=None): @mcp.tool() async def chroma_list_collections( - limit: Optional[int] = None, - offset: Optional[int] = None + limit: int = None, + offset: int = None ) -> List[str]: """List all collection names in the Chroma database with pagination support. @@ -177,16 +177,16 @@ async def chroma_list_collections( @mcp.tool() async def chroma_create_collection( collection_name: str, - embedding_function_name: Optional[str] = "default", - metadata: Optional[Dict] = None, - space: Optional[str] = None, - ef_construction: Optional[int] = None, - ef_search: Optional[int] = None, - max_neighbors: Optional[int] = None, - num_threads: Optional[int] = None, - batch_size: Optional[int] = None, - sync_threshold: Optional[int] = None, - resize_factor: Optional[float] = None, + embedding_function_name: str = "default", + metadata: Dict = None, + space: str = None, + ef_construction: int = None, + ef_search: int = None, + max_neighbors: int = None, + num_threads: int = None, + batch_size: int = None, + sync_threshold: int = None, + resize_factor: float = None, ) -> str: """Create a new Chroma collection with configurable HNSW parameters. @@ -305,13 +305,13 @@ async def chroma_get_collection_count(collection_name: str) -> int: @mcp.tool() async def chroma_modify_collection( collection_name: str, - new_name: Optional[str] = None, - new_metadata: Optional[Dict] = None, - ef_search: Optional[int] = None, - num_threads: Optional[int] = None, - batch_size: Optional[int] = None, - sync_threshold: Optional[int] = None, - resize_factor: Optional[float] = None, + new_name: str = None, + new_metadata: Dict = None, + ef_search: int = None, + num_threads: int = None, + batch_size: int = None, + sync_threshold: int = None, + resize_factor: float = None, ) -> str: """Modify a Chroma collection's name or metadata. @@ -377,35 +377,62 @@ async def chroma_delete_collection(collection_name: str) -> str: async def chroma_add_documents( collection_name: str, documents: List[str], - metadatas: Optional[List[Dict]] = None, - ids: Optional[List[str]] = None + ids: List[str], + metadatas: List[Dict] = None ) -> str: """Add documents to a Chroma collection. Args: collection_name: Name of the collection to add documents to documents: List of text documents to add + ids: List of IDs for the documents (required) metadatas: Optional list of metadata dictionaries for each document - ids: Optional list of IDs for the documents """ if not documents: raise ValueError("The 'documents' list cannot be empty.") + + if not ids: + raise ValueError("The 'ids' list is required and cannot be empty.") + + # Check if there are empty strings in the ids list + if any(not id.strip() for id in ids): + raise ValueError("IDs cannot be empty strings.") + + if len(ids) != len(documents): + raise ValueError(f"Number of ids ({len(ids)}) must match number of documents ({len(documents)}).") client = get_chroma_client() try: collection = client.get_or_create_collection(collection_name) - # Generate sequential IDs if none provided - if ids is None: - ids = [str(i) for i in range(len(documents))] + # Check for duplicate IDs + existing_ids = collection.get(include=[])["ids"] + duplicate_ids = [id for id in ids if id in existing_ids] + + if duplicate_ids: + raise ValueError( + f"The following IDs already exist in collection '{collection_name}': {duplicate_ids}. " + f"Use 'chroma_update_documents' to update existing documents." + ) - collection.add( + result = collection.add( documents=documents, metadatas=metadatas, ids=ids ) - return f"Successfully added {len(documents)} documents to collection {collection_name}" + # Check the return value + if result and isinstance(result, dict): + # If the return value is a dictionary, it may contain success information + if 'success' in result and not result['success']: + raise Exception(f"Failed to add documents: {result.get('error', 'Unknown error')}") + + # If the return value contains the actual number added + if 'count' in result: + return f"Successfully added {result['count']} documents to collection {collection_name}" + + # Default return + return f"Successfully added {len(documents)} documents to collection {collection_name}, result is {result}" except Exception as e: raise Exception(f"Failed to add documents to collection '{collection_name}': {str(e)}") from e @@ -414,8 +441,8 @@ async def chroma_query_documents( collection_name: str, query_texts: List[str], n_results: int = 5, - where: Optional[Dict] = None, - where_document: Optional[Dict] = None, + where: Dict = None, + where_document: Dict = None, include: List[str] = ["documents", "metadatas", "distances"] ) -> Dict: """Query documents from a Chroma collection with advanced filtering. @@ -452,12 +479,12 @@ async def chroma_query_documents( @mcp.tool() async def chroma_get_documents( collection_name: str, - ids: Optional[List[str]] = None, - where: Optional[Dict] = None, - where_document: Optional[Dict] = None, + ids: List[str] = None, + where: Dict = None, + where_document: Dict = None, include: List[str] = ["documents", "metadatas"], - limit: Optional[int] = None, - offset: Optional[int] = None + limit: int = None, + offset: int = None ) -> Dict: """Get documents from a Chroma collection with optional filtering. @@ -496,9 +523,9 @@ async def chroma_get_documents( async def chroma_update_documents( collection_name: str, ids: List[str], - embeddings: Optional[List[List[float]]] = None, - metadatas: Optional[List[Dict]] = None, - documents: Optional[List[str]] = None + embeddings: List[List[float]] = None, + metadatas: List[Dict] = None, + documents: List[str] = None ) -> str: """Update documents in a Chroma collection.