Home Assistant Unofficial Reference 2024.12.1
conversation.py
Go to the documentation of this file.
1 """Conversation support for OpenAI."""
2 
3 from collections.abc import Callable
4 import json
5 from typing import Any, Literal
6 
7 import openai
8 from openai._types import NOT_GIVEN
9 from openai.types.chat import (
10  ChatCompletionAssistantMessageParam,
11  ChatCompletionMessage,
12  ChatCompletionMessageParam,
13  ChatCompletionMessageToolCallParam,
14  ChatCompletionSystemMessageParam,
15  ChatCompletionToolMessageParam,
16  ChatCompletionToolParam,
17  ChatCompletionUserMessageParam,
18 )
19 from openai.types.chat.chat_completion_message_tool_call_param import Function
20 from openai.types.shared_params import FunctionDefinition
21 import voluptuous as vol
22 from voluptuous_openapi import convert
23 
24 from homeassistant.components import assist_pipeline, conversation
26 from homeassistant.config_entries import ConfigEntry
27 from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
28 from homeassistant.core import HomeAssistant
29 from homeassistant.exceptions import HomeAssistantError, TemplateError
30 from homeassistant.helpers import device_registry as dr, intent, llm, template
31 from homeassistant.helpers.entity_platform import AddEntitiesCallback
32 from homeassistant.util import ulid
33 
34 from . import OpenAIConfigEntry
35 from .const import (
36  CONF_CHAT_MODEL,
37  CONF_MAX_TOKENS,
38  CONF_PROMPT,
39  CONF_TEMPERATURE,
40  CONF_TOP_P,
41  DOMAIN,
42  LOGGER,
43  RECOMMENDED_CHAT_MODEL,
44  RECOMMENDED_MAX_TOKENS,
45  RECOMMENDED_TEMPERATURE,
46  RECOMMENDED_TOP_P,
47 )
48 
49 # Max number of back and forth with the LLM to generate a response
50 MAX_TOOL_ITERATIONS = 10
51 
52 
54  hass: HomeAssistant,
55  config_entry: OpenAIConfigEntry,
56  async_add_entities: AddEntitiesCallback,
57 ) -> None:
58  """Set up conversation entities."""
59  agent = OpenAIConversationEntity(config_entry)
60  async_add_entities([agent])
61 
62 
64  tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
65 ) -> ChatCompletionToolParam:
66  """Format tool specification."""
67  tool_spec = FunctionDefinition(
68  name=tool.name,
69  parameters=convert(tool.parameters, custom_serializer=custom_serializer),
70  )
71  if tool.description:
72  tool_spec["description"] = tool.description
73  return ChatCompletionToolParam(type="function", function=tool_spec)
74 
75 
78 ):
79  """OpenAI conversation agent."""
80 
81  _attr_has_entity_name = True
82  _attr_name = None
83 
84  def __init__(self, entry: OpenAIConfigEntry) -> None:
85  """Initialize the agent."""
86  self.entryentry = entry
87  self.history: dict[str, list[ChatCompletionMessageParam]] = {}
88  self._attr_unique_id_attr_unique_id = entry.entry_id
89  self._attr_device_info_attr_device_info = dr.DeviceInfo(
90  identifiers={(DOMAIN, entry.entry_id)},
91  name=entry.title,
92  manufacturer="OpenAI",
93  model="ChatGPT",
94  entry_type=dr.DeviceEntryType.SERVICE,
95  )
96  if self.entryentry.options.get(CONF_LLM_HASS_API):
97  self._attr_supported_features_attr_supported_features_attr_supported_features = (
98  conversation.ConversationEntityFeature.CONTROL
99  )
100 
101  @property
102  def supported_languages(self) -> list[str] | Literal["*"]:
103  """Return a list of supported languages."""
104  return MATCH_ALL
105 
106  async def async_added_to_hass(self) -> None:
107  """When entity is added to Home Assistant."""
108  await super().async_added_to_hass()
109  assist_pipeline.async_migrate_engine(
110  self.hasshass, "conversation", self.entryentry.entry_id, self.entity_identity_id
111  )
112  conversation.async_set_agent(self.hasshass, self.entryentry, self)
113  self.entryentry.async_on_unload(
114  self.entryentry.add_update_listener(self._async_entry_update_listener_async_entry_update_listener)
115  )
116 
117  async def async_will_remove_from_hass(self) -> None:
118  """When entity will be removed from Home Assistant."""
119  conversation.async_unset_agent(self.hasshass, self.entryentry)
120  await super().async_will_remove_from_hass()
121 
122  async def async_process(
123  self, user_input: conversation.ConversationInput
125  """Process a sentence."""
126  options = self.entryentry.options
127  intent_response = intent.IntentResponse(language=user_input.language)
128  llm_api: llm.APIInstance | None = None
129  tools: list[ChatCompletionToolParam] | None = None
130  user_name: str | None = None
131  llm_context = llm.LLMContext(
132  platform=DOMAIN,
133  context=user_input.context,
134  user_prompt=user_input.text,
135  language=user_input.language,
136  assistant=conversation.DOMAIN,
137  device_id=user_input.device_id,
138  )
139 
140  if options.get(CONF_LLM_HASS_API):
141  try:
142  llm_api = await llm.async_get_api(
143  self.hasshass,
144  options[CONF_LLM_HASS_API],
145  llm_context,
146  )
147  except HomeAssistantError as err:
148  LOGGER.error("Error getting LLM API: %s", err)
149  intent_response.async_set_error(
150  intent.IntentResponseErrorCode.UNKNOWN,
151  "Error preparing LLM API",
152  )
154  response=intent_response, conversation_id=user_input.conversation_id
155  )
156  tools = [
157  _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
158  ]
159 
160  if user_input.conversation_id is None:
161  conversation_id = ulid.ulid_now()
162  messages = []
163 
164  elif user_input.conversation_id in self.history:
165  conversation_id = user_input.conversation_id
166  messages = self.history[conversation_id]
167 
168  else:
169  # Conversation IDs are ULIDs. We generate a new one if not provided.
170  # If an old OLID is passed in, we will generate a new one to indicate
171  # a new conversation was started. If the user picks their own, they
172  # want to track a conversation and we respect it.
173  try:
174  ulid.ulid_to_bytes(user_input.conversation_id)
175  conversation_id = ulid.ulid_now()
176  except ValueError:
177  conversation_id = user_input.conversation_id
178 
179  messages = []
180 
181  if (
182  user_input.context
183  and user_input.context.user_id
184  and (
185  user := await self.hasshass.auth.async_get_user(user_input.context.user_id)
186  )
187  ):
188  user_name = user.name
189 
190  try:
191  prompt_parts = [
192  template.Template(
193  llm.BASE_PROMPT
194  + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
195  self.hasshass,
196  ).async_render(
197  {
198  "ha_name": self.hasshass.config.location_name,
199  "user_name": user_name,
200  "llm_context": llm_context,
201  },
202  parse_result=False,
203  )
204  ]
205 
206  except TemplateError as err:
207  LOGGER.error("Error rendering prompt: %s", err)
208  intent_response = intent.IntentResponse(language=user_input.language)
209  intent_response.async_set_error(
210  intent.IntentResponseErrorCode.UNKNOWN,
211  "Sorry, I had a problem with my template",
212  )
214  response=intent_response, conversation_id=conversation_id
215  )
216 
217  if llm_api:
218  prompt_parts.append(llm_api.api_prompt)
219 
220  prompt = "\n".join(prompt_parts)
221 
222  # Create a copy of the variable because we attach it to the trace
223  messages = [
224  ChatCompletionSystemMessageParam(role="system", content=prompt),
225  *messages[1:],
226  ChatCompletionUserMessageParam(role="user", content=user_input.text),
227  ]
228 
229  LOGGER.debug("Prompt: %s", messages)
230  LOGGER.debug("Tools: %s", tools)
231  trace.async_conversation_trace_append(
232  trace.ConversationTraceEventType.AGENT_DETAIL,
233  {"messages": messages, "tools": llm_api.tools if llm_api else None},
234  )
235 
236  client = self.entryentry.runtime_data
237 
238  # To prevent infinite loops, we limit the number of iterations
239  for _iteration in range(MAX_TOOL_ITERATIONS):
240  try:
241  result = await client.chat.completions.create(
242  model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
243  messages=messages,
244  tools=tools or NOT_GIVEN,
245  max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
246  top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
247  temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
248  user=conversation_id,
249  )
250  except openai.OpenAIError as err:
251  LOGGER.error("Error talking to OpenAI: %s", err)
252  intent_response = intent.IntentResponse(language=user_input.language)
253  intent_response.async_set_error(
254  intent.IntentResponseErrorCode.UNKNOWN,
255  "Sorry, I had a problem talking to OpenAI",
256  )
258  response=intent_response, conversation_id=conversation_id
259  )
260 
261  LOGGER.debug("Response %s", result)
262  response = result.choices[0].message
263 
264  def message_convert(
265  message: ChatCompletionMessage,
266  ) -> ChatCompletionMessageParam:
267  """Convert from class to TypedDict."""
268  tool_calls: list[ChatCompletionMessageToolCallParam] = []
269  if message.tool_calls:
270  tool_calls = [
271  ChatCompletionMessageToolCallParam(
272  id=tool_call.id,
273  function=Function(
274  arguments=tool_call.function.arguments,
275  name=tool_call.function.name,
276  ),
277  type=tool_call.type,
278  )
279  for tool_call in message.tool_calls
280  ]
281  param = ChatCompletionAssistantMessageParam(
282  role=message.role,
283  content=message.content,
284  )
285  if tool_calls:
286  param["tool_calls"] = tool_calls
287  return param
288 
289  messages.append(message_convert(response))
290  tool_calls = response.tool_calls
291 
292  if not tool_calls or not llm_api:
293  break
294 
295  for tool_call in tool_calls:
296  tool_input = llm.ToolInput(
297  tool_name=tool_call.function.name,
298  tool_args=json.loads(tool_call.function.arguments),
299  )
300  LOGGER.debug(
301  "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
302  )
303 
304  try:
305  tool_response = await llm_api.async_call_tool(tool_input)
306  except (HomeAssistantError, vol.Invalid) as e:
307  tool_response = {"error": type(e).__name__}
308  if str(e):
309  tool_response["error_text"] = str(e)
310 
311  LOGGER.debug("Tool response: %s", tool_response)
312  messages.append(
313  ChatCompletionToolMessageParam(
314  role="tool",
315  tool_call_id=tool_call.id,
316  content=json.dumps(tool_response),
317  )
318  )
319 
320  self.history[conversation_id] = messages
321 
322  intent_response = intent.IntentResponse(language=user_input.language)
323  intent_response.async_set_speech(response.content or "")
325  response=intent_response, conversation_id=conversation_id
326  )
327 
329  self, hass: HomeAssistant, entry: ConfigEntry
330  ) -> None:
331  """Handle options update."""
332  # Reload as we update device info + entity name + supported features
333  await hass.config_entries.async_reload(entry.entry_id)
conversation.ConversationResult async_process(self, conversation.ConversationInput user_input)
None _async_entry_update_listener(self, HomeAssistant hass, ConfigEntry entry)
None async_setup_entry(HomeAssistant hass, OpenAIConfigEntry config_entry, AddEntitiesCallback async_add_entities)
Definition: conversation.py:57
ChatCompletionToolParam _format_tool(llm.Tool tool, Callable[[Any], Any]|None custom_serializer)
Definition: conversation.py:65