Skip to content

Commit 1f582da

Browse files
Merge pull request #1962 from SciPhi-AI/feature/add-get-collection-by-name-and-uniq-check
add uniq check and get collection by name
2 parents 10d2b6b + a371429 commit 1f582da

File tree

12 files changed

+247
-9
lines changed

12 files changed

+247
-9
lines changed

js/sdk/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "r2r-js",
3-
"version": "0.4.25",
3+
"version": "0.4.26",
44
"description": "",
55
"main": "dist/index.js",
66
"browser": "dist/index.browser.js",

js/sdk/src/v3/clients/collections.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,4 +322,17 @@ export class CollectionsClient {
322322
downloadBlob(blob, options.filename);
323323
}
324324
}
325+
326+
/**
327+
* Retrieve a collection by its name.
328+
* @param name The name of the collection to retrieve.
329+
* @returns A promise that resolves with the collection details.
330+
*/
331+
async retrieveByName(options: { name: string; ownerId?: string }): Promise<WrappedCollectionResponse> {
332+
const queryParams: Record<string, any> = {};
333+
if (options.ownerId) {
334+
queryParams.owner_id = options.ownerId;
335+
}
336+
return this.client.makeRequest("GET", `collections/name/${options.name}`, { params: queryParams });
337+
}
325338
}

py/core/main/api/v3/collections_router.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,3 +1233,50 @@ async def extract(
12331233
"message": "Graph created successfully.",
12341234
"task_id": None,
12351235
}
1236+
1237+
@self.router.get(
1238+
"/collections/name/{collection_name}",
1239+
summary="Get a collection by name",
1240+
dependencies=[Depends(self.rate_limit_dependency)],
1241+
)
1242+
@self.base_endpoint
1243+
async def get_collection_by_name(
1244+
collection_name: str = Path(
1245+
..., description="The name of the collection"
1246+
),
1247+
owner_id: Optional[UUID] = Query(
1248+
None,
1249+
description="(Superuser only) Specify the owner_id to retrieve a collection by name",
1250+
),
1251+
auth_user=Depends(self.providers.auth.auth_wrapper()),
1252+
) -> WrappedCollectionResponse:
1253+
"""
1254+
Retrieve a collection by its (owner_id, name) combination.
1255+
The authenticated user can only fetch collections they own,
1256+
or, if superuser, from anyone.
1257+
"""
1258+
if auth_user.is_superuser:
1259+
if not owner_id:
1260+
owner_id = auth_user.id
1261+
else:
1262+
owner_id = auth_user.id
1263+
1264+
# If not superuser, fetch by (owner_id, name). Otherwise, maybe pass `owner_id=None`.
1265+
# Decide on the logic for superusers.
1266+
if not owner_id: # is_superuser
1267+
# If you want superusers to do /collections/name/<string>?owner_id=...
1268+
# just parse it from the query. For now, let's say it's not implemented.
1269+
raise R2RException(
1270+
"Superuser must specify an owner_id to fetch by name.", 400
1271+
)
1272+
1273+
collection = await self.providers.database.collections_handler.get_collection_by_name(
1274+
owner_id, collection_name
1275+
)
1276+
if not collection:
1277+
raise R2RException("Collection not found.", 404)
1278+
1279+
# Now, authorize the 'view' action just in case:
1280+
# e.g. await authorize_collection_action(auth_user, collection.id, CollectionAction.VIEW, self.services)
1281+
1282+
return collection

