Home Assistant Unofficial Reference 2024.12.1
conversation.py
Go to the documentation of this file.
1 """Conversation support for the Google Generative AI Conversation integration."""
2 
3 from __future__ import annotations
4 
5 import codecs
6 from collections.abc import Callable
7 from typing import Any, Literal
8 
9 from google.api_core.exceptions import GoogleAPIError
10 import google.generativeai as genai
11 from google.generativeai import protos
12 import google.generativeai.types as genai_types
13 from google.protobuf.json_format import MessageToDict
14 import voluptuous as vol
15 from voluptuous_openapi import convert
16 
17 from homeassistant.components import assist_pipeline, conversation
19 from homeassistant.config_entries import ConfigEntry
20 from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
21 from homeassistant.core import HomeAssistant
22 from homeassistant.exceptions import HomeAssistantError, TemplateError
23 from homeassistant.helpers import device_registry as dr, intent, llm, template
24 from homeassistant.helpers.entity_platform import AddEntitiesCallback
25 from homeassistant.util import ulid
26 
27 from .const import (
28  CONF_CHAT_MODEL,
29  CONF_DANGEROUS_BLOCK_THRESHOLD,
30  CONF_HARASSMENT_BLOCK_THRESHOLD,
31  CONF_HATE_BLOCK_THRESHOLD,
32  CONF_MAX_TOKENS,
33  CONF_PROMPT,
34  CONF_SEXUAL_BLOCK_THRESHOLD,
35  CONF_TEMPERATURE,
36  CONF_TOP_K,
37  CONF_TOP_P,
38  DOMAIN,
39  LOGGER,
40  RECOMMENDED_CHAT_MODEL,
41  RECOMMENDED_HARM_BLOCK_THRESHOLD,
42  RECOMMENDED_MAX_TOKENS,
43  RECOMMENDED_TEMPERATURE,
44  RECOMMENDED_TOP_K,
45  RECOMMENDED_TOP_P,
46 )
47 
48 # Max number of back and forth with the LLM to generate a response
49 MAX_TOOL_ITERATIONS = 10
50 
51 
53  hass: HomeAssistant,
54  config_entry: ConfigEntry,
55  async_add_entities: AddEntitiesCallback,
56 ) -> None:
57  """Set up conversation entities."""
58  agent = GoogleGenerativeAIConversationEntity(config_entry)
59  async_add_entities([agent])
60 
61 
62 SUPPORTED_SCHEMA_KEYS = {
63  "type",
64  "format",
65  "description",
66  "nullable",
67  "enum",
68  "items",
69  "properties",
70  "required",
71 }
72 
73 
74 def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
75  """Format the schema to protobuf."""
76  if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")):
77  for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
78  if "type" in subschema: # Fallback to first subschema with 'type' field
79  return _format_schema(subschema)
80  return _format_schema(
81  subschemas[0]
82  ) # Or, if not found, to any of the subschemas
83 
84  result = {}
85  for key, val in schema.items():
86  if key not in SUPPORTED_SCHEMA_KEYS:
87  continue
88  if key == "type":
89  key = "type_"
90  val = val.upper()
91  elif key == "format":
92  if schema.get("type") == "string" and val != "enum":
93  continue
94  if schema.get("type") not in ("number", "integer", "string"):
95  continue
96  key = "format_"
97  elif key == "items":
98  val = _format_schema(val)
99  elif key == "properties":
100  val = {k: _format_schema(v) for k, v in val.items()}
101  result[key] = val
102 
103  if result.get("enum") and result.get("type_") != "STRING":
104  # enum is only allowed for STRING type. This is safe as long as the schema
105  # contains vol.Coerce for the respective type, for example:
106  # vol.All(vol.Coerce(int), vol.In([1, 2, 3]))
107  result["type_"] = "STRING"
108  result["enum"] = [str(item) for item in result["enum"]]
109 
110  if result.get("type_") == "OBJECT" and not result.get("properties"):
111  # An object with undefined properties is not supported by Gemini API.
112  # Fallback to JSON string. This will probably fail for most tools that want it,
113  # but we don't have a better fallback strategy so far.
114  result["properties"] = {"json": {"type_": "STRING"}}
115  result["required"] = []
116  return result
117 
118 
120  tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
121 ) -> dict[str, Any]:
122  """Format tool specification."""
123 
124  if tool.parameters.schema:
125  parameters = _format_schema(
126  convert(tool.parameters, custom_serializer=custom_serializer)
127  )
128  else:
129  parameters = None
130 
131  return protos.Tool(
132  {
133  "function_declarations": [
134  {
135  "name": tool.name,
136  "description": tool.description,
137  "parameters": parameters,
138  }
139  ]
140  }
141  )
142 
143 
144 def _escape_decode(value: Any) -> Any:
145  """Recursively call codecs.escape_decode on all values."""
146  if isinstance(value, str):
147  return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined]
148  if isinstance(value, list):
149  return [_escape_decode(item) for item in value]
150  if isinstance(value, dict):
151  return {k: _escape_decode(v) for k, v in value.items()}
152  return value
153 
154 
157 ):
158  """Google Generative AI conversation agent."""
159 
160  _attr_has_entity_name = True
161  _attr_name = None
162 
163  def __init__(self, entry: ConfigEntry) -> None:
164  """Initialize the agent."""
165  self.entryentry = entry
166  self.history: dict[str, list[genai_types.ContentType]] = {}
167  self._attr_unique_id_attr_unique_id = entry.entry_id
168  self._attr_device_info_attr_device_info = dr.DeviceInfo(
169  identifiers={(DOMAIN, entry.entry_id)},
170  name=entry.title,
171  manufacturer="Google",
172  model="Generative AI",
173  entry_type=dr.DeviceEntryType.SERVICE,
174  )
175  if self.entryentry.options.get(CONF_LLM_HASS_API):
176  self._attr_supported_features_attr_supported_features_attr_supported_features = (
177  conversation.ConversationEntityFeature.CONTROL
178  )
179 
180  @property
181  def supported_languages(self) -> list[str] | Literal["*"]:
182  """Return a list of supported languages."""
183  return MATCH_ALL
184 
185  async def async_added_to_hass(self) -> None:
186  """When entity is added to Home Assistant."""
187  await super().async_added_to_hass()
188  assist_pipeline.async_migrate_engine(
189  self.hasshass, "conversation", self.entryentry.entry_id, self.entity_identity_id
190  )
191  conversation.async_set_agent(self.hasshass, self.entryentry, self)
192  self.entryentry.async_on_unload(
193  self.entryentry.add_update_listener(self._async_entry_update_listener_async_entry_update_listener)
194  )
195 
196  async def async_will_remove_from_hass(self) -> None:
197  """When entity will be removed from Home Assistant."""
198  conversation.async_unset_agent(self.hasshass, self.entryentry)
199  await super().async_will_remove_from_hass()
200 
201  async def async_process(
202  self, user_input: conversation.ConversationInput
204  """Process a sentence."""
206  response=intent.IntentResponse(language=user_input.language),
207  conversation_id=user_input.conversation_id
208  if user_input.conversation_id in self.history
209  else ulid.ulid_now(),
210  )
211  assert result.conversation_id
212 
213  llm_context = llm.LLMContext(
214  platform=DOMAIN,
215  context=user_input.context,
216  user_prompt=user_input.text,
217  language=user_input.language,
218  assistant=conversation.DOMAIN,
219  device_id=user_input.device_id,
220  )
221  llm_api: llm.APIInstance | None = None
222  tools: list[dict[str, Any]] | None = None
223  if self.entryentry.options.get(CONF_LLM_HASS_API):
224  try:
225  llm_api = await llm.async_get_api(
226  self.hasshass,
227  self.entryentry.options[CONF_LLM_HASS_API],
228  llm_context,
229  )
230  except HomeAssistantError as err:
231  LOGGER.error("Error getting LLM API: %s", err)
232  result.response.async_set_error(
233  intent.IntentResponseErrorCode.UNKNOWN,
234  f"Error preparing LLM API: {err}",
235  )
236  return result
237  tools = [
238  _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
239  ]
240 
241  try:
242  prompt = await self._async_render_prompt_async_render_prompt(user_input, llm_api, llm_context)
243  except TemplateError as err:
244  LOGGER.error("Error rendering prompt: %s", err)
245  result.response.async_set_error(
246  intent.IntentResponseErrorCode.UNKNOWN,
247  f"Sorry, I had a problem with my template: {err}",
248  )
249  return result
250 
251  model_name = self.entryentry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
252  # Gemini 1.0 doesn't support system_instruction while 1.5 does.
253  # Assume future versions will support it (if not, the request fails with a
254  # clear message at which point we can fix).
255  supports_system_instruction = (
256  "gemini-1.0" not in model_name and "gemini-pro" not in model_name
257  )
258 
259  model = genai.GenerativeModel(
260  model_name=model_name,
261  generation_config={
262  "temperature": self.entryentry.options.get(
263  CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
264  ),
265  "top_p": self.entryentry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
266  "top_k": self.entryentry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
267  "max_output_tokens": self.entryentry.options.get(
268  CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
269  ),
270  },
271  safety_settings={
272  "HARASSMENT": self.entryentry.options.get(
273  CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
274  ),
275  "HATE": self.entryentry.options.get(
276  CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
277  ),
278  "SEXUAL": self.entryentry.options.get(
279  CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
280  ),
281  "DANGEROUS": self.entryentry.options.get(
282  CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
283  ),
284  },
285  tools=tools or None,
286  system_instruction=prompt if supports_system_instruction else None,
287  )
288 
289  messages = self.history.get(result.conversation_id, [])
290  if not supports_system_instruction:
291  if not messages:
292  messages = [{}, {"role": "model", "parts": "Ok"}]
293  messages[0] = {"role": "user", "parts": prompt}
294 
295  LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
296  trace.async_conversation_trace_append(
297  trace.ConversationTraceEventType.AGENT_DETAIL,
298  {
299  # Make a copy to attach it to the trace event.
300  "messages": messages[:]
301  if supports_system_instruction
302  else messages[2:],
303  "prompt": prompt,
304  "tools": [*llm_api.tools] if llm_api else None,
305  },
306  )
307 
308  chat = model.start_chat(history=messages)
309  chat_request = user_input.text
310  # To prevent infinite loops, we limit the number of iterations
311  for _iteration in range(MAX_TOOL_ITERATIONS):
312  try:
313  chat_response = await chat.send_message_async(chat_request)
314  except (
315  GoogleAPIError,
316  ValueError,
317  genai_types.BlockedPromptException,
318  genai_types.StopCandidateException,
319  ) as err:
320  LOGGER.error("Error sending message: %s %s", type(err), err)
321 
322  if isinstance(
323  err, genai_types.StopCandidateException
324  ) and "finish_reason: SAFETY\n" in str(err):
325  error = "The message got blocked by your safety settings"
326  else:
327  error = (
328  f"Sorry, I had a problem talking to Google Generative AI: {err}"
329  )
330 
331  result.response.async_set_error(
332  intent.IntentResponseErrorCode.UNKNOWN,
333  error,
334  )
335  return result
336 
337  LOGGER.debug("Response: %s", chat_response.parts)
338  if not chat_response.parts:
339  result.response.async_set_error(
340  intent.IntentResponseErrorCode.UNKNOWN,
341  "Sorry, I had a problem getting a response from Google Generative AI.",
342  )
343  return result
344  self.history[result.conversation_id] = chat.history
345  function_calls = [
346  part.function_call for part in chat_response.parts if part.function_call
347  ]
348  if not function_calls or not llm_api:
349  break
350 
351  tool_responses = []
352  for function_call in function_calls:
353  tool_call = MessageToDict(function_call._pb) # noqa: SLF001
354  tool_name = tool_call["name"]
355  tool_args = _escape_decode(tool_call["args"])
356  LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
357  tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
358  try:
359  function_response = await llm_api.async_call_tool(tool_input)
360  except (HomeAssistantError, vol.Invalid) as e:
361  function_response = {"error": type(e).__name__}
362  if str(e):
363  function_response["error_text"] = str(e)
364 
365  LOGGER.debug("Tool response: %s", function_response)
366  tool_responses.append(
367  protos.Part(
368  function_response=protos.FunctionResponse(
369  name=tool_name, response=function_response
370  )
371  )
372  )
373  chat_request = protos.Content(parts=tool_responses)
374 
375  result.response.async_set_speech(
376  " ".join([part.text.strip() for part in chat_response.parts if part.text])
377  )
378  return result
379 
381  self,
382  user_input: conversation.ConversationInput,
383  llm_api: llm.APIInstance | None,
384  llm_context: llm.LLMContext,
385  ) -> str:
386  user_name: str | None = None
387  if (
388  user_input.context
389  and user_input.context.user_id
390  and (
391  user := await self.hasshass.auth.async_get_user(user_input.context.user_id)
392  )
393  ):
394  user_name = user.name
395 
396  parts = [
397  template.Template(
398  llm.BASE_PROMPT
399  + self.entryentry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
400  self.hasshass,
401  ).async_render(
402  {
403  "ha_name": self.hasshass.config.location_name,
404  "user_name": user_name,
405  "llm_context": llm_context,
406  },
407  parse_result=False,
408  )
409  ]
410 
411  if llm_api:
412  parts.append(llm_api.api_prompt)
413 
414  return "\n".join(parts)
415 
417  self, hass: HomeAssistant, entry: ConfigEntry
418  ) -> None:
419  """Handle options update."""
420  # Reload as we update device info + entity name + supported features
421  await hass.config_entries.async_reload(entry.entry_id)
str _async_render_prompt(self, conversation.ConversationInput user_input, llm.APIInstance|None llm_api, llm.LLMContext llm_context)
conversation.ConversationResult async_process(self, conversation.ConversationInput user_input)
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)
Definition: conversation.py:56
dict[str, Any] _format_tool(llm.Tool tool, Callable[[Any], Any]|None custom_serializer)