Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## [1.4.5]

* **fix: add capability to use opensearch serverless**

## [1.4.4]

* **fix: add table precheck to teradata source**
Expand Down
117 changes: 117 additions & 0 deletions test/integration/connectors/elasticsearch/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from _pytest.fixtures import TopRequest
from opensearchpy import Document, Keyword, OpenSearch, Text
from opensearchpy.exceptions import TransportError

from test.integration.connectors.utils.constants import DESTINATION_TAG, NOSQL_TAG, SOURCE_TAG
from test.integration.connectors.utils.docker import HealthCheck, container_context
Expand Down Expand Up @@ -618,10 +619,20 @@ def test_opensearch_destination_iam_precheck_fail_invalid_credentials():
("https://abc123xyz.us-east-1.aoss.amazonaws.com", "us-east-1", "aoss"),
("https://abc123xyz.eu-west-1.aoss.amazonaws.com", "eu-west-1", "aoss"),
("https://abc123xyz.us-gov-west-1.aoss.amazonaws.com", "us-gov-west-1", "aoss"),
# FIPS-compliant endpoints
("https://abc123xyz.us-east-1.aoss-fips.amazonaws.com", "us-east-1", "aoss"),
("https://abc123xyz.us-east-2.aoss-fips.amazonaws.com", "us-east-2", "aoss"),
("https://abc123xyz.us-gov-west-1.aoss-fips.amazonaws.com", "us-gov-west-1", "aoss"),
("https://abc123xyz.ca-central-1.aoss-fips.amazonaws.com", "ca-central-1", "aoss"),
("https://search-domain.us-east-1.es-fips.amazonaws.com", "us-east-1", "es"),
("https://search-domain.us-east-2.es-fips.amazonaws.com", "us-east-2", "es"),
("https://search-domain.us-gov-west-1.es-fips.amazonaws.com", "us-gov-west-1", "es"),
("https://search-domain.ca-central-1.es-fips.amazonaws.com", "ca-central-1", "es"),
# Without https://
("search-domain.us-east-1.es.amazonaws.com", "us-east-1", "es"),
# With port
("https://search-domain.us-east-1.es.amazonaws.com:443", "us-east-1", "es"),
("https://abc123xyz.us-east-1.aoss-fips.amazonaws.com:443", "us-east-1", "aoss"),
],
)
def test_detect_aws_opensearch_config_valid(hostname, expected_region, expected_service):
Expand Down Expand Up @@ -668,6 +679,112 @@ def test_opensearch_uploader_config_batch_size_default():
)


@pytest.mark.asyncio
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
@pytest.mark.parametrize(
("status_code", "error_class"),
[
(403, "AuthorizationException"),
(400, "RequestError"),
(404, "NotFoundError"),
],
)
async def test_opensearch_indexer_pit_fallback_to_scroll(status_code, error_class):
"""Test that _get_doc_ids_async falls back to scroll when create_pit fails.

Covers: 403 (missing permissions), 400 (unsupported endpoint), 404 (pre-2.4 OpenSearch).
The fallback is scoped to the create_pit call only.
"""
from unittest.mock import AsyncMock, patch

from opensearchpy.exceptions import TransportError

connection_config = OpenSearchConnectionConfig(
access_config=OpenSearchAccessConfig(password="admin"),
username="admin",
hosts=["http://localhost:9200"],
use_ssl=True,
)
indexer = OpenSearchIndexer(
connection_config=connection_config,
index_config=OpenSearchIndexerConfig(index_name="test_index"),
)

expected_ids = {"id1", "id2", "id3"}

mock_client = AsyncMock()
mock_client.create_pit = AsyncMock(side_effect=TransportError(status_code, error_class))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)

with (
patch("opensearchpy.AsyncOpenSearch", return_value=mock_client),
patch.object(
indexer, "_get_doc_ids_scroll", new_callable=AsyncMock, return_value=expected_ids
) as mock_scroll,
patch.object(
connection_config,
"get_async_client_kwargs",
new_callable=AsyncMock,
return_value={"hosts": ["http://localhost:9200"]},
),
):
result = await indexer._get_doc_ids_async()

mock_client.create_pit.assert_called_once()
mock_scroll.assert_called_once()
assert result == expected_ids


@pytest.mark.asyncio
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
@pytest.mark.parametrize(
("exception",),
[
(TransportError(500, "internal_server_error"),),
(ConnectionError("cluster unreachable"),),
],
)
async def test_opensearch_indexer_pit_no_fallback_on_other_errors(exception):
"""Test that _get_doc_ids_async re-raises non-fallback errors without trying scroll.

Covers: 500 (server error), ConnectionError (network). These should NOT fall back.
"""
from unittest.mock import AsyncMock, patch

