Skip to content

Gemini: Add option to specify config_entry in generate_content service #143776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ConfigEntryError,
ConfigEntryNotReady,
HomeAssistantError,
ServiceValidationError,
)
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
Expand All @@ -35,10 +36,12 @@
RECOMMENDED_CHAT_MODEL,
TIMEOUT_MILLIS,
)
from .model_setup import get_content_config

SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
CONF_FILENAMES = "filenames"
CONF_CONFIG_ENTRY = "config_entry"

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
Expand All @@ -64,11 +67,31 @@
translation_key="deprecated_image_filename_parameter",
)

prompt_parts = [call.data[CONF_PROMPT]]
config_entry: GoogleGenerativeAIConfigEntry
if CONF_CONFIG_ENTRY in call.data:
entry_id = call.data[CONF_CONFIG_ENTRY]
found_entry = hass.config_entries.async_get_entry(entry_id)
if found_entry is None or found_entry.domain != DOMAIN:
raise ServiceValidationError(

Check warning on line 75 in homeassistant/components/google_generative_ai_conversation/__init__.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/google_generative_ai_conversation/__init__.py#L75

Added line #L75 was not covered by tests
translation_domain=DOMAIN,
translation_key="invalid_config_entry",
translation_placeholders={"config_entry": entry_id},
)
config_entry = found_entry
else:
# Deprecated in 2025.6, to remove in 2025.10
async_create_issue(
hass,
DOMAIN,
"missing_config_entry_parameter",
breaks_in_ha_version="2025.10.0",
is_fixable=False,
severity=IssueSeverity.WARNING,
translation_key="missing_config_entry_parameter",
)
config_entry = hass.config_entries.async_loaded_entries(DOMAIN)[0]

config_entry: GoogleGenerativeAIConfigEntry = (
hass.config_entries.async_loaded_entries(DOMAIN)[0]
)
prompt_parts = [call.data[CONF_PROMPT]]

client = config_entry.runtime_data

Expand All @@ -93,9 +116,15 @@

await hass.async_add_executor_job(append_files_to_prompt)

model_name = config_entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)

