Skip to content

Commit

Permalink
Merge pull request #1962 from SciPhi-AI/feature/add-get-collection-by…
Browse files Browse the repository at this point in the history
…-name-and-uniq-check

add uniq check and get collection by name
  • Loading branch information
emrgnt-cmplxty authored Feb 11, 2025
2 parents 10d2b6b + a371429 commit 1f582da
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 9 deletions.
2 changes: 1 addition & 1 deletion js/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.4.25",
"version": "0.4.26",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
13 changes: 13 additions & 0 deletions js/sdk/src/v3/clients/collections.ts
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,17 @@ export class CollectionsClient {
downloadBlob(blob, options.filename);
}
}

/**
* Retrieve a collection by its name.
* @param name The name of the collection to retrieve.
* @returns A promise that resolves with the collection details.
*/
async retrieveByName(options: { name: string; ownerId?: string }): Promise<WrappedCollectionResponse> {
const queryParams: Record<string, any> = {};
if (options.ownerId) {
queryParams.owner_id = options.ownerId;
}
return this.client.makeRequest("GET", `collections/name/${options.name}`, { params: queryParams });
}
}
47 changes: 47 additions & 0 deletions py/core/main/api/v3/collections_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,3 +1233,50 @@ async def extract(
"message": "Graph created successfully.",
"task_id": None,
}

@self.router.get(
"/collections/name/{collection_name}",
summary="Get a collection by name",
dependencies=[Depends(self.rate_limit_dependency)],
)
@self.base_endpoint
async def get_collection_by_name(
collection_name: str = Path(
..., description="The name of the collection"
),
owner_id: Optional[UUID] = Query(
None,
description="(Superuser only) Specify the owner_id to retrieve a collection by name",
),
auth_user=Depends(self.providers.auth.auth_wrapper()),
) -> WrappedCollectionResponse:
"""
Retrieve a collection by its (owner_id, name) combination.
The authenticated user can only fetch collections they own,
or, if superuser, from anyone.
"""
if auth_user.is_superuser:
if not owner_id:
owner_id = auth_user.id
else:
owner_id = auth_user.id

# If not superuser, fetch by (owner_id, name). Otherwise, maybe pass `owner_id=None`.
# Decide on the logic for superusers.
if not owner_id: # is_superuser
# If you want superusers to do /collections/name/<string>?owner_id=...
# just parse it from the query. For now, let's say it's not implemented.
raise R2RException(
"Superuser must specify an owner_id to fetch by name.", 400
)

collection = await self.providers.database.collections_handler.get_collection_by_name(
owner_id, collection_name
)
if not collection:
raise R2RException("Collection not found.", 404)

# Now, authorize the 'view' action just in case:
# e.g. await authorize_collection_action(auth_user, collection.id, CollectionAction.VIEW, self.services)

return collection
1 change: 1 addition & 0 deletions py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ async def agent_app(
The agent uses both vector search and knowledge graph capabilities to find and synthesize
information, providing detailed, factual responses with proper attribution to source documents.
"""
print("in app..")
if "model" not in rag_generation_config.__fields_set__:
rag_generation_config.model = self.config.app.quality_llm

Expand Down
91 changes: 89 additions & 2 deletions py/core/providers/database/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(
super().__init__(project_name, connection_manager)

async def create_tables(self) -> None:
query = f"""
# 1. Create the table if it does not exist.
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
owner_id UUID,
Expand All @@ -55,7 +56,57 @@ async def create_tables(self) -> None:
document_count INT DEFAULT 0
);
"""
await self.connection_manager.execute_query(query)
await self.connection_manager.execute_query(create_table_query)

# 2. Check for duplicate rows that would violate the uniqueness constraint.
check_duplicates_query = f"""
SELECT owner_id, name, COUNT(*) AS cnt
FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
GROUP BY owner_id, name
HAVING COUNT(*) > 1
"""
duplicates = await self.connection_manager.fetch_query(
check_duplicates_query
)
if duplicates:
logger.warning(
"Cannot add unique constraint (owner_id, name) because duplicates exist. "
"Please resolve duplicates first. Found duplicates: %s",
duplicates,
)
return # or raise an exception, depending on your use case

# 3. Parse the qualified table name into schema and table.
qualified_table = self._get_table_name(
PostgresCollectionsHandler.TABLE_NAME
)
if "." in qualified_table:
schema, table = qualified_table.split(".", 1)
else:
schema = "public"
table = qualified_table

# 4. Add the unique constraint if it does not already exist.
alter_table_constraint = f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint c
JOIN pg_class t ON c.conrelid = t.oid
JOIN pg_namespace n ON n.oid = t.relnamespace
WHERE t.relname = '{table}'
AND n.nspname = '{schema}'
AND c.conname = 'unique_owner_collection_name'
) THEN
ALTER TABLE {qualified_table}
ADD CONSTRAINT unique_owner_collection_name
UNIQUE (owner_id, name);
END IF;
END;
$$;
"""
await self.connection_manager.execute_query(alter_table_constraint)

async def collection_exists(self, collection_id: UUID) -> bool:
"""Check if a collection exists."""
Expand Down Expand Up @@ -621,3 +672,39 @@ async def export_to_csv(
status_code=500,
detail=f"Failed to export data: {str(e)}",
) from e

