Home Assistant Unofficial Reference 2024.12.1
conversation.py
Go to the documentation of this file.
1 """The conversation platform for the Ollama integration."""
2 
3 from __future__ import annotations
4 
5 from collections.abc import Callable
6 import json
7 import logging
8 import time
9 from typing import Any, Literal
10 
11 import ollama
12 import voluptuous as vol
13 from voluptuous_openapi import convert
14 
15 from homeassistant.components import assist_pipeline, conversation
17 from homeassistant.config_entries import ConfigEntry
18 from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
19 from homeassistant.core import HomeAssistant
20 from homeassistant.exceptions import HomeAssistantError, TemplateError
21 from homeassistant.helpers import intent, llm, template
22 from homeassistant.helpers.entity_platform import AddEntitiesCallback
23 from homeassistant.util import ulid
24 
25 from .const import (
26  CONF_KEEP_ALIVE,
27  CONF_MAX_HISTORY,
28  CONF_MODEL,
29  CONF_NUM_CTX,
30  CONF_PROMPT,
31  DEFAULT_KEEP_ALIVE,
32  DEFAULT_MAX_HISTORY,
33  DEFAULT_NUM_CTX,
34  DOMAIN,
35  MAX_HISTORY_SECONDS,
36 )
37 from .models import MessageHistory, MessageRole
38 
39 # Max number of back and forth with the LLM to generate a response
40 MAX_TOOL_ITERATIONS = 10
41 
42 _LOGGER = logging.getLogger(__name__)
43 
44 
46  hass: HomeAssistant,
47  config_entry: ConfigEntry,
48  async_add_entities: AddEntitiesCallback,
49 ) -> None:
50  """Set up conversation entities."""
51  agent = OllamaConversationEntity(config_entry)
52  async_add_entities([agent])
53 
54 
56  tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
57 ) -> dict[str, Any]:
58  """Format tool specification."""
59  tool_spec = {
60  "name": tool.name,
61  "parameters": convert(tool.parameters, custom_serializer=custom_serializer),
62  }
63  if tool.description:
64  tool_spec["description"] = tool.description
65  return {"type": "function", "function": tool_spec}
66 
67 
68 def _fix_invalid_arguments(value: Any) -> Any:
69  """Attempt to repair incorrectly formatted json function arguments.
70 
71  Small models (for example llama3.1 8B) may produce invalid argument values
72  which we attempt to repair here.
73  """
74  if not isinstance(value, str):
75  return value
76  if (value.startswith("[") and value.endswith("]")) or (
77  value.startswith("{") and value.endswith("}")
78  ):
79  try:
80  return json.loads(value)
81  except json.decoder.JSONDecodeError:
82  pass
83  return value
84 
85 
86 def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
87  """Rewrite ollama tool arguments.
88 
89  This function improves tool use quality by fixing common mistakes made by
90  small local tool use models. This will repair invalid json arguments and
91  omit unnecessary arguments with empty values that will fail intent parsing.
92  """
93  return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
94 
95 
98 ):
99  """Ollama conversation agent."""
100 
101  _attr_has_entity_name = True
102 
103  def __init__(self, entry: ConfigEntry) -> None:
104  """Initialize the agent."""
105  self.entryentry = entry
106 
107  # conversation id -> message history
108  self._history_history: dict[str, MessageHistory] = {}
109  self._attr_name_attr_name = entry.title
110  self._attr_unique_id_attr_unique_id = entry.entry_id
111  if self.entryentry.options.get(CONF_LLM_HASS_API):
112  self._attr_supported_features_attr_supported_features_attr_supported_features = (
113  conversation.ConversationEntityFeature.CONTROL
114  )
115 
116  async def async_added_to_hass(self) -> None:
117  """When entity is added to Home Assistant."""
118  await super().async_added_to_hass()
119  assist_pipeline.async_migrate_engine(
120  self.hasshass, "conversation", self.entryentry.entry_id, self.entity_identity_id
121  )
122  conversation.async_set_agent(self.hasshass, self.entryentry, self)
123  self.entryentry.async_on_unload(
124  self.entryentry.add_update_listener(self._async_entry_update_listener_async_entry_update_listener)
125  )
126 
127  async def async_will_remove_from_hass(self) -> None:
128  """When entity will be removed from Home Assistant."""
129  conversation.async_unset_agent(self.hasshass, self.entryentry)
130  await super().async_will_remove_from_hass()
131 
132  @property
133  def supported_languages(self) -> list[str] | Literal["*"]:
134  """Return a list of supported languages."""
135  return MATCH_ALL
136 
137  async def async_process(
138  self, user_input: conversation.ConversationInput
140  """Process a sentence."""
141  settings = {**self.entryentry.data, **self.entryentry.options}
142 
143  client = self.hasshass.data[DOMAIN][self.entryentry.entry_id]
144  conversation_id = user_input.conversation_id or ulid.ulid_now()
145  model = settings[CONF_MODEL]
146  intent_response = intent.IntentResponse(language=user_input.language)
147  llm_api: llm.APIInstance | None = None
148  tools: list[dict[str, Any]] | None = None
149  user_name: str | None = None
150  llm_context = llm.LLMContext(
151  platform=DOMAIN,
152  context=user_input.context,
153  user_prompt=user_input.text,
154  language=user_input.language,
155  assistant=conversation.DOMAIN,
156  device_id=user_input.device_id,
157  )
158 
159  if settings.get(CONF_LLM_HASS_API):
160  try:
161  llm_api = await llm.async_get_api(
162  self.hasshass,
163  settings[CONF_LLM_HASS_API],
164  llm_context,
165  )
166  except HomeAssistantError as err:
167  _LOGGER.error("Error getting LLM API: %s", err)
168  intent_response.async_set_error(
169  intent.IntentResponseErrorCode.UNKNOWN,
170  f"Error preparing LLM API: {err}",
171  )
173  response=intent_response, conversation_id=user_input.conversation_id
174  )
175  tools = [
176  _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
177  ]
178 
179  if (
180  user_input.context
181  and user_input.context.user_id
182  and (
183  user := await self.hasshass.auth.async_get_user(user_input.context.user_id)
184  )
185  ):
186  user_name = user.name
187 
188  # Look up message history
189  message_history: MessageHistory | None = None
190  message_history = self._history_history.get(conversation_id)
191  if message_history is None:
192  # New history
193  #
194  # Render prompt and error out early if there's a problem
195  try:
196  prompt_parts = [
197  template.Template(
198  llm.BASE_PROMPT
199  + settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
200  self.hasshass,
201  ).async_render(
202  {
203  "ha_name": self.hasshass.config.location_name,
204  "user_name": user_name,
205  "llm_context": llm_context,
206  },
207  parse_result=False,
208  )
209  ]
210 
211  except TemplateError as err:
212  _LOGGER.error("Error rendering prompt: %s", err)
213  intent_response.async_set_error(
214  intent.IntentResponseErrorCode.UNKNOWN,
215  f"Sorry, I had a problem generating my prompt: {err}",
216  )
218  response=intent_response, conversation_id=conversation_id
219  )
220 
221  if llm_api:
222  prompt_parts.append(llm_api.api_prompt)
223 
224  prompt = "\n".join(prompt_parts)
225  _LOGGER.debug("Prompt: %s", prompt)
226  _LOGGER.debug("Tools: %s", tools)
227 
228  message_history = MessageHistory(
229  timestamp=time.monotonic(),
230  messages=[
231  ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
232  ],
233  )
234  self._history_history[conversation_id] = message_history
235  else:
236  # Bump timestamp so this conversation won't get cleaned up
237  message_history.timestamp = time.monotonic()
238 
239  # Clean up old histories
240  self._prune_old_histories_prune_old_histories()
241 
242  # Trim this message history to keep a maximum number of *user* messages
243  max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
244  self._trim_history_trim_history(message_history, max_messages)
245 
246  # Add new user message
247  message_history.messages.append(
248  ollama.Message(role=MessageRole.USER.value, content=user_input.text)
249  )
250 
251  trace.async_conversation_trace_append(
252  trace.ConversationTraceEventType.AGENT_DETAIL,
253  {"messages": message_history.messages},
254  )
255 
256  # Get response
257  # To prevent infinite loops, we limit the number of iterations
258  for _iteration in range(MAX_TOOL_ITERATIONS):
259  try:
260  response = await client.chat(
261  model=model,
262  # Make a copy of the messages because we mutate the list later
263  messages=list(message_history.messages),
264  tools=tools,
265  stream=False,
266  # keep_alive requires specifying unit. In this case, seconds
267  keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
268  options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
269  )
270  except (ollama.RequestError, ollama.ResponseError) as err:
271  _LOGGER.error("Unexpected error talking to Ollama server: %s", err)
272  intent_response.async_set_error(
273  intent.IntentResponseErrorCode.UNKNOWN,
274  f"Sorry, I had a problem talking to the Ollama server: {err}",
275  )
277  response=intent_response, conversation_id=conversation_id
278  )
279 
280  response_message = response["message"]
281  message_history.messages.append(
282  ollama.Message(
283  role=response_message["role"],
284  content=response_message.get("content"),
285  tool_calls=response_message.get("tool_calls"),
286  )
287  )
288 
289  tool_calls = response_message.get("tool_calls")
290  if not tool_calls or not llm_api:
291  break
292 
293  for tool_call in tool_calls:
294  tool_input = llm.ToolInput(
295  tool_name=tool_call["function"]["name"],
296  tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
297  )
298  _LOGGER.debug(
299  "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
300  )
301 
302  try:
303  tool_response = await llm_api.async_call_tool(tool_input)
304  except (HomeAssistantError, vol.Invalid) as e:
305  tool_response = {"error": type(e).__name__}
306  if str(e):
307  tool_response["error_text"] = str(e)
308 
309  _LOGGER.debug("Tool response: %s", tool_response)
310  message_history.messages.append(
311  ollama.Message(
312  role=MessageRole.TOOL.value,
313  content=json.dumps(tool_response),
314  )
315  )
316 
317  # Create intent response
318  intent_response.async_set_speech(response_message["content"])
320  response=intent_response, conversation_id=conversation_id
321  )
322 
323  def _prune_old_histories(self) -> None:
324  """Remove old message histories."""
325  now = time.monotonic()
326  self._history_history = {
327  conversation_id: message_history
328  for conversation_id, message_history in self._history_history.items()
329  if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
330  }
331 
332  def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
333  """Trims excess messages from a single history."""
334  if max_messages < 1:
335  # Keep all messages
336  return
337 
338  if message_history.num_user_messages >= max_messages:
339  # Trim history but keep system prompt (first message).
340  # Every other message should be an assistant message, so keep 2x
341  # message objects.
342  num_keep = 2 * max_messages
343  drop_index = len(message_history.messages) - num_keep
344  message_history.messages = [
345  message_history.messages[0]
346  ] + message_history.messages[drop_index:]
347 
349  self, hass: HomeAssistant, entry: ConfigEntry
350  ) -> None:
351  """Handle options update."""
352  # Reload as we update device info + entity name + supported features
353  await hass.config_entries.async_reload(entry.entry_id)
None _trim_history(self, MessageHistory message_history, int max_messages)
conversation.ConversationResult async_process(self, conversation.ConversationInput user_input)
None _async_entry_update_listener(self, HomeAssistant hass, ConfigEntry entry)
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88
dict[str, Any] _parse_tool_args(dict[str, Any] arguments)
Definition: conversation.py:86
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)
Definition: conversation.py:49
dict[str, Any] _format_tool(llm.Tool tool, Callable[[Any], Any]|None custom_serializer)
Definition: conversation.py:57