try:
response = await client.aio.models.generate_content(
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
model=model_name,
# Features like tools and custom prompts are powered by
# HA's Conversation infra, so we cannot use them here.
config=get_content_config(config_entry),
contents=prompt_parts,
)
except (
APIError,
Expand All @@ -120,6 +149,7 @@
schema=vol.Schema(
{
vol.Required(CONF_PROMPT): cv.string,
vol.Optional(CONF_CONFIG_ENTRY): cv.string,
vol.Optional(CONF_IMAGE_FILENAME, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
AutomaticFunctionCallingConfig,
Content,
FunctionDeclaration,
GenerateContentConfig,
GoogleSearch,
HarmCategory,
Part,
SafetySetting,
Schema,
Tool,
)
Expand All @@ -32,25 +29,13 @@

from .const import (
CONF_CHAT_MODEL,
CONF_DANGEROUS_BLOCK_THRESHOLD,
CONF_HARASSMENT_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_SEXUAL_BLOCK_THRESHOLD,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
)
from .model_setup import get_content_config

# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
Expand Down Expand Up @@ -371,50 +356,16 @@ async def _async_handle_message(

if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
generateContentConfig = GenerateContentConfig(
temperature=self.entry.options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
),
top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
max_output_tokens=self.entry.options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=self.entry.options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=self.entry.options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=self.entry.options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=self.entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
],
tools=tools or None,
system_instruction=prompt if supports_system_instruction else None,
automatic_function_calling=AutomaticFunctionCallingConfig(
disable=True, maximum_remote_calls=None
),
)

if not supports_system_instruction:
generateContentConfig = get_content_config(self.entry)
# Set additional config options that are only supported in conversation.
generateContentConfig.tools = tools
generateContentConfig.automatic_function_calling = (
AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None)
)
if supports_system_instruction:
generateContentConfig.system_instruction = prompt
else:
messages = [
Content(role="user", parts=[Part.from_text(text=prompt)]),
Content(role="model", parts=[Part.from_text(text="Ok")]),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Generic helper functions for setting up Gemini models."""

from google.genai import Client
from google.genai.types import GenerateContentConfig, HarmCategory, SafetySetting

from homeassistant.config_entries import ConfigEntry

from .const import (
CONF_DANGEROUS_BLOCK_THRESHOLD,
CONF_HARASSMENT_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD,
CONF_MAX_TOKENS,
CONF_SEXUAL_BLOCK_THRESHOLD,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
)

type GoogleGenerativeAIConfigEntry = ConfigEntry[Client]


def get_content_config(
entry: GoogleGenerativeAIConfigEntry,
) -> GenerateContentConfig:
"""Create parameters for Gemini model inputs from a config entry."""

return GenerateContentConfig(
temperature=entry.options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
top_k=entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
top_p=entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
max_output_tokens=entry.options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=entry.options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=entry.options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=entry.options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
],
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
generate_content:
fields:
config_entry:
required: true
selector:
config_entry:
integration: google_generative_ai_conversation
prompt:
required: true
selector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
"name": "Generate content",
"description": "Generate content from a prompt consisting of text and optionally images",
"fields": {
"config_entry": {
"name": "Config entry",
"description": "The config entry specifying model settings to use for this action"
},
"prompt": {
"name": "Prompt",
"description": "The prompt",
Expand All @@ -76,6 +80,15 @@
"deprecated_image_filename_parameter": {
"title": "Deprecated 'image_filename' parameter",
"description": "The 'image_filename' parameter in Google Generative AI actions is deprecated. Please edit scripts and automations to use 'filenames' instead."
},
"missing_config_entry_parameter": {
"title": "Missing 'Config entry' parameter",
"description": "The 'Config entry' parameter in Google Generative AI actions is required. Please edit scripts and automations to specify a config entry."
}
},
"exceptions": {
"invalid_config_entry": {
"message": "Invalid config entry `{config_entry}`"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction=None, temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=None, tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=None, thinking_config=None),
'contents': list([
'Describe this image from my doorbell camera',
b'some file',
Expand All @@ -23,10 +24,11 @@
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction=None, temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=None, tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=None, thinking_config=None),
'contents': list([
'Write an opening speech for a Home Assistant release party',
]),
'model': 'models/gemini-2.0-flash',
'model': 'fake-test-model',
}),
),
])
Expand All @@ -38,6 +40,7 @@
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction=None, temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=None, tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=None, thinking_config=None),
'contents': list([
'Write an opening speech for a Home Assistant release party',
]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion

from homeassistant.components.google_generative_ai_conversation import CONF_CHAT_MODEL
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
Expand All @@ -20,6 +21,23 @@ async def test_generate_content_service_without_images(
hass: HomeAssistant, snapshot: SnapshotAssertion
) -> None:
"""Test generate content service."""
SAMPLE_MODEL = "fake-test-model"

second_entry = MockConfigEntry(
domain="google_generative_ai_conversation",
title="Google Generative AI Conversation",
data={
"api_key": "bla",
},
options={
CONF_CHAT_MODEL: SAMPLE_MODEL,
},
)
second_entry.add_to_hass(hass)
with patch("google.genai.models.AsyncModels.get"):
await hass.config_entries.async_setup(second_entry.entry_id)
await hass.async_block_till_done()

stubbed_generated_content = (
"I'm thrilled to welcome you all to the release "
"party for the latest version of Home Assistant!"
Expand All @@ -36,14 +54,18 @@ async def test_generate_content_service_without_images(
response = await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "Write an opening speech for a Home Assistant release party"},
{
"config_entry": second_entry.entry_id,
"prompt": "Write an opening speech for a Home Assistant release party",
},
blocking=True,
return_response=True,
)

assert response == {
"text": stubbed_generated_content,
}
assert mock_generate.call_args.kwargs["model"] == SAMPLE_MODEL
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot


Expand Down