async def get_collection_by_name(
self, owner_id: UUID, name: str
) -> Optional[CollectionResponse]:
"""
Fetch a collection by owner_id + name combination.
Return None if not found.
"""
query = f"""
SELECT
id, owner_id, name, description, graph_sync_status,
graph_cluster_status, created_at, updated_at, user_count, document_count
FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
WHERE owner_id = $1 AND name = $2
LIMIT 1
"""
result = await self.connection_manager.fetchrow_query(
query, [owner_id, name]
)
if not result:
raise R2RException(
status_code=404,
message="No collection found with the specified name",
)
return CollectionResponse(
id=result["id"],
owner_id=result["owner_id"],
name=result["name"],
description=result["description"],
graph_sync_status=result["graph_sync_status"],
graph_cluster_status=result["graph_cluster_status"],
created_at=result["created_at"],
updated_at=result["updated_at"],
user_count=result["user_count"],
document_count=result["document_count"],
)
4 changes: 2 additions & 2 deletions py/r2r/r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ audio_lm = "openai/whisper-1"
[agent]
agent_static_prompt = "static_rag_agent"
agent_dynamic_prompt = "dynamic_rag_agent"
tools = ["local_search", "content", "web_search"]
# tools = ["local_search", "content"]
# tools = ["local_search", "content", "web_search"]
tools = ["local_search", "content"]

[auth]
provider = "r2r"
Expand Down
28 changes: 28 additions & 0 deletions py/sdk/asnyc_methods/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,31 @@ async def extract(
)

return WrappedGenericMessageResponse(**response_dict)

async def retrieve_by_name(
self, name: str, owner_id: Optional[str] = None
) -> WrappedCollectionResponse:
"""
Retrieve a collection by its name.
For non-superusers, the backend will use the authenticated user's ID.
For superusers, the caller must supply an owner_id to restrict the search.
Args:
name (str): The name of the collection to retrieve.
owner_id (Optional[str]): The owner ID to restrict the search. Required for superusers.
Returns:
WrappedCollectionResponse
"""
query_params: dict[str, Any] = {}
if owner_id is not None:
query_params["owner_id"] = owner_id

response_dict = await self.client._make_request(
"GET",
f"collections/name/{name}",
params=query_params,
version="v3",
)
return WrappedCollectionResponse(**response_dict)
8 changes: 6 additions & 2 deletions py/sdk/asnyc_methods/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ async def agent(
"search_settings": search_settings,
"task_prompt_override": task_prompt_override,
"include_title_if_available": include_title_if_available,
"conversation_id": conversation_id,
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"tools": tools,
"max_tool_context_length": max_tool_context_length,
"use_system_context": use_system_context,
Expand Down Expand Up @@ -256,7 +258,9 @@ async def reasoning_agent(

data: dict[str, Any] = {
"rag_generation_config": rag_generation_config or {},
"conversation_id": conversation_id,
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"tools": tools,
"max_tool_context_length": max_tool_context_length,
}
Expand Down
2 changes: 2 additions & 0 deletions py/sdk/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _make_request(

try:
response = self.client.request(method, url, **request_args)
print("response =", response)

self._handle_response(response)

if "application/json" in response.headers.get("Content-Type", ""):
Expand Down
28 changes: 28 additions & 0 deletions py/sdk/sync_methods/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,31 @@ def extract(
)

return WrappedGenericMessageResponse(**response_dict)

def retrieve_by_name(
self, name: str, owner_id: Optional[str] = None
) -> WrappedCollectionResponse:
"""
Retrieve a collection by its name.
For non-superusers, the backend will use the authenticated user's ID.
For superusers, the caller must supply an owner_id to restrict the search.
Args:
name (str): The name of the collection to retrieve.
owner_id (Optional[str]): The owner ID to restrict the search. Required for superusers.
Returns:
WrappedCollectionResponse
"""
query_params: dict[str, Any] = {}
if owner_id is not None:
query_params["owner_id"] = owner_id

response_dict = self.client._make_request(
"GET",
f"collections/name/{name}",
params=query_params,
version="v3",
)
return WrappedCollectionResponse(**response_dict)
8 changes: 6 additions & 2 deletions py/sdk/sync_methods/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def agent(
"search_settings": search_settings,
"task_prompt_override": task_prompt_override,
"include_title_if_available": include_title_if_available,
"conversation_id": str(conversation_id),
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"tools": tools,
"max_tool_context_length": max_tool_context_length,
"use_extended_prompt": use_extended_prompt,
Expand Down Expand Up @@ -257,7 +259,9 @@ def reasoning_agent(

data: dict[str, Any] = {
"rag_generation_config": rag_generation_config or {},
"conversation_id": conversation_id,
"conversation_id": (
str(conversation_id) if conversation_id else None
),
"tools": tools,
"max_tool_context_length": max_tool_context_length,
}
Expand Down
24 changes: 24 additions & 0 deletions py/tests/integration/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,27 @@ def test_delete_non_existent_collection(client: R2RClient):
assert (
exc_info.value.status_code == 404
), "Expected 404 when deleting non-existent collection"


def test_retrieve_collection_by_name(client: R2RClient):
# Generate a unique collection name
unique_name = f"TestRetrieveByName-{uuid.uuid4()}"

# Create a collection with the unique name
created_resp = client.collections.create(
name=unique_name, description="Collection for retrieval by name test"
)
created = created_resp.results
assert (
created.id is not None
), "Creation did not return a valid collection ID"

# Retrieve the collection by its name
retrieved_resp = client.collections.retrieve_by_name(unique_name)
retrieved = retrieved_resp.results
assert (
retrieved.id == created.id
), "Retrieved collection does not match the created collection"

# Cleanup: Delete the created collection
client.collections.delete(created.id)

0 comments on commit 1f582da

Please sign in to comment.