connection_config = OpenSearchConnectionConfig(
access_config=OpenSearchAccessConfig(password="admin"),
username="admin",
hosts=["http://localhost:9200"],
use_ssl=True,
)
indexer = OpenSearchIndexer(
connection_config=connection_config,
index_config=OpenSearchIndexerConfig(index_name="test_index"),
)

mock_client = AsyncMock()
mock_client.create_pit = AsyncMock(side_effect=exception)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)

with (
patch("opensearchpy.AsyncOpenSearch", return_value=mock_client),
patch.object(indexer, "_get_doc_ids_scroll", new_callable=AsyncMock) as mock_scroll,
patch.object(
connection_config,
"get_async_client_kwargs",
new_callable=AsyncMock,
return_value={"hosts": ["http://localhost:9200"]},
),
):
with pytest.raises(type(exception)):
await indexer._get_doc_ids_async()

mock_client.create_pit.assert_called_once()
mock_scroll.assert_not_called()


@pytest.mark.asyncio
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
async def test_opensearch_connection_config_retry_settings():
Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.4" # pragma: no cover
__version__ = "1.4.5" # pragma: no cover
126 changes: 90 additions & 36 deletions unstructured_ingest/processes/connectors/elasticsearch/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@

"""OpenSearch connector - inherits from Elasticsearch connector (OpenSearch is an ES fork)."""

# Precompiled regex patterns for AWS hostname detection (GovCloud, China, standard)
_ES_PATTERN = re.compile(r"\.([a-z]{2}(?:-[a-z]+)+-\d+)\.es\.amazonaws\.com$")
_AOSS_PATTERN = re.compile(r"^[a-z0-9]+\.([a-z]{2}(?:-[a-z]+)+-\d+)\.aoss\.amazonaws\.com$")
# Precompiled regex patterns for AWS hostname detection (GovCloud, China, standard, FIPS)
_ES_PATTERN = re.compile(r"\.([a-z]{2}(?:-[a-z]+)+-\d+)\.es(?:-fips)?\.amazonaws\.com$")
_AOSS_PATTERN = re.compile(
r"^[a-z0-9]+\.([a-z]{2}(?:-[a-z]+)+-\d+)\.aoss(?:-fips)?\.amazonaws\.com$"
)


