|
3 | 3 | from contextlib import contextmanager
|
4 | 4 | from dataclasses import dataclass, field
|
5 | 5 | from datetime import date, datetime
|
| 6 | +from pathlib import Path |
6 | 7 | from typing import TYPE_CHECKING, Any, Generator, Optional
|
7 | 8 |
|
8 | 9 | from dateutil import parser
|
|
15 | 16 | AccessConfig,
|
16 | 17 | ConnectionConfig,
|
17 | 18 | FileData,
|
18 |
| - Uploader, |
19 | 19 | UploaderConfig,
|
20 | 20 | UploadStager,
|
21 | 21 | UploadStagerConfig,
|
| 22 | + VectorDBUploader, |
22 | 23 | )
|
23 | 24 | from unstructured_ingest.v2.logger import logger
|
24 | 25 |
|
@@ -160,7 +161,9 @@ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
|
160 | 161 |
|
161 | 162 |
|
162 | 163 | class WeaviateUploaderConfig(UploaderConfig):
|
163 |
| - collection: str = Field(description="The name of the collection this object belongs to") |
| 164 | + collection: Optional[str] = Field( |
| 165 | + description="The name of the collection this object belongs to", default=None |
| 166 | + ) |
164 | 167 | batch_size: Optional[int] = Field(default=None, description="Number of records per batch")
|
165 | 168 | requests_per_minute: Optional[int] = Field(default=None, description="Rate limit for upload")
|
166 | 169 | dynamic_batch: bool = Field(default=True, description="Whether to use dynamic batch")
|
@@ -205,17 +208,50 @@ def get_batch_client(self, client: "WeaviateClient") -> Generator["BatchClient",
|
205 | 208 |
|
206 | 209 |
|
207 | 210 | @dataclass
|
208 |
| -class WeaviateUploader(Uploader, ABC): |
| 211 | +class WeaviateUploader(VectorDBUploader, ABC): |
209 | 212 | upload_config: WeaviateUploaderConfig
|
210 | 213 | connection_config: WeaviateConnectionConfig
|
211 | 214 |
|
| 215 | + def _collection_exists(self, collection_name: Optional[str] = None): |
| 216 | + collection_name = collection_name or self.upload_config.collection |
| 217 | + with self.connection_config.get_client() as weaviate_client: |
| 218 | + return weaviate_client.collections.exists(name=collection_name) |
| 219 | + |
212 | 220 | def precheck(self) -> None:
|
213 | 221 | try:
|
214 | 222 | self.connection_config.get_client()
|
| 223 | + # only if collection name populated should we check that it exists |
| 224 | + if self.upload_config.collection and not self._collection_exists(): |
| 225 | + raise DestinationConnectionError( |
| 226 | + f"collection '{self.upload_config.collection}' does not exist" |
| 227 | + ) |
215 | 228 | except Exception as e:
|
216 | 229 | logger.error(f"Failed to validate connection {e}", exc_info=True)
|
217 | 230 | raise DestinationConnectionError(f"failed to validate connection: {e}")
|
218 | 231 |
|
| 232 | + def init(self, *kwargs: Any) -> None: |
| 233 | + self.create_destination() |
| 234 | + |
| 235 | + def create_destination( |
| 236 | + self, destination_name: str = "elements", vector_length: Optional[int] = None, **kwargs: Any |
| 237 | + ) -> bool: |
| 238 | + collection_name = self.upload_config.collection or destination_name |
| 239 | + self.upload_config.collection = collection_name |
| 240 | + connectors_dir = Path(__file__).parents[1] |
| 241 | + collection_config_file = connectors_dir / "assets" / "weaviate_collection_config.json" |
| 242 | + with collection_config_file.open() as f: |
| 243 | + collection_config = json.load(f) |
| 244 | + collection_config["class"] = collection_name |
| 245 | + if not self._collection_exists(): |
| 246 | + logger.info( |
| 247 | + f"creating default weaviate collection '{collection_name}' with default configs" |
| 248 | + ) |
| 249 | + with self.connection_config.get_client() as weaviate_client: |
| 250 | + weaviate_client.collections.create_from_dict(config=collection_config) |
| 251 | + return True |
| 252 | + logger.debug(f"collection with name '{collection_name}' already exists, skipping creation") |
| 253 | + return False |
| 254 | + |
219 | 255 | def check_for_errors(self, client: "WeaviateClient") -> None:
|
220 | 256 | failed_uploads = client.batch.failed_objects
|
221 | 257 | if failed_uploads:
|
@@ -253,6 +289,8 @@ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None
|
253 | 289 | f"writing {len(data)} objects to destination "
|
254 | 290 | f"class {self.connection_config.access_config} "
|
255 | 291 | )
|
| 292 | + if not self.upload_config.collection: |
| 293 | + raise ValueError("No collection specified") |
256 | 294 |
|
257 | 295 | with self.connection_config.get_client() as weaviate_client:
|
258 | 296 | self.delete_by_record_id(client=weaviate_client, file_data=file_data)
|
|
0 commit comments