Skip to content

Commit b31ffee

Browse files
jxnlcursoragent
andauthored
GenAI config labels loss (#2005)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
1 parent 3b86812 commit b31ffee

File tree

2 files changed

+151
-14
lines changed

2 files changed

+151
-14
lines changed

instructor/providers/gemini/utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,13 @@ def update_genai_kwargs(
330330
}
331331
)
332332

333-
# Extract thinking_config from user's config object if provided
334-
# This ensures thinking_config inside config parameter is not ignored
333+
# Extract thinking_config from user's config if provided (dict or object)
334+
# This ensures thinking_config inside config parameter is not ignored.
335335
user_config = new_kwargs.get("config")
336336
user_thinking_config = None
337-
if user_config is not None and hasattr(user_config, "thinking_config"):
337+
if isinstance(user_config, dict):
338+
user_thinking_config = user_config.get("thinking_config")
339+
elif user_config is not None and hasattr(user_config, "thinking_config"):
338340
user_thinking_config = user_config.thinking_config
339341

340342
# Handle thinking_config parameter - prioritize kwarg over config.thinking_config
@@ -345,19 +347,25 @@ def update_genai_kwargs(
345347
if thinking_config is not None:
346348
base_config["thinking_config"] = thinking_config
347349

348-
# Extract other relevant fields from user's config object
349-
# This ensures fields like automatic_function_calling are not ignored
350+
# Extract other relevant fields from user's config (dict or object).
351+
# This ensures fields like automatic_function_calling / labels / cached_content
352+
# are not ignored when config is passed as a dict.
350353
if user_config is not None:
351354
config_fields_to_merge = [
352355
"automatic_function_calling",
353356
"labels",
354357
"cached_content",
355358
]
356359
for field in config_fields_to_merge:
357-
if hasattr(user_config, field):
360+
if isinstance(user_config, dict):
361+
field_value = user_config.get(field)
362+
elif hasattr(user_config, field):
358363
field_value = getattr(user_config, field)
359-
if field_value is not None and field not in base_config:
360-
base_config[field] = field_value
364+
else:
365+
field_value = None
366+
367+
if field_value is not None and field not in base_config:
368+
base_config[field] = field_value
361369

362370
return base_config
363371

@@ -882,12 +890,16 @@ def handle_genai_structured_outputs(
882890
if new_kwargs.get("stream", False) and not issubclass(response_model, PartialBase):
883891
response_model = Partial[response_model]
884892

885-
# Extract thinking_config from user-provided config object if present
886-
# This fixes issue #1966 where thinking_config inside config was ignored
893+
# Extract thinking_config and cached_content from user-provided config (dict or object).
894+
# This fixes issue #1966 (thinking_config ignored) and ensures cached_content
895+
# is detected even when config is provided as a dict.
887896
user_config = new_kwargs.get("config")
888897
user_thinking_config = None
889898
user_cached_content = None
890-
if user_config is not None:
899+
if isinstance(user_config, dict):
900+
user_thinking_config = user_config.get("thinking_config")
901+
user_cached_content = user_config.get("cached_content")
902+
elif user_config is not None:
891903
if hasattr(user_config, "thinking_config"):
892904
user_thinking_config = user_config.thinking_config
893905
if hasattr(user_config, "cached_content"):
@@ -965,12 +977,16 @@ def handle_genai_tools(
965977
if new_kwargs.get("stream", False) and not issubclass(response_model, PartialBase):
966978
response_model = Partial[response_model]
967979

968-
# Extract thinking_config and cached_content from user-provided config object if present
969-
# This fixes issue #1966 where thinking_config inside config was ignored
980+
# Extract thinking_config and cached_content from user-provided config (dict or object).
981+
# This fixes issue #1966 (thinking_config ignored) and ensures cached_content
982+
# is detected even when config is provided as a dict.
970983
user_config = new_kwargs.get("config")
971984
user_thinking_config = None
972985
user_cached_content = None
973-
if user_config is not None:
986+
if isinstance(user_config, dict):
987+
user_thinking_config = user_config.get("thinking_config")
988+
user_cached_content = user_config.get("cached_content")
989+
elif user_config is not None:
974990
if hasattr(user_config, "thinking_config"):
975991
user_thinking_config = user_config.thinking_config
976992
if hasattr(user_config, "cached_content"):

tests/test_genai_config_merging.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,124 @@ class TestModel(BaseModel):
397397
assert result_config.tools is not None
398398
assert result_config.tool_config is not None
399399
assert result_config.system_instruction is not None
400+
401+
402+
def test_update_genai_kwargs_config_dict_labels():
403+
"""Test that labels is merged when config is provided as a dict (issue #1759)."""
404+
kwargs = {"config": {"labels": {"env": "prod", "team": "ml"}}}
405+
base_config: dict[str, object] = {}
406+
407+
result = update_genai_kwargs(kwargs, base_config)
408+
409+
assert result["labels"] == {"env": "prod", "team": "ml"}
410+
411+
412+
def test_update_genai_kwargs_config_dict_cached_content():
413+
"""Test that cached_content is merged when config is provided as a dict."""
414+
kwargs = {"config": {"cached_content": "caches/dict123"}}
415+
base_config: dict[str, object] = {}
416+
417+
result = update_genai_kwargs(kwargs, base_config)
418+
419+
assert result["cached_content"] == "caches/dict123"
420+
421+
422+
def test_update_genai_kwargs_config_dict_thinking_config():
423+
"""Test that thinking_config is merged when config is provided as a dict."""
424+
thinking_config = {"thinking_budget": 1234}
425+
kwargs = {"config": {"thinking_config": thinking_config}}
426+
base_config: dict[str, object] = {}
427+
428+
result = update_genai_kwargs(kwargs, base_config)
429+
430+
assert result["thinking_config"] == thinking_config
431+
432+
433+
def test_handle_genai_structured_outputs_preserves_labels_from_config_dict():
434+
"""Test that labels are preserved when config is provided as a dict (issue #1759)."""
435+
from pydantic import BaseModel
436+
437+
from instructor.providers.gemini.utils import handle_genai_structured_outputs
438+
439+
class TestModel(BaseModel):
440+
name: str
441+
442+
new_kwargs = {
443+
"messages": [{"role": "user", "content": "Hello"}],
444+
"config": {"labels": {"tenant": "acme", "cost-center": "123"}},
445+
}
446+
447+
_, result_kwargs = handle_genai_structured_outputs(TestModel, new_kwargs)
448+
449+
result_config = result_kwargs["config"]
450+
assert result_config.labels == {"tenant": "acme", "cost-center": "123"}
451+
452+
453+
def test_handle_genai_tools_preserves_labels_from_config_dict():
454+
"""Test that labels are preserved in tools mode when config is a dict (issue #1759)."""
455+
from pydantic import BaseModel
456+
457+
from instructor.providers.gemini.utils import handle_genai_tools
458+
459+
class TestModel(BaseModel):
460+
name: str
461+
462+
new_kwargs = {
463+
"messages": [{"role": "user", "content": "Hello"}],
464+
"config": {"labels": {"tenant": "acme", "cost-center": "123"}},
465+
}
466+
467+
_, result_kwargs = handle_genai_tools(TestModel, new_kwargs)
468+
469+
result_config = result_kwargs["config"]
470+
assert result_config.labels == {"tenant": "acme", "cost-center": "123"}
471+
472+
473+
def test_handle_genai_structured_outputs_skips_system_instruction_with_cached_content_dict():
474+
"""Test cached_content dict config disables system_instruction in structured outputs."""
475+
from pydantic import BaseModel
476+
477+
from instructor.providers.gemini.utils import handle_genai_structured_outputs
478+
479+
class TestModel(BaseModel):
480+
name: str
481+
482+
new_kwargs = {
483+
"messages": [
484+
{"role": "system", "content": "You are a helpful assistant."},
485+
{"role": "user", "content": "Hello"},
486+
],
487+
"config": {"cached_content": "caches/dict-cache-1"},
488+
}
489+
490+
_, result_kwargs = handle_genai_structured_outputs(TestModel, new_kwargs)
491+
492+
result_config = result_kwargs["config"]
493+
assert result_config.cached_content == "caches/dict-cache-1"
494+
assert result_config.system_instruction is None
495+
496+
497+
def test_handle_genai_tools_skips_tools_and_system_instruction_with_cached_content_dict():
498+
"""Test cached_content dict config disables tools/tool_config/system_instruction in tools mode."""
499+
from pydantic import BaseModel
500+
501+
from instructor.providers.gemini.utils import handle_genai_tools
502+
503+
class TestModel(BaseModel):
504+
name: str
505+
506+
new_kwargs = {
507+
"messages": [
508+
{"role": "system", "content": "You are a helpful assistant."},
509+
{"role": "user", "content": "Hello"},
510+
],
511+
"config": {"cached_content": "caches/dict-cache-2"},
512+
}
513+
514+
_, result_kwargs = handle_genai_tools(TestModel, new_kwargs)
515+
516+
result_config = result_kwargs["config"]
517+
assert result_config.cached_content == "caches/dict-cache-2"
518+
assert result_config.system_instruction is None
519+
assert result_config.tools is None
520+
assert result_config.tool_config is None

0 commit comments

Comments
 (0)