1
+ import asyncio
1
2
import csv
2
3
import hashlib
4
+ import os
3
5
import re
4
6
from dataclasses import dataclass , field
5
7
from pathlib import Path
8
10
9
11
from pydantic import BaseModel , Field , Secret
10
12
11
- from unstructured_ingest import __name__ as integration_name
12
13
from unstructured_ingest .__version__ import __version__ as integration_version
13
14
from unstructured_ingest .data_types .file_data import (
14
15
BatchFileData ,
@@ -83,10 +84,8 @@ def get_client(self) -> "AstraDBClient":
83
84
84
85
# Create a client object to interact with the Astra DB
85
86
# caller_name/version for Astra DB tracking
86
- return AstraDBClient (
87
- caller_name = integration_name ,
88
- caller_version = integration_version ,
89
- )
87
+ user_agent = os .getenv ("UNSTRUCTURED_USER_AGENT" , "unstructuredio_oss" )
88
+ return AstraDBClient (callers = [(user_agent , integration_version )])
90
89
91
90
92
91
def get_astra_db (
@@ -141,7 +140,7 @@ async def get_async_astra_collection(
141
140
)
142
141
143
142
# Get async collection from AsyncDatabase
144
- async_astra_db_collection = await async_astra_db .get_collection (name = collection_name )
143
+ async_astra_db_collection = async_astra_db .get_collection (name = collection_name )
145
144
return async_astra_db_collection
146
145
147
146
@@ -360,13 +359,22 @@ class AstraDBUploader(Uploader):
360
359
upload_config : AstraDBUploaderConfig
361
360
connector_type : str = CONNECTOR_TYPE
362
361
362
+ def is_async (self ) -> bool :
363
+ return True
364
+
363
365
def init (self , ** kwargs : Any ) -> None :
364
366
self .create_destination (** kwargs )
365
367
368
+ @requires_dependencies (["astrapy" ], extras = "astradb" )
366
369
def precheck (self ) -> None :
367
370
try :
368
371
if self .upload_config .collection_name :
369
- self .get_collection (collection_name = self .upload_config .collection_name ).options ()
372
+ collection = get_astra_collection (
373
+ connection_config = self .connection_config ,
374
+ collection_name = self .upload_config .collection_name ,
375
+ keyspace = self .upload_config .keyspace ,
376
+ )
377
+ collection .options ()
370
378
else :
371
379
# check for db connection only if collection name is not provided
372
380
get_astra_db (
@@ -377,17 +385,7 @@ def precheck(self) -> None:
377
385
logger .error (f"Failed to validate connection { e } " , exc_info = True )
378
386
raise DestinationConnectionError (f"failed to validate connection: { e } " )
379
387
380
- @requires_dependencies (["astrapy" ], extras = "astradb" )
381
- def get_collection (self , collection_name : Optional [str ] = None ) -> "AstraDBCollection" :
382
- return get_astra_collection (
383
- connection_config = self .connection_config ,
384
- collection_name = collection_name or self .upload_config .collection_name ,
385
- keyspace = self .upload_config .keyspace ,
386
- )
387
-
388
388
def _collection_exists (self , collection_name : str ):
389
- from astrapy .exceptions import CollectionNotFoundException
390
-
391
389
collection = get_astra_collection (
392
390
connection_config = self .connection_config ,
393
391
collection_name = collection_name ,
@@ -397,8 +395,10 @@ def _collection_exists(self, collection_name: str):
397
395
try :
398
396
collection .options ()
399
397
return True
400
- except CollectionNotFoundException :
401
- return False
398
+ except RuntimeError as e :
399
+ if "not found" in str (e ):
400
+ return False
401
+ raise DestinationConnectionError (f"failed to check if astra collection exists : { e } " )
402
402
except Exception as e :
403
403
logger .error (f"failed to check if astra collection exists : { e } " )
404
404
raise DestinationConnectionError (f"failed to check if astra collection exists : { e } " )
@@ -422,51 +422,65 @@ def create_destination(
422
422
self .upload_config .collection_name = collection_name
423
423
424
424
if not self ._collection_exists (collection_name ):
425
+ from astrapy .info import CollectionDefinition
426
+
425
427
astra_db = get_astra_db (
426
428
connection_config = self .connection_config , keyspace = self .upload_config .keyspace
427
429
)
428
430
logger .info (
429
431
f"creating default astra collection '{ collection_name } ' with dimension "
430
432
f"{ vector_length } and metric { similarity_metric } "
431
433
)
432
- astra_db .create_collection (
433
- collection_name ,
434
- dimension = vector_length ,
435
- metric = similarity_metric ,
434
+ definition = (
435
+ CollectionDefinition .builder ()
436
+ .set_vector_dimension (dimension = vector_length )
437
+ .set_vector_metric (similarity_metric )
438
+ .build ()
436
439
)
440
+ (astra_db .create_collection (collection_name , definition = definition ),)
437
441
return True
438
442
logger .debug (f"collection with name '{ collection_name } ' already exists, skipping creation" )
439
443
return False
440
444
441
- def delete_by_record_id (self , collection : "AstraDBCollection " , file_data : FileData ):
445
+ async def delete_by_record_id (self , collection : "AstraDBAsyncCollection " , file_data : FileData ):
442
446
logger .debug (
443
447
f"deleting records from collection { collection .name } "
444
448
f"with { self .upload_config .record_id_key } "
445
449
f"set to { file_data .identifier } "
446
450
)
447
451
delete_filter = {self .upload_config .record_id_key : {"$eq" : file_data .identifier }}
448
- delete_resp = collection .delete_many (filter = delete_filter )
452
+ delete_resp = await collection .delete_many (filter = delete_filter )
449
453
logger .debug (
450
454
f"deleted { delete_resp .deleted_count } records from collection { collection .name } "
451
455
)
452
456
453
- def run_data (self , data : list [dict ], file_data : FileData , ** kwargs : Any ) -> None :
457
+ async def run_data (self , data : list [dict ], file_data : FileData , ** kwargs : Any ) -> None :
454
458
logger .info (
455
459
f"writing { len (data )} objects to destination "
456
460
f"collection { self .upload_config .collection_name } "
457
461
)
458
462
459
463
astra_db_batch_size = self .upload_config .batch_size
460
- collection = self .get_collection ()
461
-
462
- self .delete_by_record_id (collection = collection , file_data = file_data )
464
+ async_astra_collection = await get_async_astra_collection (
465
+ connection_config = self .connection_config ,
466
+ collection_name = self .upload_config .collection_name ,
467
+ keyspace = self .upload_config .keyspace ,
468
+ )
463
469
464
- for chunk in batch_generator (data , astra_db_batch_size ):
465
- collection .insert_many (chunk )
470
+ await self .delete_by_record_id (collection = async_astra_collection , file_data = file_data )
471
+ await asyncio .gather (
472
+ * [
473
+ async_astra_collection .insert_many (chunk )
474
+ for chunk in batch_generator (data , astra_db_batch_size )
475
+ ]
476
+ )
466
477
467
- def run (self , path : Path , file_data : FileData , ** kwargs : Any ) -> None :
478
+ async def run_async (self , path : Path , file_data : FileData , ** kwargs : Any ) -> None :
468
479
data = get_json_data (path = path )
469
- self .run_data (data = data , file_data = file_data , ** kwargs )
480
+ await self .run_data (data = data , file_data = file_data )
481
+
482
+ def run (self , ** kwargs : Any ) -> Any :
483
+ raise NotImplementedError ("Use astradb run_async instead" )
470
484
471
485
472
486
astra_db_source_entry = SourceRegistryEntry (
0 commit comments