Skip to content
Open
6 changes: 6 additions & 0 deletions libs/langchain/langchain_classic/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also expect this to be applied to libs/langchain_v1/langchain/embeddings/base.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The changes are applied to the langchain_v1 lib as well.

"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
Expand Down Expand Up @@ -155,6 +156,7 @@ def init_embeddings(
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
Expand Down Expand Up @@ -207,6 +209,10 @@ def init_embeddings(
from langchain_openai import AzureOpenAIEmbeddings

return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings

return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

Expand Down
27 changes: 15 additions & 12 deletions libs/langchain/tests/unit_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
)


def test_parse_model_string() -> None:
@pytest.mark.parametrize(
("model_string", "expected_provider", "expected_model"),
[
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
],
)
def test_parse_model_string(
model_string: str, expected_provider: str, expected_model: str
) -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
assert _parse_model_string(model_string) == (
expected_provider,
expected_model,
)


Expand Down
5 changes: 5 additions & 0 deletions libs/langchain_v1/langchain/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
Expand Down Expand Up @@ -207,6 +208,10 @@ def init_embeddings(
from langchain_openai import AzureOpenAIEmbeddings

return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings

return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

Expand Down
25 changes: 13 additions & 12 deletions libs/langchain_v1/tests/unit_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
)


def test_parse_model_string() -> None:
@pytest.mark.parametrize(
("model_string", "expected_provider", "expected_model"),
[
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
],
)
def test_parse_model_string(model_string: str, expected_provider: str, expected_model: str) -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
assert _parse_model_string(model_string) == (
expected_provider,
expected_model,
)


Expand Down