1 """Conversation support for the Google Generative AI Conversation integration."""
3 from __future__
import annotations
6 from collections.abc
import Callable
7 from typing
import Any, Literal
9 from google.api_core.exceptions
import GoogleAPIError
10 import google.generativeai
as genai
11 from google.generativeai
import protos
12 import google.generativeai.types
as genai_types
13 from google.protobuf.json_format
import MessageToDict
14 import voluptuous
as vol
15 from voluptuous_openapi
import convert
29 CONF_DANGEROUS_BLOCK_THRESHOLD,
30 CONF_HARASSMENT_BLOCK_THRESHOLD,
31 CONF_HATE_BLOCK_THRESHOLD,
34 CONF_SEXUAL_BLOCK_THRESHOLD,
40 RECOMMENDED_CHAT_MODEL,
41 RECOMMENDED_HARM_BLOCK_THRESHOLD,
42 RECOMMENDED_MAX_TOKENS,
43 RECOMMENDED_TEMPERATURE,
49 MAX_TOOL_ITERATIONS = 10
54 config_entry: ConfigEntry,
55 async_add_entities: AddEntitiesCallback,
57 """Set up conversation entities."""
62 SUPPORTED_SCHEMA_KEYS = {
75 """Format the schema to protobuf."""
76 if (subschemas := schema.get(
"anyOf"))
or (subschemas := schema.get(
"allOf")):
77 for subschema
in subschemas:
78 if "type" in subschema:
85 for key, val
in schema.items():
86 if key
not in SUPPORTED_SCHEMA_KEYS:
92 if schema.get(
"type") ==
"string" and val !=
"enum":
94 if schema.get(
"type")
not in (
"number",
"integer",
"string"):
99 elif key ==
"properties":
103 if result.get(
"enum")
and result.get(
"type_") !=
"STRING":
107 result[
"type_"] =
"STRING"
108 result[
"enum"] = [
str(item)
for item
in result[
"enum"]]
110 if result.get(
"type_") ==
"OBJECT" and not result.get(
"properties"):
114 result[
"properties"] = {
"json": {
"type_":
"STRING"}}
115 result[
"required"] = []
120 tool: llm.Tool, custom_serializer: Callable[[Any], Any] |
None
122 """Format tool specification."""
124 if tool.parameters.schema:
126 convert(tool.parameters, custom_serializer=custom_serializer)
133 "function_declarations": [
136 "description": tool.description,
137 "parameters": parameters,
145 """Recursively call codecs.escape_decode on all values."""
146 if isinstance(value, str):
147 return codecs.escape_decode(bytes(value,
"utf-8"))[0].decode(
"utf-8")
148 if isinstance(value, list):
150 if isinstance(value, dict):
158 """Google Generative AI conversation agent."""
160 _attr_has_entity_name =
True
164 """Initialize the agent."""
166 self.history: dict[str, list[genai_types.ContentType]] = {}
169 identifiers={(DOMAIN, entry.entry_id)},
171 manufacturer=
"Google",
172 model=
"Generative AI",
173 entry_type=dr.DeviceEntryType.SERVICE,
175 if self.
entryentry.options.get(CONF_LLM_HASS_API):
177 conversation.ConversationEntityFeature.CONTROL
182 """Return a list of supported languages."""
186 """When entity is added to Home Assistant."""
188 assist_pipeline.async_migrate_engine(
191 conversation.async_set_agent(self.
hasshass, self.
entryentry, self)
192 self.
entryentry.async_on_unload(
197 """When entity will be removed from Home Assistant."""
198 conversation.async_unset_agent(self.
hasshass, self.
entryentry)
202 self, user_input: conversation.ConversationInput
204 """Process a sentence."""
206 response=intent.IntentResponse(language=user_input.language),
207 conversation_id=user_input.conversation_id
208 if user_input.conversation_id
in self.history
209 else ulid.ulid_now(),
211 assert result.conversation_id
213 llm_context = llm.LLMContext(
215 context=user_input.context,
216 user_prompt=user_input.text,
217 language=user_input.language,
218 assistant=conversation.DOMAIN,
219 device_id=user_input.device_id,
221 llm_api: llm.APIInstance |
None =
None
222 tools: list[dict[str, Any]] |
None =
None
223 if self.
entryentry.options.get(CONF_LLM_HASS_API):
225 llm_api = await llm.async_get_api(
227 self.
entryentry.options[CONF_LLM_HASS_API],
230 except HomeAssistantError
as err:
231 LOGGER.error(
"Error getting LLM API: %s", err)
232 result.response.async_set_error(
233 intent.IntentResponseErrorCode.UNKNOWN,
234 f
"Error preparing LLM API: {err}",
238 _format_tool(tool, llm_api.custom_serializer)
for tool
in llm_api.tools
243 except TemplateError
as err:
244 LOGGER.error(
"Error rendering prompt: %s", err)
245 result.response.async_set_error(
246 intent.IntentResponseErrorCode.UNKNOWN,
247 f
"Sorry, I had a problem with my template: {err}",
251 model_name = self.
entryentry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
255 supports_system_instruction = (
256 "gemini-1.0" not in model_name
and "gemini-pro" not in model_name
259 model = genai.GenerativeModel(
260 model_name=model_name,
262 "temperature": self.
entryentry.options.get(
263 CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
265 "top_p": self.
entryentry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
266 "top_k": self.
entryentry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
267 "max_output_tokens": self.
entryentry.options.get(
268 CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
272 "HARASSMENT": self.
entryentry.options.get(
273 CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
275 "HATE": self.
entryentry.options.get(
276 CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
278 "SEXUAL": self.
entryentry.options.get(
279 CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
281 "DANGEROUS": self.
entryentry.options.get(
282 CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
286 system_instruction=prompt
if supports_system_instruction
else None,
289 messages = self.history.
get(result.conversation_id, [])
290 if not supports_system_instruction:
292 messages = [{}, {
"role":
"model",
"parts":
"Ok"}]
293 messages[0] = {
"role":
"user",
"parts": prompt}
295 LOGGER.debug(
"Input: '%s' with history: %s", user_input.text, messages)
296 trace.async_conversation_trace_append(
297 trace.ConversationTraceEventType.AGENT_DETAIL,
300 "messages": messages[:]
301 if supports_system_instruction
304 "tools": [*llm_api.tools]
if llm_api
else None,
308 chat = model.start_chat(history=messages)
309 chat_request = user_input.text
311 for _iteration
in range(MAX_TOOL_ITERATIONS):
313 chat_response = await chat.send_message_async(chat_request)
317 genai_types.BlockedPromptException,
318 genai_types.StopCandidateException,
320 LOGGER.error(
"Error sending message: %s %s", type(err), err)
323 err, genai_types.StopCandidateException
324 )
and "finish_reason: SAFETY\n" in str(err):
325 error =
"The message got blocked by your safety settings"
328 f
"Sorry, I had a problem talking to Google Generative AI: {err}"
331 result.response.async_set_error(
332 intent.IntentResponseErrorCode.UNKNOWN,
337 LOGGER.debug(
"Response: %s", chat_response.parts)
338 if not chat_response.parts:
339 result.response.async_set_error(
340 intent.IntentResponseErrorCode.UNKNOWN,
341 "Sorry, I had a problem getting a response from Google Generative AI.",
344 self.history[result.conversation_id] = chat.history
346 part.function_call
for part
in chat_response.parts
if part.function_call
348 if not function_calls
or not llm_api:
352 for function_call
in function_calls:
353 tool_call = MessageToDict(function_call._pb)
354 tool_name = tool_call[
"name"]
356 LOGGER.debug(
"Tool call: %s(%s)", tool_name, tool_args)
357 tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
359 function_response = await llm_api.async_call_tool(tool_input)
360 except (HomeAssistantError, vol.Invalid)
as e:
361 function_response = {
"error": type(e).__name__}
363 function_response[
"error_text"] =
str(e)
365 LOGGER.debug(
"Tool response: %s", function_response)
366 tool_responses.append(
368 function_response=protos.FunctionResponse(
369 name=tool_name, response=function_response
373 chat_request = protos.Content(parts=tool_responses)
375 result.response.async_set_speech(
376 " ".join([part.text.strip()
for part
in chat_response.parts
if part.text])
382 user_input: conversation.ConversationInput,
383 llm_api: llm.APIInstance |
None,
384 llm_context: llm.LLMContext,
386 user_name: str |
None =
None
389 and user_input.context.user_id
391 user := await self.
hasshass.auth.async_get_user(user_input.context.user_id)
394 user_name = user.name
399 + self.
entryentry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
403 "ha_name": self.
hasshass.config.location_name,
404 "user_name": user_name,
405 "llm_context": llm_context,
412 parts.append(llm_api.api_prompt)
414 return "\n".join(parts)
417 self, hass: HomeAssistant, entry: ConfigEntry
419 """Handle options update."""
421 await hass.config_entries.async_reload(entry.entry_id)
list[str]|Literal["*"] supported_languages(self)
str _async_render_prompt(self, conversation.ConversationInput user_input, llm.APIInstance|None llm_api, llm.LLMContext llm_context)
conversation.ConversationResult async_process(self, conversation.ConversationInput user_input)
None _async_entry_update_listener(self, HomeAssistant hass, ConfigEntry entry)
None __init__(self, ConfigEntry entry)
None async_will_remove_from_hass(self)
None async_added_to_hass(self)
web.Response get(self, web.Request request, str config_key)
Any _escape_decode(Any value)
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)
dict[str, Any] _format_schema(dict[str, Any] schema)
dict[str, Any] _format_tool(llm.Tool tool, Callable[[Any], Any]|None custom_serializer)