Skip to content

Commit cb309ae

Browse files
authored
feat/add code to weaviate uploader to create default collection (#355)
* Add code to weaviate uploader to create default collection * update changelog * support functionality without having to instantiate the uploader * don't create collection if it already exists * change signature * limit default indexed fields * add unit test * add optional init step * always check if collection exists, create if needed * add collection check to precheck only if it's not None * add custom implementation of uploader for vectordbs * use exists() rather than list_all() * fix int test
1 parent 034ae3b commit cb309ae

File tree

11 files changed

+136
-19
lines changed

11 files changed

+136
-19
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
## 0.4.3-dev0
1+
## 0.4.3-dev1
22

33
### Enhancements
44

55
* **Add support for allow list when downloading from raw html**
6+
* **Add support for setting up destination as part of uploader**
67

78
## 0.4.2
89

test/integration/connectors/weaviate/test_local.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def wait_for_container(timeout: int = 10, interval: int = 1) -> None:
2525
start_time = time.time()
2626
while time.time() - start_time < timeout:
2727
try:
28-
requests.get("http://localhost:8080/v1/.well-known/read")
28+
requests.get("http://localhost:8080/v1/.well-known/read", timeout=1)
2929
return
3030
except Exception as e:
3131
print(f"Failed to validate container healthy, sleeping for {interval} seconds: {e}")
@@ -34,15 +34,20 @@ def wait_for_container(timeout: int = 10, interval: int = 1) -> None:
3434

3535

3636
@pytest.fixture
37-
def collection(collections_schema_config: dict) -> str:
37+
def weaviate_instance():
3838
with container_context(
3939
image="semitechnologies/weaviate:1.27.3",
4040
ports={8080: 8080, 50051: 50051},
41-
):
41+
) as ctx:
4242
wait_for_container()
43-
with weaviate.connect_to_local() as weaviate_client:
44-
weaviate_client.collections.create_from_dict(config=collections_schema_config)
45-
yield COLLECTION_NAME
43+
yield ctx
44+
45+
46+
@pytest.fixture
47+
def collection(weaviate_instance, collections_schema_config: dict) -> str:
48+
with weaviate.connect_to_local() as weaviate_client:
49+
weaviate_client.collections.create_from_dict(config=collections_schema_config)
50+
return COLLECTION_NAME
4651

4752

4853
def get_count(client: WeaviateClient) -> int:
@@ -129,3 +134,19 @@ def test_weaviate_local_destination(upload_file: Path, collection: str, tmp_path
129134
file_data=file_data,
130135
expected_count=expected_count,
131136
)
137+
138+
139+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
140+
def test_weaviate_local_create_destination(weaviate_instance):
141+
uploader = LocalWeaviateUploader(
142+
upload_config=LocalWeaviateUploaderConfig(),
143+
connection_config=LocalWeaviateConnectionConfig(),
144+
)
145+
collection_name = "system_created"
146+
created = uploader.create_destination(destination_name=collection_name)
147+
assert created
148+
with uploader.connection_config.get_client() as weaviate_client:
149+
assert weaviate_client.collections.exists(name=collection_name)
150+
151+
created = uploader.create_destination(destination_name=collection_name)
152+
assert not created

unstructured_ingest/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.3-dev0" # pragma: no cover
1+
__version__ = "0.4.3-dev1" # pragma: no cover

unstructured_ingest/v2/interfaces/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .process import BaseProcess
66
from .processor import ProcessorConfig
77
from .upload_stager import UploadStager, UploadStagerConfig
8-
from .uploader import UploadContent, Uploader, UploaderConfig
8+
from .uploader import UploadContent, Uploader, UploaderConfig, VectorDBUploader
99

1010
__all__ = [
1111
"DownloadResponse",
@@ -29,4 +29,5 @@
2929
"FileDataSourceMetadata",
3030
"BatchFileData",
3131
"BatchItem",
32+
"VectorDBUploader",
3233
]

unstructured_ingest/v2/interfaces/process.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ class BaseProcess(ABC):
88
def is_async(self) -> bool:
99
return False
1010

11+
def init(self, *kwargs: Any) -> None:
12+
pass
13+
1114
def precheck(self) -> None:
1215
pass
1316

unstructured_ingest/v2/interfaces/uploader.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC
22
from dataclasses import dataclass
33
from pathlib import Path
4-
from typing import Any, TypeVar
4+
from typing import Any, Optional, TypeVar
55

