Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@

logger = logging.getLogger(__name__)


def _is_huggingface_hosted_url(url: str | None) -> bool:
"""True if url is HF-hosted (huggingface.co or hf.space)."""
if not url:
return False
url_lower = url.lower().strip()
return "huggingface.co" in url_lower or "hf.space" in url_lower


VALID_TASKS = (
"text2text-generation",
"text-generation",
Expand Down Expand Up @@ -234,6 +243,11 @@ def validate_environment(self) -> Self:
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)
# Local/custom endpoint URL -> don't pass HF token (avoids 401s and egress).
if self.endpoint_url and not _is_huggingface_hosted_url(self.endpoint_url):
client_api_key: str | None = None
else:
client_api_key = huggingfacehub_api_token

from huggingface_hub import ( # type: ignore[import]
AsyncInferenceClient, # type: ignore[import]
Expand All @@ -245,7 +259,7 @@ def validate_environment(self) -> Self:
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
api_key=huggingfacehub_api_token,
api_key=client_api_key,
provider=self.provider, # type: ignore[arg-type]
**{
key: value
Expand All @@ -258,7 +272,7 @@ def validate_environment(self) -> Self:
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
api_key=huggingfacehub_api_token,
api_key=client_api_key,
provider=self.provider, # type: ignore[arg-type]
**{
key: value
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Tests for HuggingFaceEndpoint with local/custom endpoint_url (no HF API calls)."""

from unittest.mock import MagicMock, patch

import pytest

from langchain_huggingface.llms.huggingface_endpoint import (
HuggingFaceEndpoint,
_is_huggingface_hosted_url,
)


@pytest.mark.parametrize(
("url", "expected"),
[
(None, False),
("", False),
("http://localhost:8010/", False),
("http://127.0.0.1:8080", False),
("http://my-tgi.internal/", False),
("https://api.inference-api.azure-api.net/", False),
("https://abc.huggingface.co/inference", True),
("https://xyz.hf.space/", True),
],
)
def test_is_huggingface_hosted_url(
url: str | None,
expected: bool, # noqa: FBT001
) -> None:
"""URL helper: local/custom vs HF-hosted."""
assert _is_huggingface_hosted_url(url) is expected


@patch(
"huggingface_hub.AsyncInferenceClient",
)
@patch("huggingface_hub.InferenceClient")
def test_local_endpoint_does_not_pass_api_key(
mock_inference_client: MagicMock,
mock_async_client: MagicMock,
) -> None:
"""With a local endpoint_url we don't pass api_key so the client doesn't hit HF."""
mock_inference_client.return_value = MagicMock()
mock_async_client.return_value = MagicMock()

HuggingFaceEndpoint( # type: ignore[call-arg]
endpoint_url="http://localhost:8010/",
max_new_tokens=64,
)

mock_inference_client.assert_called_once()
call_kwargs = mock_inference_client.call_args[1]
assert call_kwargs.get("api_key") is None
assert call_kwargs.get("model") == "http://localhost:8010/"

mock_async_client.assert_called_once()
async_call_kwargs = mock_async_client.call_args[1]
assert async_call_kwargs.get("api_key") is None


@patch("huggingface_hub.AsyncInferenceClient")
@patch("huggingface_hub.InferenceClient")
def test_huggingface_hosted_endpoint_keeps_api_key(
mock_inference_client: MagicMock,
mock_async_client: MagicMock,
) -> None:
"""HF-hosted endpoint_url still gets the token."""
mock_inference_client.return_value = MagicMock()
mock_async_client.return_value = MagicMock()

HuggingFaceEndpoint( # type: ignore[call-arg]
endpoint_url="https://abc.huggingface.co/inference",
max_new_tokens=64,
huggingfacehub_api_token="hf_xxx", # noqa: S106
)

call_kwargs = mock_inference_client.call_args[1]
assert call_kwargs.get("api_key") == "hf_xxx"