py/core/main/api/v3/retrieval_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ async def agent_app(
652652
The agent uses both vector search and knowledge graph capabilities to find and synthesize
653653
information, providing detailed, factual responses with proper attribution to source documents.
654654
"""
655+
print("in app..")
655656
if "model" not in rag_generation_config.__fields_set__:
656657
rag_generation_config.model = self.config.app.quality_llm
657658

py/core/providers/database/collections.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def __init__(
4141
super().__init__(project_name, connection_manager)
4242

4343
async def create_tables(self) -> None:
44-
query = f"""
44+
# 1. Create the table if it does not exist.
45+
create_table_query = f"""
4546
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
4647
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
4748
owner_id UUID,
@@ -55,7 +56,57 @@ async def create_tables(self) -> None:
5556
document_count INT DEFAULT 0
5657
);
5758
"""
58-
await self.connection_manager.execute_query(query)
59+
await self.connection_manager.execute_query(create_table_query)
60+
61+
# 2. Check for duplicate rows that would violate the uniqueness constraint.
62+
check_duplicates_query = f"""
63+
SELECT owner_id, name, COUNT(*) AS cnt
64+
FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
65+
GROUP BY owner_id, name
66+
HAVING COUNT(*) > 1
67+
"""
68+
duplicates = await self.connection_manager.fetch_query(
69+
check_duplicates_query
70+
)
71+
if duplicates:
72+
logger.warning(
73+
"Cannot add unique constraint (owner_id, name) because duplicates exist. "
74+
"Please resolve duplicates first. Found duplicates: %s",
75+
duplicates,
76+
)
77+
return # or raise an exception, depending on your use case
78+
79+
# 3. Parse the qualified table name into schema and table.
80+
qualified_table = self._get_table_name(
81+
PostgresCollectionsHandler.TABLE_NAME
82+
)
83+
if "." in qualified_table:
84+
schema, table = qualified_table.split(".", 1)
85+
else:
86+
schema = "public"
87+
table = qualified_table
88+
89+
# 4. Add the unique constraint if it does not already exist.
90+
alter_table_constraint = f"""
91+
DO $$
92+
BEGIN
93+
IF NOT EXISTS (
94+
SELECT 1
95+
FROM pg_constraint c
96+
JOIN pg_class t ON c.conrelid = t.oid
97+
JOIN pg_namespace n ON n.oid = t.relnamespace
98+
WHERE t.relname = '{table}'
99+
AND n.nspname = '{schema}'
100+
AND c.conname = 'unique_owner_collection_name'
101+
) THEN
102+
ALTER TABLE {qualified_table}
103+
ADD CONSTRAINT unique_owner_collection_name
104+
UNIQUE (owner_id, name);
105+
END IF;
106+
END;
107+
$$;
108+
"""
109+
await self.connection_manager.execute_query(alter_table_constraint)
59110

60111
async def collection_exists(self, collection_id: UUID) -> bool:
61112
"""Check if a collection exists."""
@@ -621,3 +672,39 @@ async def export_to_csv(
621672
status_code=500,
622673
detail=f"Failed to export data: {str(e)}",
623674
) from e
675+
676+
async def get_collection_by_name(
677+
self, owner_id: UUID, name: str
678+
) -> Optional[CollectionResponse]:
679+
"""
680+
Fetch a collection by owner_id + name combination.
681+
Return None if not found.
682+
"""
683+
query = f"""
684+
SELECT
685+
id, owner_id, name, description, graph_sync_status,
686+
graph_cluster_status, created_at, updated_at, user_count, document_count
687+
FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
688+
WHERE owner_id = $1 AND name = $2
689+
LIMIT 1
690+
"""
691+
result = await self.connection_manager.fetchrow_query(
692+
query, [owner_id, name]
693+
)
694+
if not result:
695+
raise R2RException(
696+
status_code=404,
697+
message="No collection found with the specified name",
698+
)
699+
return CollectionResponse(
700+
id=result["id"],
701+
owner_id=result["owner_id"],
702+
name=result["name"],
703+
description=result["description"],
704+
graph_sync_status=result["graph_sync_status"],
705+
graph_cluster_status=result["graph_cluster_status"],
706+
created_at=result["created_at"],
707+
updated_at=result["updated_at"],
708+
user_count=result["user_count"],
709+
document_count=result["document_count"],
710+
)

py/r2r/r2r.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ audio_lm = "openai/whisper-1"
2323
[agent]
2424
agent_static_prompt = "static_rag_agent"
2525
agent_dynamic_prompt = "dynamic_rag_agent"
26-
tools = ["local_search", "content", "web_search"]
27-
# tools = ["local_search", "content"]
26+
# tools = ["local_search", "content", "web_search"]
27+
tools = ["local_search", "content"]
2828

2929
[auth]
3030
provider = "r2r"

py/sdk/asnyc_methods/collections.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,31 @@ async def extract(
326326
)
327327

328328
return WrappedGenericMessageResponse(**response_dict)
329+
330+
async def retrieve_by_name(
331+
self, name: str, owner_id: Optional[str] = None
332+
) -> WrappedCollectionResponse:
333+
"""
334+
Retrieve a collection by its name.
335+
336+
For non-superusers, the backend will use the authenticated user's ID.
337+
For superusers, the caller must supply an owner_id to restrict the search.
338+
339+
Args:
340+
name (str): The name of the collection to retrieve.
341+
owner_id (Optional[str]): The owner ID to restrict the search. Required for superusers.
342+
343+
Returns:
344+
WrappedCollectionResponse
345+
"""
346+
query_params: dict[str, Any] = {}
347+
if owner_id is not None:
348+
query_params["owner_id"] = owner_id
349+
350+
response_dict = await self.client._make_request(
351+
"GET",
352+
f"collections/name/{name}",
353+
params=query_params,
354+
version="v3",
355+
)
356+
return WrappedCollectionResponse(**response_dict)

py/sdk/asnyc_methods/retrieval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ async def agent(
197197
"search_settings": search_settings,
198198
"task_prompt_override": task_prompt_override,
199199
"include_title_if_available": include_title_if_available,
200-
"conversation_id": conversation_id,
200+
"conversation_id": (
201+
str(conversation_id) if conversation_id else None
202+
),
201203
"tools": tools,
202204
"max_tool_context_length": max_tool_context_length,
203205
"use_system_context": use_system_context,
@@ -256,7 +258,9 @@ async def reasoning_agent(
256258

257259
data: dict[str, Any] = {
258260
"rag_generation_config": rag_generation_config or {},
259-
"conversation_id": conversation_id,
261+
"conversation_id": (
262+
str(conversation_id) if conversation_id else None
263+
),
260264
"tools": tools,
261265
"max_tool_context_length": max_tool_context_length,
262266
}

py/sdk/sync_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def _make_request(
6161

6262
try:
6363
response = self.client.request(method, url, **request_args)
64+
print("response =", response)
65+
6466
self._handle_response(response)
6567

6668
if "application/json" in response.headers.get("Content-Type", ""):

py/sdk/sync_methods/collections.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,31 @@ def extract(
326326
)
327327

328328
return WrappedGenericMessageResponse(**response_dict)
329+
330+
def retrieve_by_name(
331+
self, name: str, owner_id: Optional[str] = None
332+
) -> WrappedCollectionResponse:
333+
"""
334+
Retrieve a collection by its name.
335+
336+
For non-superusers, the backend will use the authenticated user's ID.
337+
For superusers, the caller must supply an owner_id to restrict the search.
338+
339+
Args:
340+
name (str): The name of the collection to retrieve.
341+
owner_id (Optional[str]): The owner ID to restrict the search. Required for superusers.
342+
343+
Returns:
344+
WrappedCollectionResponse
345+
"""
346+
query_params: dict[str, Any] = {}
347+
if owner_id is not None:
348+
query_params["owner_id"] = owner_id
349+
350+
response_dict = self.client._make_request(
351+
"GET",
352+
f"collections/name/{name}",
353+
params=query_params,
354+
version="v3",
355+
)
356+
return WrappedCollectionResponse(**response_dict)

py/sdk/sync_methods/retrieval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def agent(
198198
"search_settings": search_settings,
199199
"task_prompt_override": task_prompt_override,
200200
"include_title_if_available": include_title_if_available,
201-
"conversation_id": str(conversation_id),
201+
"conversation_id": (
202+
str(conversation_id) if conversation_id else None
203+
),
202204
"tools": tools,
203205
"max_tool_context_length": max_tool_context_length,
204206
"use_extended_prompt": use_extended_prompt,
@@ -257,7 +259,9 @@ def reasoning_agent(
257259

258260
data: dict[str, Any] = {
259261
"rag_generation_config": rag_generation_config or {},
260-
"conversation_id": conversation_id,
262+
"conversation_id": (
263+
str(conversation_id) if conversation_id else None
264+
),
261265
"tools": tools,
262266
"max_tool_context_length": max_tool_context_length,
263267
}

py/tests/integration/test_collections.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,27 @@ def test_delete_non_existent_collection(client: R2RClient):
242242
assert (
243243
exc_info.value.status_code == 404
244244
), "Expected 404 when deleting non-existent collection"
245+
246+
247+
def test_retrieve_collection_by_name(client: R2RClient):
248+
# Generate a unique collection name
249+
unique_name = f"TestRetrieveByName-{uuid.uuid4()}"
250+
251+
# Create a collection with the unique name
252+
created_resp = client.collections.create(
253+
name=unique_name, description="Collection for retrieval by name test"
254+
)
255+
created = created_resp.results
256+
assert (
257+
created.id is not None
258+
), "Creation did not return a valid collection ID"
259+
260+
# Retrieve the collection by its name
261+
retrieved_resp = client.collections.retrieve_by_name(unique_name)
262+
retrieved = retrieved_resp.results
263+
assert (
264+
retrieved.id == created.id
265+
), "Retrieved collection does not match the created collection"
266+
267+
# Cleanup: Delete the created collection
268+
client.collections.delete(created.id)

0 commit comments

Comments
 (0)