66
from pydantic import BaseModel
77

@@ -38,6 +38,11 @@ def is_batch(self) -> bool:
3838
def run_batch(self, contents: list[UploadContent], **kwargs: Any) -> None:
3939
raise NotImplementedError()
4040

41+
def create_destination(self, destination_name: str = "elements", **kwargs: Any) -> bool:
42+
# Update the uploader config if needed with a new destination that gets created.
43+
# Return a flag on if anything was created or not.
44+
return False
45+
4146
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
4247
data = get_data(path=path)
4348
self.run_data(data=data, file_data=file_data, **kwargs)
@@ -51,3 +56,11 @@ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None
5156

5257
async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
5358
return self.run_data(data=data, file_data=file_data, **kwargs)
59+
60+
61+
@dataclass
62+
class VectorDBUploader(Uploader, ABC):
63+
def create_destination(
64+
self, destination_name: str = "elements", vector_length: Optional[int] = None, **kwargs: Any
65+
) -> bool:
66+
return False

unstructured_ingest/v2/pipeline/pipeline.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from unstructured_ingest.v2.interfaces import ProcessorConfig, Uploader
1212
from unstructured_ingest.v2.logger import logger, make_default_logger
1313
from unstructured_ingest.v2.otel import OtelHandler
14+
from unstructured_ingest.v2.pipeline.interfaces import PipelineStep
1415
from unstructured_ingest.v2.pipeline.steps.chunk import Chunker, ChunkStep
1516
from unstructured_ingest.v2.pipeline.steps.download import DownloaderT, DownloadStep
1617
from unstructured_ingest.v2.pipeline.steps.embed import Embedder, EmbedStep
@@ -91,10 +92,6 @@ def __post_init__(
9192
self.chunker_step = ChunkStep(process=chunker, context=self.context) if chunker else None
9293

9394
self.embedder_step = EmbedStep(process=embedder, context=self.context) if embedder else None
94-
# TODO: support initialize() call from each step process
95-
# Potential long call to download embedder models, run before any fanout:
96-
if embedder and embedder.config:
97-
embedder.config.get_embedder().initialize()
9895

9996
self.stager_step = UploadStageStep(process=stager, context=self.context) if stager else None
10097
self.uploader_step = UploadStep(process=uploader, context=self.context)
@@ -135,6 +132,7 @@ def run(self):
135132
with otel_handler.get_tracer().start_as_current_span(
136133
"ingest process", record_exception=True
137134
):
135+
self._run_inits()
138136
self._run_prechecks()
139137
self._run()
140138
finally:
@@ -156,7 +154,7 @@ def clean_results(self, results: list[Any | list[Any]] | None) -> list[Any] | No
156154
final = [f for f in flat if f]
157155
return final or None
158156

159-
def _run_prechecks(self):
157+
def _get_all_steps(self) -> list[PipelineStep]:
160158
steps = [self.indexer_step, self.downloader_step, self.partitioner_step, self.uploader_step]
161159
if self.chunker_step:
162160
steps.append(self.chunker_step)
@@ -166,8 +164,24 @@ def _run_prechecks(self):
166164
steps.append(self.uncompress_step)
167165
if self.stager_step:
168166
steps.append(self.stager_step)
167+
return steps
168+
169+
def _run_inits(self):
170+
failures = {}
171+
172+
for step in self._get_all_steps():
173+
try:
174+
step.process.init()
175+
except Exception as e:
176+
failures[step.process.__class__.__name__] = f"[{type(e).__name__}] {e}"
177+
if failures:
178+
for k, v in failures.items():
179+
logger.error(f"Step init failure: {k}: {v}")
180+
raise PipelineError("Init failed")
181+
182+
def _run_prechecks(self):
169183
failures = {}
170-
for step in steps:
184+
for step in self._get_all_steps():
171185
try:
172186
step.process.precheck()
173187
except Exception as e:

unstructured_ingest/v2/processes/connectors/assets/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"properties": [
3+
{
4+
"dataType": [
5+
"text"
6+
],
7+
"indexFilterable": true,
8+
"indexSearchable": true,
9+
"name": "record_id",
10+
"tokenization": "word"
11+
},
12+
{
13+
"dataType": [
14+
"text"
15+
],
16+
"indexFilterable": true,
17+
"indexSearchable": true,
18+
"name": "text",
19+
"tokenization": "word"
20+
}
21+
],
22+
"vectorizer": "none"
23+
}

unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import contextmanager
44
from dataclasses import dataclass, field
55
from datetime import date, datetime
6+
from pathlib import Path
67
from typing import TYPE_CHECKING, Any, Generator, Optional
78

89
from dateutil import parser
@@ -15,10 +16,10 @@
1516
AccessConfig,
1617
ConnectionConfig,
1718
FileData,
18-
Uploader,
1919
UploaderConfig,
2020
UploadStager,
2121
UploadStagerConfig,
22+
VectorDBUploader,
2223
)
2324
from unstructured_ingest.v2.logger import logger
2425

@@ -160,7 +161,9 @@ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
160161

161162

162163
class WeaviateUploaderConfig(UploaderConfig):
163-
collection: str = Field(description="The name of the collection this object belongs to")
164+
collection: Optional[str] = Field(
165+
description="The name of the collection this object belongs to", default=None
166+
)
164167
batch_size: Optional[int] = Field(default=None, description="Number of records per batch")
165168
requests_per_minute: Optional[int] = Field(default=None, description="Rate limit for upload")
166169
dynamic_batch: bool = Field(default=True, description="Whether to use dynamic batch")
@@ -205,17 +208,50 @@ def get_batch_client(self, client: "WeaviateClient") -> Generator["BatchClient",
205208

206209

207210
@dataclass
208-
class WeaviateUploader(Uploader, ABC):
211+
class WeaviateUploader(VectorDBUploader, ABC):
209212
upload_config: WeaviateUploaderConfig
210213
connection_config: WeaviateConnectionConfig
211214

215+
def _collection_exists(self, collection_name: Optional[str] = None):
216+
collection_name = collection_name or self.upload_config.collection
217+
with self.connection_config.get_client() as weaviate_client:
218+
return weaviate_client.collections.exists(name=collection_name)
219+
212220
def precheck(self) -> None:
213221
try:
214222
self.connection_config.get_client()
223+
# only if collection name populated should we check that it exists
224+
if self.upload_config.collection and not self._collection_exists():
225+
raise DestinationConnectionError(
226+
f"collection '{self.upload_config.collection}' does not exist"
227+
)
215228
except Exception as e:
216229
logger.error(f"Failed to validate connection {e}", exc_info=True)
217230
raise DestinationConnectionError(f"failed to validate connection: {e}")
218231

232+
def init(self, *kwargs: Any) -> None:
233+
self.create_destination()
234+
235+
def create_destination(
236+
self, destination_name: str = "elements", vector_length: Optional[int] = None, **kwargs: Any
237+
) -> bool:
238+
collection_name = self.upload_config.collection or destination_name
239+
self.upload_config.collection = collection_name
240+
connectors_dir = Path(__file__).parents[1]
241+
collection_config_file = connectors_dir / "assets" / "weaviate_collection_config.json"
242+
with collection_config_file.open() as f:
243+
collection_config = json.load(f)
244+
collection_config["class"] = collection_name
245+
if not self._collection_exists():
246+
logger.info(
247+
f"creating default weaviate collection '{collection_name}' with default configs"
248+
)
249+
with self.connection_config.get_client() as weaviate_client:
250+
weaviate_client.collections.create_from_dict(config=collection_config)
251+
return True
252+
logger.debug(f"collection with name '{collection_name}' already exists, skipping creation")
253+
return False
254+
219255
def check_for_errors(self, client: "WeaviateClient") -> None:
220256
failed_uploads = client.batch.failed_objects
221257
if failed_uploads:
@@ -253,6 +289,8 @@ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None
253289
f"writing {len(data)} objects to destination "
254290
f"class {self.connection_config.access_config} "
255291
)
292+
if not self.upload_config.collection:
293+
raise ValueError("No collection specified")
256294

257295
with self.connection_config.get_client() as weaviate_client:
258296
self.delete_by_record_id(client=weaviate_client, file_data=file_data)

unstructured_ingest/v2/processes/embedder.py

+3
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def get_embedder(self) -> "BaseEmbeddingEncoder":
184184
class Embedder(BaseProcess, ABC):
185185
config: EmbedderConfig
186186

187+
def init(self, *kwargs: Any) -> None:
188+
self.config.get_embedder().initialize()
189+
187190
def run(self, elements_filepath: Path, **kwargs: Any) -> list[dict]:
188191
# TODO update base embedder classes to support async
189192
embedder = self.config.get_embedder()

0 commit comments

Comments
 (0)