Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class InferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"textclf"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class AsyncInferenceClient:
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"textclf"`, `"together"`, `"wavespeed"` or `"zai-org"`.
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from .textclf import TextCLFConversationalTask, TextCLFTextGenerationTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
from .wavespeed import (
WavespeedAIImageToImageTask,
Expand Down Expand Up @@ -84,6 +85,7 @@
"replicate",
"sambanova",
"scaleway",
"textclf",
"together",
"wavespeed",
"zai-org",
Expand Down Expand Up @@ -200,6 +202,10 @@
"conversational": ScalewayConversationalTask(),
"feature-extraction": ScalewayFeatureExtractionTask(),
},
"textclf": {
"text-generation": TextCLFTextGenerationTask(),
"conversational": TextCLFConversationalTask(),
},
"together": {
"text-to-image": TogetherTextToImageTask(),
"conversational": TogetherConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"replicate": {},
"sambanova": {},
"scaleway": {},
"textclf": {},
"together": {},
"wavespeed": {},
"zai-org": {},
Expand Down
36 changes: 36 additions & 0 deletions src/huggingface_hub/inference/_providers/textclf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Optional, Union

from huggingface_hub.inference._common import RequestParameters, _as_dict
from huggingface_hub.inference._providers._common import (
BaseConversationalTask,
BaseTextGenerationTask,
)

_PROVIDER = "textclf"
_BASE_URL = "https://api.textclf.com"


class TextCLFTextGenerationTask(BaseTextGenerationTask):
def __init__(self):
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/v1/chat/completions"

def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
output = _as_dict(response)["choices"][0]
return {
"generated_text": output["text"],
"details": {
"finish_reason": output.get("finish_reason"),
"seed": output.get("seed"),
},
}


class TextCLFConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/v1/chat/completions"
4 changes: 4 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@
"sambanova": {
"conversational": "meta-llama/Llama-3.1-8B-Instruct",
},
"textclf": {
"text-generation": "meta-llama/Llama-3.1-8B-Instruct",
"conversational": "meta-llama/Llama-3.1-8B-Instruct",
},
}

CHAT_COMPLETION_MODEL = "HuggingFaceH4/zephyr-7b-beta"
Expand Down
11 changes: 11 additions & 0 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from huggingface_hub.inference._providers.scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
from huggingface_hub.inference._providers.textclf import TextCLFConversationalTask, TextCLFTextGenerationTask
from huggingface_hub.inference._providers.together import TogetherTextToImageTask
from huggingface_hub.inference._providers.wavespeed import (
WavespeedAIImageToImageTask,
Expand Down Expand Up @@ -1728,6 +1729,16 @@ def test_prepare_url_feature_extraction(self):
== "https://router.huggingface.co/sambanova/v1/embeddings"
)

class TestTextCLFProvider:
def test_prepare_url_text_generation(self):
helper = TextCLFTextGenerationTask()
url = helper._prepare_url("textclf_token", "username/repo_name")
assert url == "https://api.textclf.com/v1/chat/completions"

def test_prepare_url_conversational(self):
helper = TextCLFConversationalTask()
url = helper._prepare_url("textclf_token", "username/repo_name")
assert url == "https://api.textclf.com/v1/chat/completions"

class TestTogetherProvider:
def test_prepare_route_text_to_image(self):
Expand Down