def _run_coroutine(fn: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -175,8 +177,10 @@ def _detect_and_validate_aws_config(self) -> Tuple[str, str]:
raise ValueError(
f"Could not auto-detect AWS region and service from host: {self.hosts[0]}. "
f"Ensure your host URL follows AWS OpenSearch format: "
f"https://search-domain-xxx.REGION.es.amazonaws.com (for OpenSearch Service) or "
f"https://xxx.REGION.aoss.amazonaws.com (for OpenSearch Serverless)"
f"https://search-domain-xxx.REGION.es.amazonaws.com (for OpenSearch Service), "
f"https://search-domain-xxx.REGION.es-fips.amazonaws.com (for Service with FIPS), "
f"https://xxx.REGION.aoss.amazonaws.com (for OpenSearch Serverless), or "
f"https://xxx.REGION.aoss-fips.amazonaws.com (for Serverless with FIPS)"
)

region, service = detected
Expand Down Expand Up @@ -325,24 +329,76 @@ async def run_async(self, **kwargs: Any) -> AsyncGenerator[ElasticsearchBatchFil

@requires_dependencies(["opensearchpy"], extras="opensearch")
async def _get_doc_ids_async(self) -> set[str]:
"""Fetch document IDs using async_scan."""
from opensearchpy import AsyncOpenSearch
from opensearchpy.helpers import async_scan
"""Fetch all document IDs, trying PIT + search_after first with scroll fallback.

scan_query = {"stored_fields": [], "query": {"match_all": {}}}
PIT is required for OpenSearch Serverless (AOSS) and preferred for
OpenSearch Service. Falls back to scroll if PIT creation fails due to
missing permissions (403) or unsupported version (400/404).
"""
from opensearchpy import AsyncOpenSearch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not part of the pr, but I didn't have AsyncOpenSearch after installing the opensearch deps. Looks like we may need to install opensearch-py[async]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good spot! potentially comes from the requirements.txt -> pyproject.toml change

i will look at that as a new PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually let me add it

from opensearchpy.exceptions import TransportError

async with AsyncOpenSearch(
**await self.connection_config.get_async_client_kwargs()
) as client:
doc_ids = set()
async for hit in async_scan(
client,
query=scan_query,
scroll="1m",
index=self.index_config.index_name,
):
doc_ids.add(hit["_id"])
try:
pit = await client.create_pit(
index=self.index_config.index_name, params={"keep_alive": "5m"}
)
except TransportError as e:
if e.status_code in (400, 403, 404):
logger.warning(
f"PIT creation failed (HTTP {e.status_code}), "
"falling back to scroll. Note: scroll is not supported on AOSS."
)
return await self._get_doc_ids_scroll(client)
raise
return await self._get_doc_ids_pit(client, pit["pit_id"])

async def _get_doc_ids_pit(self, client: Any, pit_id: str) -> set[str]:
"""Paginate through all document IDs using an existing PIT context."""
try:
doc_ids: set[str] = set()
search_after = None
while True:
body: dict[str, Any] = {
"stored_fields": [],
"query": {"match_all": {}},
"pit": {"id": pit_id, "keep_alive": "5m"},
"sort": [{"_id": "asc"}],
"size": 1000,
}
Comment on lines 364 to 374
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_doc_ids_pit is only collecting _ids, but the PIT search request will still return full _source and compute total hits by default. This can make ID enumeration significantly slower / more bandwidth-heavy on large indices. Consider explicitly setting _source: false (or _source: []) and track_total_hits: false in the search body to minimize response payload and avoid unnecessary hit counting.

Copilot uses AI. Check for mistakes.
if search_after:
body["search_after"] = search_after
resp = await client.search(body=body)
if "pit_id" in resp:
pit_id = resp["pit_id"]
hits = resp["hits"]["hits"]
if not hits:
break
for hit in hits:
doc_ids.add(hit["_id"])
search_after = hits[-1]["sort"]
return doc_ids
finally:
try:
await client.delete_pit(body={"pit_id": [pit_id]})
except Exception:
logger.warning("Failed to delete PIT, it will expire automatically")

async def _get_doc_ids_scroll(self, client: Any) -> set[str]:
"""Fetch document IDs using scroll (fallback when PIT is unavailable)."""
from opensearchpy.helpers import async_scan

doc_ids: set[str] = set()
async for hit in async_scan(
client,
query={"stored_fields": [], "query": {"match_all": {}}},
scroll="1m",
index=self.index_config.index_name,
):
doc_ids.add(hit["_id"])
return doc_ids


class OpenSearchDownloaderConfig(ElasticsearchDownloaderConfig):
Expand All @@ -357,34 +413,33 @@ class OpenSearchDownloader(ElasticsearchDownloader):

@requires_dependencies(["opensearchpy"], extras="opensearch")
async def run_async(self, file_data: BatchFileData, **kwargs: Any) -> download_responses:
"""Download documents from OpenSearch."""
"""Download documents from OpenSearch.

Uses a direct search by IDs instead of scroll, since the batch
is already bounded by the indexer's batch_size (typically 100).
"""
from opensearchpy import AsyncOpenSearch
from opensearchpy.helpers import async_scan

elasticsearch_filedata = ElasticsearchBatchFileData.cast(file_data=file_data)

index_name: str = elasticsearch_filedata.additional_metadata.index_name
ids: list[str] = [item.identifier for item in elasticsearch_filedata.batch_items]

scan_query = {
search_body: dict[str, Any] = {
"version": True,
"query": {"ids": {"values": ids}},
"size": len(ids),
}

# Only add _source if fields are explicitly specified (avoids AWS FGAC issues)
if self.download_config.fields:
scan_query["_source"] = self.download_config.fields
search_body["_source"] = self.download_config.fields

download_responses = []
async with AsyncOpenSearch(
**await self.connection_config.get_async_client_kwargs()
) as client:
async for result in async_scan(
client,
query=scan_query,
scroll="1m",
index=index_name,
):
resp = await client.search(body=search_body, index=index_name)
for result in resp["hits"]["hits"]:
download_responses.append(
self.generate_download_response(
result=result, index_name=index_name, file_data=elasticsearch_filedata
Expand Down Expand Up @@ -439,10 +494,10 @@ async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs:
from opensearchpy.exceptions import TransportError
from opensearchpy.helpers import async_bulk

logger.debug(
f"writing {len(data)} elements to index {self.upload_config.index_name} "
f"at {self.connection_config.hosts} "
f"with batch size (bytes) {self.upload_config.batch_size_bytes}"
logger.info(
f"writing {len(data)} elements via document batches to destination "
f"index named {self.upload_config.index_name} at {self.connection_config.hosts} "
f"with batch size (in bytes) {self.upload_config.batch_size_bytes}"
)

async with AsyncOpenSearch(
Expand Down Expand Up @@ -491,12 +546,11 @@ async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs:
f"Failed to upload {len(failed)} out of {len(batch)} documents"
)

logger.debug(
f"uploaded batch of {len(batch)} elements to {self.upload_config.index_name}"
logger.info(
f"uploaded batch of {len(batch)} elements to index "
f"{self.upload_config.index_name}"
)

logger.info(f"Upload complete: {len(data)} elements to {self.upload_config.index_name}")


class OpenSearchUploadStagerConfig(ElasticsearchUploadStagerConfig):
pass
Expand Down