Skip to content

Commit e1b8fcd

Browse files
authored
feat/use latest astradb client version (#457)
* use latest astradb client version * update _collection_exists method * fix everything that broke with version update * revert bae chang * revert change
1 parent 41d4d87 commit e1b8fcd

File tree

8 files changed

+77
-56
lines changed

8 files changed

+77
-56
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 1.0.2
2+
3+
### Features
4+
5+
* **Update astra source connector to use new astrapy client**
6+
17
## 1.0.1
28

39
### Features

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ ci = [
172172
"lancedb<=0.15.0",
173173
# TODO: versions higher than this are missing the macos wheel
174174
"pykx==2.5.3",
175-
"astrapy<2.0.0"
176175
]
177176

178177
[project.scripts]

requirements/connectors/astradb.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
astrapy
1+
astrapy>2.0.0

test/integration/connectors/databricks/test_volumes_native.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,7 @@ async def test_volumes_native_destination(upload_file: Path):
260260
),
261261
)
262262
uploader.precheck()
263-
if uploader.is_async():
264-
await uploader.run_async(path=upload_file, file_data=file_data)
265-
else:
266-
uploader.run(path=upload_file, file_data=file_data)
263+
uploader.run(path=upload_file, file_data=file_data)
267264

268265
validate_upload(
269266
client=workspace_client,

test/integration/connectors/test_astradb.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from _pytest.fixtures import TopRequest
1010
from astrapy import Collection
1111
from astrapy import DataAPIClient as AstraDBClient
12+
from astrapy.info import CollectionDefinition
1213

1314
from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG, VECTOR_DB_TAG
1415
from test.integration.connectors.utils.validation.destination import (
@@ -127,11 +128,16 @@ def collection(upload_file: Path) -> Collection:
127128
api_endpoint=env_data.api_endpoint,
128129
token=env_data.token,
129130
)
130-
collection = astra_db.create_collection(collection_name, dimension=embedding_dimension)
131+
collection = astra_db.create_collection(
132+
collection_name,
133+
definition=CollectionDefinition.builder()
134+
.set_vector_dimension(dimension=embedding_dimension)
135+
.build(),
136+
)
131137
try:
132138
yield collection
133139
finally:
134-
astra_db.drop_collection(collection)
140+
astra_db.drop_collection(collection.name)
135141

136142

137143
@pytest.mark.asyncio
@@ -198,7 +204,7 @@ async def test_astra_search_destination(
198204
output_filename=upload_file.name,
199205
)
200206
uploader.precheck()
201-
uploader.run(path=staged_filepath, file_data=file_data)
207+
await uploader.run_async(path=staged_filepath, file_data=file_data)
202208

203209
# Run validation
204210
with staged_filepath.open() as f:
@@ -211,7 +217,7 @@ async def test_astra_search_destination(
211217
)
212218

213219
# Rerun and make sure the same documents get updated
214-
uploader.run(path=staged_filepath, file_data=file_data)
220+
await uploader.run_async(path=staged_filepath, file_data=file_data)
215221
current_count = collection.count_documents(filter={}, upper_bound=expected_count * 2)
216222
assert current_count == expected_count, (
217223
f"Expected count ({expected_count}) doesn't match how "

unstructured_ingest/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.1" # pragma: no cover
1+
__version__ = "1.0.2" # pragma: no cover

unstructured_ingest/processes/connectors/astradb.py

+47-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
import csv
23
import hashlib
4+
import os
35
import re
46
from dataclasses import dataclass, field
57
from pathlib import Path
@@ -8,7 +10,6 @@
810

911
from pydantic import BaseModel, Field, Secret
1012

11-
from unstructured_ingest import __name__ as integration_name
1213
from unstructured_ingest.__version__ import __version__ as integration_version
1314
from unstructured_ingest.data_types.file_data import (
1415
BatchFileData,
@@ -83,10 +84,8 @@ def get_client(self) -> "AstraDBClient":
8384

8485
# Create a client object to interact with the Astra DB
8586
# caller_name/version for Astra DB tracking
86-
return AstraDBClient(
87-
caller_name=integration_name,
88-
caller_version=integration_version,
89-
)
87+
user_agent = os.getenv("UNSTRUCTURED_USER_AGENT", "unstructuredio_oss")
88+
return AstraDBClient(callers=[(user_agent, integration_version)])
9089

9190

9291
def get_astra_db(
@@ -141,7 +140,7 @@ async def get_async_astra_collection(
141140
)
142141

143142
# Get async collection from AsyncDatabase
144-
async_astra_db_collection = await async_astra_db.get_collection(name=collection_name)
143+
async_astra_db_collection = async_astra_db.get_collection(name=collection_name)
145144
return async_astra_db_collection
146145

147146

@@ -360,13 +359,22 @@ class AstraDBUploader(Uploader):
360359
upload_config: AstraDBUploaderConfig
361360
connector_type: str = CONNECTOR_TYPE
362361

362+
def is_async(self) -> bool:
363+
return True
364+
363365
def init(self, **kwargs: Any) -> None:
364366
self.create_destination(**kwargs)
365367

368+
@requires_dependencies(["astrapy"], extras="astradb")
366369
def precheck(self) -> None:
367370
try:
368371
if self.upload_config.collection_name:
369-
self.get_collection(collection_name=self.upload_config.collection_name).options()
372+
collection = get_astra_collection(
373+
connection_config=self.connection_config,
374+
collection_name=self.upload_config.collection_name,
375+
keyspace=self.upload_config.keyspace,
376+
)
377+
collection.options()
370378
else:
371379
# check for db connection only if collection name is not provided
372380
get_astra_db(
@@ -377,17 +385,7 @@ def precheck(self) -> None:
377385
logger.error(f"Failed to validate connection {e}", exc_info=True)
378386
raise DestinationConnectionError(f"failed to validate connection: {e}")
379387

380-
@requires_dependencies(["astrapy"], extras="astradb")
381-
def get_collection(self, collection_name: Optional[str] = None) -> "AstraDBCollection":
382-
return get_astra_collection(
383-
connection_config=self.connection_config,
384-
collection_name=collection_name or self.upload_config.collection_name,
385-
keyspace=self.upload_config.keyspace,
386-
)
387-
388388
def _collection_exists(self, collection_name: str):
389-
from astrapy.exceptions import CollectionNotFoundException
390-
391389
collection = get_astra_collection(
392390
connection_config=self.connection_config,
393391
collection_name=collection_name,
@@ -397,8 +395,10 @@ def _collection_exists(self, collection_name: str):
397395
try:
398396
collection.options()
399397
return True
400-
except CollectionNotFoundException:
401-
return False
398+
except RuntimeError as e:
399+
if "not found" in str(e):
400+
return False
401+
raise DestinationConnectionError(f"failed to check if astra collection exists : {e}")
402402
except Exception as e:
403403
logger.error(f"failed to check if astra collection exists : {e}")
404404
raise DestinationConnectionError(f"failed to check if astra collection exists : {e}")
@@ -422,51 +422,65 @@ def create_destination(
422422
self.upload_config.collection_name = collection_name
423423

424424
if not self._collection_exists(collection_name):
425+
from astrapy.info import CollectionDefinition
426+
425427
astra_db = get_astra_db(
426428
connection_config=self.connection_config, keyspace=self.upload_config.keyspace
427429
)
428430
logger.info(
429431
f"creating default astra collection '{collection_name}' with dimension "
430432
f"{vector_length} and metric {similarity_metric}"
431433
)
432-
astra_db.create_collection(
433-
collection_name,
434-
dimension=vector_length,
435-
metric=similarity_metric,
434+
definition = (
435+
CollectionDefinition.builder()
436+
.set_vector_dimension(dimension=vector_length)
437+
.set_vector_metric(similarity_metric)
438+
.build()
436439
)
440+
(astra_db.create_collection(collection_name, definition=definition),)
437441
return True
438442
logger.debug(f"collection with name '{collection_name}' already exists, skipping creation")
439443
return False
440444

441-
def delete_by_record_id(self, collection: "AstraDBCollection", file_data: FileData):
445+
async def delete_by_record_id(self, collection: "AstraDBAsyncCollection", file_data: FileData):
442446
logger.debug(
443447
f"deleting records from collection {collection.name} "
444448
f"with {self.upload_config.record_id_key} "
445449
f"set to {file_data.identifier}"
446450
)
447451
delete_filter = {self.upload_config.record_id_key: {"$eq": file_data.identifier}}
448-
delete_resp = collection.delete_many(filter=delete_filter)
452+
delete_resp = await collection.delete_many(filter=delete_filter)
449453
logger.debug(
450454
f"deleted {delete_resp.deleted_count} records from collection {collection.name}"
451455
)
452456

453-
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
457+
async def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
454458
logger.info(
455459
f"writing {len(data)} objects to destination "
456460
f"collection {self.upload_config.collection_name}"
457461
)
458462

459463
astra_db_batch_size = self.upload_config.batch_size
460-
collection = self.get_collection()
461-
462-
self.delete_by_record_id(collection=collection, file_data=file_data)
464+
async_astra_collection = await get_async_astra_collection(
465+
connection_config=self.connection_config,
466+
collection_name=self.upload_config.collection_name,
467+
keyspace=self.upload_config.keyspace,
468+
)
463469

464-
for chunk in batch_generator(data, astra_db_batch_size):
465-
collection.insert_many(chunk)
470+
await self.delete_by_record_id(collection=async_astra_collection, file_data=file_data)
471+
await asyncio.gather(
472+
*[
473+
async_astra_collection.insert_many(chunk)
474+
for chunk in batch_generator(data, astra_db_batch_size)
475+
]
476+
)
466477

467-
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
478+
async def run_async(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
468479
data = get_json_data(path=path)
469-
self.run_data(data=data, file_data=file_data, **kwargs)
480+
await self.run_data(data=data, file_data=file_data)
481+
482+
def run(self, **kwargs: Any) -> Any:
483+
raise NotImplementedError("Use astradb run_async instead")
470484

471485

472486
astra_db_source_entry = SourceRegistryEntry(

uv.lock

+11-12
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)