Skip to content

Commit d8b9400

Browse files
azibekmdrxyCopilot
authored
fix(huggingface): pass llm params to ChatHuggingFace (#32368)
This PR fixes #32234 and improves HuggingFace chat model integration by: Ensuring ChatHuggingFace inherits key parameters (temperature, max_tokens, top_p, streaming, etc.) from the underlying LLM when not explicitly set. Adding and updating unit tests to verify property inheritance. No breaking changes; these updates enhance reliability and maintainability. --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Mason Daugherty <github@mdrxy.com>
1 parent cf595dc commit d8b9400

2 files changed

Lines changed: 165 additions & 0 deletions

File tree

libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,57 @@ class GetPopulation(BaseModel):
513513

514514
def __init__(self, **kwargs: Any):
515515
super().__init__(**kwargs)
516+
517+
# Inherit properties from the LLM if they weren't explicitly set
518+
self._inherit_llm_properties()
519+
516520
self._resolve_model_id()
517521

522+
def _inherit_llm_properties(self) -> None:
523+
"""Inherit properties from the wrapped LLM instance if not explicitly set."""
524+
if not hasattr(self, "llm") or self.llm is None:
525+
return
526+
527+
# Map of ChatHuggingFace properties to LLM properties
528+
property_mappings = {
529+
"temperature": "temperature",
530+
"max_tokens": "max_new_tokens", # Different naming convention
531+
"top_p": "top_p",
532+
"seed": "seed",
533+
"streaming": "streaming",
534+
"stop": "stop_sequences",
535+
}
536+
537+
# Inherit properties from LLM and not explicitly set here
538+
for chat_prop, llm_prop in property_mappings.items():
539+
if hasattr(self.llm, llm_prop):
540+
llm_value = getattr(self.llm, llm_prop)
541+
chat_value = getattr(self, chat_prop, None)
542+
if not chat_value and llm_value:
543+
setattr(self, chat_prop, llm_value)
544+
545+
# Handle special cases for HuggingFaceEndpoint
546+
if _is_huggingface_endpoint(self.llm):
547+
# Inherit additional HuggingFaceEndpoint specific properties
548+
endpoint_mappings = {
549+
"frequency_penalty": "repetition_penalty",
550+
}
551+
552+
for chat_prop, llm_prop in endpoint_mappings.items():
553+
if hasattr(self.llm, llm_prop):
554+
llm_value = getattr(self.llm, llm_prop)
555+
chat_value = getattr(self, chat_prop, None)
556+
if chat_value is None and llm_value is not None:
557+
setattr(self, chat_prop, llm_value)
558+
559+
# Inherit model_kwargs if not explicitly set
560+
if (
561+
not self.model_kwargs
562+
and hasattr(self.llm, "model_kwargs")
563+
and isinstance(self.llm.model_kwargs, dict)
564+
):
565+
self.model_kwargs = self.llm.model_kwargs.copy()
566+
518567
@model_validator(mode="after")
519568
def validate_llm(self) -> Self:
520569
if (

libs/partners/huggingface/tests/unit_tests/test_chat_models.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323
def mock_llm() -> Mock:
2424
llm = Mock(spec=HuggingFaceEndpoint)
2525
llm.inference_server_url = "test endpoint url"
26+
llm.temperature = 0.7
27+
llm.max_new_tokens = 512
28+
llm.top_p = 0.9
29+
llm.seed = 42
30+
llm.streaming = True
31+
llm.repetition_penalty = 1.1
32+
llm.stop_sequences = ["</s>", "<|end|>"]
33+
llm.model_kwargs = {"do_sample": True, "top_k": 50}
34+
llm.server_kwargs = {"timeout": 120}
35+
llm.repo_id = "test/model"
36+
llm.model = "test/model"
2637
return llm
2738

2839

@@ -209,3 +220,108 @@ def test_bind_tools(chat_hugging_face: Any) -> None:
209220
_, kwargs = mock_super_bind.call_args
210221
assert kwargs["tools"] == tools
211222
assert kwargs["tool_choice"] == "auto"
223+
224+
225+
def test_property_inheritance_integration(chat_hugging_face: Any) -> None:
226+
"""Test that ChatHuggingFace inherits params from LLM object."""
227+
assert getattr(chat_hugging_face, "temperature", None) == 0.7
228+
assert getattr(chat_hugging_face, "max_tokens", None) == 512
229+
assert getattr(chat_hugging_face, "top_p", None) == 0.9
230+
assert getattr(chat_hugging_face, "streaming", None) is True
231+
232+
233+
def test_default_params_includes_inherited_values(chat_hugging_face: Any) -> None:
234+
"""Test that _default_params includes inherited max_tokens from max_new_tokens."""
235+
params = chat_hugging_face._default_params
236+
assert params["max_tokens"] == 512 # inherited from LLM's max_new_tokens
237+
assert params["temperature"] == 0.7 # inherited from LLM's temperature
238+
assert params["stream"] is True # inherited from LLM's streaming
239+
240+
241+
def test_create_message_dicts_includes_inherited_params(chat_hugging_face: Any) -> None:
242+
"""Test that _create_message_dicts includes inherited parameters in API call."""
243+
messages = [HumanMessage(content="test message")]
244+
message_dicts, params = chat_hugging_face._create_message_dicts(messages, None)
245+
246+
# Verify inherited parameters are included
247+
assert params["max_tokens"] == 512
248+
assert params["temperature"] == 0.7
249+
assert params["stream"] is True
250+
251+
# Verify message conversion
252+
assert len(message_dicts) == 1
253+
assert message_dicts[0]["role"] == "user"
254+
assert message_dicts[0]["content"] == "test message"
255+
256+
257+
def test_model_kwargs_inheritance(mock_llm: Any) -> None:
258+
"""Test that model_kwargs are inherited when not explicitly set."""
259+
with patch(
260+
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
261+
):
262+
chat = ChatHuggingFace(llm=mock_llm)
263+
assert chat.model_kwargs == {"do_sample": True, "top_k": 50}
264+
265+
266+
def test_huggingface_endpoint_specific_inheritance(mock_llm: Any) -> None:
267+
"""Test HuggingFaceEndpoint specific parameter inheritance."""
268+
with (
269+
patch(
270+
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
271+
),
272+
patch(
273+
"langchain_huggingface.chat_models.huggingface._is_huggingface_endpoint",
274+
return_value=True,
275+
),
276+
):
277+
chat = ChatHuggingFace(llm=mock_llm)
278+
assert (
279+
getattr(chat, "frequency_penalty", None) == 1.1
280+
) # from repetition_penalty
281+
282+
283+
def test_parameter_precedence_explicit_over_inherited(mock_llm: Any) -> None:
284+
"""Test that explicitly set parameters take precedence over inherited ones."""
285+
with patch(
286+
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
287+
):
288+
# Explicitly set max_tokens to override inheritance
289+
chat = ChatHuggingFace(llm=mock_llm, max_tokens=256, temperature=0.5)
290+
assert chat.max_tokens == 256 # explicit value, not inherited 512
291+
assert chat.temperature == 0.5 # explicit value, not inherited 0.7
292+
293+
294+
def test_inheritance_with_no_llm_properties(mock_llm: Any) -> None:
295+
"""Test inheritance when LLM doesn't have expected properties."""
296+
# Remove some properties from mock
297+
del mock_llm.temperature
298+
del mock_llm.top_p
299+
300+
with patch(
301+
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
302+
):
303+
chat = ChatHuggingFace(llm=mock_llm)
304+
# Should still inherit available properties
305+
assert chat.max_tokens == 512 # max_new_tokens still available
306+
# Missing properties should remain None/default
307+
assert getattr(chat, "temperature", None) is None
308+
assert getattr(chat, "top_p", None) is None
309+
310+
311+
def test_inheritance_with_empty_llm() -> None:
312+
"""Test that inheritance handles LLM with no relevant attributes gracefully."""
313+
with patch(
314+
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
315+
):
316+
# Create a minimal mock LLM that passes validation but has no
317+
# inheritance attributes
318+
empty_llm = Mock(spec=HuggingFaceEndpoint)
319+
empty_llm.repo_id = "test/model"
320+
empty_llm.model = "test/model"
321+
# Mock doesn't have the inheritance attributes by default
322+
323+
chat = ChatHuggingFace(llm=empty_llm)
324+
# Properties should remain at their default values when LLM has no
325+
# relevant attrs
326+
assert chat.max_tokens is None
327+
assert chat.temperature is None

0 commit comments

Comments
 (0)