Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions integrations/snowflake/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "snowflake-haystack"
dynamic = ["version"]
description = 'A Snowflake integration for the Haystack framework.'
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "Apache-2.0"
keywords = []
authors = [{ name = "deepset GmbH", email = "[email protected]" },
Expand All @@ -16,7 +16,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -25,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"haystack-ai",
"haystack-ai>=2.22.0",
"adbc_driver_snowflake>=1.4.0",
"polars[pandas,pyarrow]>=1.23.0",
"snowflake-connector-python>=3.12.0",
Expand Down Expand Up @@ -82,7 +81,6 @@ check_untyped_defs = true
disallow_incomplete_defs = true

[tool.ruff]
target-version = "py39"
line-length = 120

[tool.ruff.lint]
Expand Down Expand Up @@ -130,10 +128,6 @@ ignore = [
# Unused method argument
"ARG002"
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.lint.isort]
known-first-party = ["haystack_integrations"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal
from urllib.parse import quote

from haystack import logging
Expand Down Expand Up @@ -49,13 +49,13 @@ class SnowflakeAuthenticator:
def __init__(
self,
authenticator: Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"],
api_key: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008
private_key_file: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008
private_key_file_pwd: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE_PWD", strict=False), # noqa: B008
oauth_client_id: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_CLIENT_ID", strict=False), # noqa: B008
oauth_client_secret: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_CLIENT_SECRET", strict=False), # noqa: B008
oauth_token_request_url: Optional[str] = None,
oauth_authorization_url: Optional[str] = None,
api_key: Secret | None = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008
private_key_file: Secret | None = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008
private_key_file_pwd: Secret | None = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE_PWD", strict=False), # noqa: B008
oauth_client_id: Secret | None = Secret.from_env_var("SNOWFLAKE_CLIENT_ID", strict=False), # noqa: B008
oauth_client_secret: Secret | None = Secret.from_env_var("SNOWFLAKE_CLIENT_SECRET", strict=False), # noqa: B008
oauth_token_request_url: str | None = None,
oauth_authorization_url: str | None = None,
) -> None:
"""
Initialize the authenticator with the specified authentication method.
Expand Down Expand Up @@ -98,7 +98,7 @@ def validate_auth_params(self) -> None:
if not self.api_key:
raise ValueError(ERROR_API_KEY_REQUIRED)

def read_private_key_content(self) -> Optional[str]:
def read_private_key_content(self) -> str | None:
"""
Reads the private key file content for ADBC compatibility.

Expand All @@ -121,7 +121,7 @@ def read_private_key_content(self) -> Optional[str]:
msg = f"Failed to read private key file: {e!s}"
raise PrivateKeyReadError(msg) from e

def _build_jwt_auth_params(self, user: Optional[str] = None) -> list[str]:
def _build_jwt_auth_params(self, user: str | None = None) -> list[str]:
"""
Builds JWT authentication parameters for ADBC.

Expand Down Expand Up @@ -172,7 +172,7 @@ def _build_oauth_auth_params(self) -> list[str]:

return params

def build_auth_params(self, user: Optional[str] = None) -> list[str]:
def build_auth_params(self, user: str | None = None) -> list[str]:
"""
Builds authentication parameters for the connection URI.

Expand All @@ -186,7 +186,7 @@ def build_auth_params(self, user: Optional[str] = None) -> list[str]:
return self._build_oauth_auth_params()
return []

def get_password_for_uri(self) -> Optional[str]:
def get_password_for_uri(self) -> str | None:
"""
Gets the password for URI construction in SNOWFLAKE authentication.

Expand Down Expand Up @@ -221,7 +221,7 @@ def create_masked_params(self, params: list) -> list[str]:

return masked_params

def test_connection(self, user: str, account: str, database: Optional[str] = None) -> bool:
def test_connection(self, user: str, account: str, database: str | None = None) -> bool:
"""
Tests the connection with the provided credentials.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Literal, Optional, Union
from typing import Any, Literal
from urllib.parse import quote_plus

import polars as pl
Expand Down Expand Up @@ -102,18 +102,18 @@ def __init__(
user: str,
account: str,
authenticator: Literal["SNOWFLAKE", "SNOWFLAKE_JWT", "OAUTH"] = "SNOWFLAKE",
api_key: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008
database: Optional[str] = None,
db_schema: Optional[str] = None,
warehouse: Optional[str] = None,
login_timeout: Optional[int] = 60,
api_key: Secret | None = Secret.from_env_var("SNOWFLAKE_API_KEY", strict=False), # noqa: B008
database: str | None = None,
db_schema: str | None = None,
warehouse: str | None = None,
login_timeout: int | None = 60,
return_markdown: bool = True,
private_key_file: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008
private_key_file_pwd: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD", strict=False), # noqa: B008
oauth_client_id: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID", strict=False), # noqa: B008
oauth_client_secret: Optional[Secret] = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET", strict=False), # noqa: B008
oauth_token_request_url: Optional[str] = None,
oauth_authorization_url: Optional[str] = None,
private_key_file: Secret | None = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_FILE", strict=False), # noqa: B008
private_key_file_pwd: Secret | None = Secret.from_env_var("SNOWFLAKE_PRIVATE_KEY_PWD", strict=False), # noqa: B008
oauth_client_id: Secret | None = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_ID", strict=False), # noqa: B008
oauth_client_secret: Secret | None = Secret.from_env_var("SNOWFLAKE_OAUTH_CLIENT_SECRET", strict=False), # noqa: B008
oauth_token_request_url: str | None = None,
oauth_authorization_url: str | None = None,
) -> None:
"""
:param user: User's login.
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
self.oauth_client_secret = oauth_client_secret
self.oauth_token_request_url = oauth_token_request_url
self.oauth_authorization_url = oauth_authorization_url
self.authenticator_handler: Optional[SnowflakeAuthenticator] = None
self.authenticator_handler: SnowflakeAuthenticator | None = None
self._warmed_up = False

def warm_up(self) -> None:
Expand Down Expand Up @@ -336,7 +336,7 @@ def _polars_to_md(data: pl.DataFrame) -> str:
)
return ""

def _execute_query_with_connector(self, query: str) -> Optional[pl.DataFrame]:
def _execute_query_with_connector(self, query: str) -> pl.DataFrame | None:
"""
Executes a query using snowflake-connector-python directly (for JWT authentication).
This bypasses ADBC compatibility issues.
Expand Down Expand Up @@ -400,7 +400,7 @@ def _execute_query_with_connector(self, query: str) -> Optional[pl.DataFrame]:
return None

@staticmethod
def _empty_response() -> dict[str, Union[DataFrame, str]]:
def _empty_response() -> dict[str, DataFrame | str]:
"""Returns a standardized empty response.

:returns:
Expand All @@ -411,7 +411,7 @@ def _empty_response() -> dict[str, Union[DataFrame, str]]:
return {"dataframe": DataFrame(), "table": ""}

@component.output_types(dataframe=DataFrame, table=str)
def run(self, query: str, return_markdown: Optional[bool] = None) -> dict[str, Union[DataFrame, str]]:
def run(self, query: str, return_markdown: bool | None = None) -> dict[str, DataFrame | str]:
"""
Executes a SQL query against a Snowflake database using ADBC and Polars.

Expand Down
10 changes: 5 additions & 5 deletions integrations/snowflake/tests/test_snowflake_table_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Any, Optional
from typing import Any
from unittest.mock import Mock
from urllib.parse import quote_plus

Expand Down Expand Up @@ -191,10 +191,10 @@ def test_snowflake_uri_constructor(
mocker: Mock,
user: str,
account: str,
db_name: Optional[str],
schema_name: Optional[str],
warehouse_name: Optional[str],
expected_uri: Optional[str],
db_name: str | None,
schema_name: str | None,
warehouse_name: str | None,
expected_uri: str | None,
should_raise: bool,
) -> None:
mocker.patch.dict(os.environ, {"SNOWFLAKE_API_KEY": "test_api_key"})
Expand Down