Skip to content

Commit

Permalink
Make session auth overridable with provided params, add validation er…
Browse files Browse the repository at this point in the history
…rors (#7)

* Add session auth overrides and errors

* Bump version

* Fix

* More

* Format and lint
  • Loading branch information
sfc-gh-nmoiseyev authored Sep 27, 2024
1 parent 093daee commit dfeba3b
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 25 deletions.
73 changes: 63 additions & 10 deletions libs/snowflake/langchain_snowflake/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def format_docs(docs):
""" # noqa: E501

sp_session: Optional[Session] = Field(alias="sp_session")
sp_session: Session = Field(alias="sp_session")
"""Snowpark session object."""

_sp_root: Root
Expand Down Expand Up @@ -188,19 +188,18 @@ def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment needed to establish a Snowflake session or obtain
an API root from a provided Snowflake session."""

if "sp_session" not in values:
values["database"] = get_from_dict_or_env(
values, "database", "SNOWFLAKE_DATABASE"
)
values["schema"] = get_from_dict_or_env(values, "schema", "SNOWFLAKE_SCHEMA")

if "sp_session" not in values or values["sp_session"] is None:
values["username"] = get_from_dict_or_env(
values, "username", "SNOWFLAKE_USERNAME"
)
values["account"] = get_from_dict_or_env(
values, "account", "SNOWFLAKE_ACCOUNT"
)
values["database"] = get_from_dict_or_env(
values, "database", "SNOWFLAKE_DATABASE"
)
values["schema"] = get_from_dict_or_env(
values, "schema", "SNOWFLAKE_SCHEMA"
)
values["role"] = get_from_dict_or_env(values, "role", "SNOWFLAKE_ROLE")

# check whether to authenticate with password or authenticator
Expand Down Expand Up @@ -243,6 +242,40 @@ def validate_environment(cls, values: Dict) -> Dict:
except Exception as e:
raise CortexSearchRetrieverError(f"Failed to create session: {e}")

else:
# If a session is provided, make sure other authentication parameters
# are not provided.
for param in [
"username",
"password",
"account",
"role",
"authenticator",
]:
if param in values:
raise CortexSearchRetrieverError(
f"Provided both a Snowflake session and a"
f"{'n' if param in ['account', 'authenticator'] else ''} "
f"`{param}`. If a Snowflake session is provided, do not "
f"provide any other authentication parameters (username, "
f"password, account, role, authenticator)."
)

# If overridable parameters are not provided, use the value from the session
for param, method in [
("database", "get_current_database"),
("schema", "get_current_schema"),
]:
if param not in values:
session_value = getattr(values["sp_session"], method)()
if session_value is None:
raise CortexSearchRetrieverError(
f"Snowflake {param} not set on the provided session. Pass "
f"the {param} as an argument, set it as an environment "
f"variable, or provide it in your session configuration."
)
values[param] = session_value

return values

def __init__(self, **kwargs: Any) -> None:
Expand All @@ -254,6 +287,26 @@ def _columns(self, cols: List[str] = []) -> List[str]:
override_cols = cols if cols else self.columns
return [self.search_column] + override_cols

@property
def _database(self) -> str:
"""The Snowflake database containing the Cortex Search Service."""
if self.snowflake_database is not None:
return self.snowflake_database
database = self.sp_session.get_current_database()
if database is None:
raise CortexSearchRetrieverError("Snowflake database not set on session")
return str(database)

@property
def _schema(self) -> str:
"""The Snowflake schema containing the Cortex Search Service."""
if self.snowflake_schema is not None:
return self.snowflake_schema
schema = self.sp_session.get_current_schema()
if schema is None:
raise CortexSearchRetrieverError("Snowflake schema not set on session")
return str(schema)

@property
def _default_params(self) -> Dict[str, Any]:
"""Default query parameters for the Cortex Search Service retriever. Can be
Expand All @@ -279,8 +332,8 @@ def _get_relevant_documents(
) -> List[Document]:
try:
response = (
self._sp_root.databases[self.snowflake_database]
.schemas[self.snowflake_schema]
self._sp_root.databases[self._database]
.schemas[self._schema]
.cortex_search_services[self.cortex_search_service]
.search(
query=str(query),
Expand Down
2 changes: 1 addition & 1 deletion libs/snowflake/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-snowflake"
version = "0.1.1"
version = "0.1.2"
description = "An integration package connecting Snowflake and LangChain"
authors = []
readme = "README.md"
Expand Down
142 changes: 128 additions & 14 deletions libs/snowflake/tests/integration_tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""

import os
from typing import List
from unittest import mock

import pytest
from langchain_core.documents import Document
from pydantic import ValidationError
from snowflake.snowpark import Session

from langchain_snowflake import CortexSearchRetriever, CortexSearchRetrieverError

Expand All @@ -46,12 +48,7 @@ def test_snowflake_cortex_search_invoke() -> None:
assert len(documents) > 0

for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
for column in columns:
if column == search_column:
continue
assert column in doc.metadata
check_document(doc, columns, search_column)
# Validate the filter was passed through correctly
assert doc.metadata["era"] == "Jurassic"

Expand Down Expand Up @@ -193,18 +190,119 @@ def test_snowflake_cortex_search_invoke_columns() -> None:

retriever = CortexSearchRetriever(**kwargs)

documents = retriever.invoke("dinosaur with a large tail", columns=["description"])
override_columns = ["description"]
documents = retriever.invoke("dinosaur with a large tail", columns=override_columns)

assert len(documents) == 10

for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
check_document(doc, override_columns, search_column)
assert "era" not in doc.metadata


@pytest.mark.requires("snowflake.core")
def test_snowflake_cortex_search_session_auth() -> None:
"""Test authentication with a provided `snowlfake.snowpark.Session object`."""

columns = ["description", "era"]
search_column = "description"

kwargs = {
"search_service": "dinosaur_svc",
"columns": columns,
"search_column": search_column,
"limit": 10,
}

session_config = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USERNAME"],
"password": os.environ["SNOWFLAKE_PASSWORD"],
"database": os.environ["SNOWFLAKE_DATABASE"],
"schema": os.environ["SNOWFLAKE_SCHEMA"],
"role": os.environ["SNOWFLAKE_ROLE"],
}

session = Session.builder.configs(session_config).create()

retriever = CortexSearchRetriever(sp_session=session, **kwargs)

documents = retriever.invoke("dinosaur with a large tail")
assert len(documents) > 0

for doc in documents:
check_document(doc, columns, search_column)


@pytest.mark.requires("snowflake.core")
def test_snowflake_cortex_search_session_auth_validation_error() -> None:
"""Test validation errors when both a `snowlfake.snowpark.Session object` and
another authentication paramter are provided."""

columns = ["name", "description", "era", "diet"]
search_column = "description"
kwargs = {
"search_service": "dinosaur_svc",
"columns": columns,
"search_column": search_column,
"limit": 10,
}

session_config = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USERNAME"],
"password": os.environ["SNOWFLAKE_PASSWORD"],
"database": os.environ["SNOWFLAKE_DATABASE"],
"schema": os.environ["SNOWFLAKE_SCHEMA"],
"role": os.environ["SNOWFLAKE_ROLE"],
}

session = Session.builder.configs(session_config).create()

for param in ["account", "user", "password", "role", "authenticator"]:
with pytest.raises(CortexSearchRetrieverError):
kwargs[param] = "fake_value"
CortexSearchRetriever(sp_session=session, **kwargs)
del kwargs[param]


@pytest.mark.requires("snowflake.core")
def test_snowflake_cortex_search_session_auth_overrides() -> None:
"""Test overrides to the provided `snowlfake.snowpark.Session object`."""

columns = ["name", "description", "era", "diet"]
search_column = "description"
kwargs = {
"search_service": "dinosaur_svc",
"columns": columns,
"search_column": search_column,
"limit": 10,
}

session_config = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USERNAME"],
"password": os.environ["SNOWFLAKE_PASSWORD"],
"database": os.environ["SNOWFLAKE_DATABASE"],
"schema": os.environ["SNOWFLAKE_SCHEMA"],
"role": os.environ["SNOWFLAKE_ROLE"],
}

for param in ["database", "schema"]:
session_config_copy = session_config.copy()
del session_config_copy[param]
session = Session.builder.configs(session_config_copy).create()

retriever = CortexSearchRetriever(sp_session=session, **kwargs)

documents = retriever.invoke("dinosaur with a large tail")
assert len(documents) > 0

check_documents(documents, columns, search_column)


@pytest.mark.skip(
"""This test requires a Snowflake account with externalbrowser authentication
"""This test requires a Snowflake account with externalbrowser authentication
enabled."""
)
@pytest.mark.requires("snowflake.core")
Expand All @@ -213,11 +311,11 @@ def test_snowflake_cortex_search_constructor_externalbrowser_authenticator() ->
"""Test the constructor with external browser authenticator."""

columns = ["name", "description", "era", "diet"]

search_column = "description"
kwargs = {
"search_service": "dinosaur_svc",
"columns": columns,
"search_column": "description",
"search_column": search_column,
"limit": 10,
"authenticator": "externalbrowser",
}
Expand All @@ -226,6 +324,22 @@ def test_snowflake_cortex_search_constructor_externalbrowser_authenticator() ->

documents = retriever.invoke("dinosaur with a large tail")
assert len(documents) > 0
check_documents(documents, columns, search_column)


def check_document(doc: Document, columns: List[str], search_column: str) -> None:
"""Check the document returned by the retriever."""
assert isinstance(doc, Document)
assert doc.page_content
for column in columns:
if column == search_column:
continue
assert column in doc.metadata


def check_documents(
documents: List[Document], columns: List[str], search_column: str
) -> None:
"""Check the documents returned by the retriever."""
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
check_document(doc, columns, search_column)

0 comments on commit dfeba3b

Please sign in to comment.