1 """Conversation support for OpenAI."""
3 from collections.abc
import Callable
5 from typing
import Any, Literal
8 from openai._types
import NOT_GIVEN
9 from openai.types.chat
import (
10 ChatCompletionAssistantMessageParam,
11 ChatCompletionMessage,
12 ChatCompletionMessageParam,
13 ChatCompletionMessageToolCallParam,
14 ChatCompletionSystemMessageParam,
15 ChatCompletionToolMessageParam,
16 ChatCompletionToolParam,
17 ChatCompletionUserMessageParam,
19 from openai.types.chat.chat_completion_message_tool_call_param
import Function
20 from openai.types.shared_params
import FunctionDefinition
21 import voluptuous
as vol
22 from voluptuous_openapi
import convert
34 from .
import OpenAIConfigEntry
43 RECOMMENDED_CHAT_MODEL,
44 RECOMMENDED_MAX_TOKENS,
45 RECOMMENDED_TEMPERATURE,
50 MAX_TOOL_ITERATIONS = 10
55 config_entry: OpenAIConfigEntry,
56 async_add_entities: AddEntitiesCallback,
58 """Set up conversation entities."""
64 tool: llm.Tool, custom_serializer: Callable[[Any], Any] |
None
65 ) -> ChatCompletionToolParam:
66 """Format tool specification."""
67 tool_spec = FunctionDefinition(
69 parameters=convert(tool.parameters, custom_serializer=custom_serializer),
72 tool_spec[
"description"] = tool.description
73 return ChatCompletionToolParam(type=
"function", function=tool_spec)
79 """OpenAI conversation agent."""
81 _attr_has_entity_name =
True
84 def __init__(self, entry: OpenAIConfigEntry) ->
None:
85 """Initialize the agent."""
87 self.history: dict[str, list[ChatCompletionMessageParam]] = {}
90 identifiers={(DOMAIN, entry.entry_id)},
92 manufacturer=
"OpenAI",
94 entry_type=dr.DeviceEntryType.SERVICE,
96 if self.
entryentry.options.get(CONF_LLM_HASS_API):
98 conversation.ConversationEntityFeature.CONTROL
103 """Return a list of supported languages."""
107 """When entity is added to Home Assistant."""
109 assist_pipeline.async_migrate_engine(
112 conversation.async_set_agent(self.
hasshass, self.
entryentry, self)
113 self.
entryentry.async_on_unload(
118 """When entity will be removed from Home Assistant."""
119 conversation.async_unset_agent(self.
hasshass, self.
entryentry)
123 self, user_input: conversation.ConversationInput
125 """Process a sentence."""
126 options = self.
entryentry.options
127 intent_response = intent.IntentResponse(language=user_input.language)
128 llm_api: llm.APIInstance |
None =
None
129 tools: list[ChatCompletionToolParam] |
None =
None
130 user_name: str |
None =
None
131 llm_context = llm.LLMContext(
133 context=user_input.context,
134 user_prompt=user_input.text,
135 language=user_input.language,
136 assistant=conversation.DOMAIN,
137 device_id=user_input.device_id,
140 if options.get(CONF_LLM_HASS_API):
142 llm_api = await llm.async_get_api(
144 options[CONF_LLM_HASS_API],
147 except HomeAssistantError
as err:
148 LOGGER.error(
"Error getting LLM API: %s", err)
149 intent_response.async_set_error(
150 intent.IntentResponseErrorCode.UNKNOWN,
151 "Error preparing LLM API",
154 response=intent_response, conversation_id=user_input.conversation_id
157 _format_tool(tool, llm_api.custom_serializer)
for tool
in llm_api.tools
160 if user_input.conversation_id
is None:
161 conversation_id = ulid.ulid_now()
164 elif user_input.conversation_id
in self.history:
165 conversation_id = user_input.conversation_id
166 messages = self.history[conversation_id]
174 ulid.ulid_to_bytes(user_input.conversation_id)
175 conversation_id = ulid.ulid_now()
177 conversation_id = user_input.conversation_id
183 and user_input.context.user_id
185 user := await self.
hasshass.auth.async_get_user(user_input.context.user_id)
188 user_name = user.name
194 + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
198 "ha_name": self.
hasshass.config.location_name,
199 "user_name": user_name,
200 "llm_context": llm_context,
206 except TemplateError
as err:
207 LOGGER.error(
"Error rendering prompt: %s", err)
208 intent_response = intent.IntentResponse(language=user_input.language)
209 intent_response.async_set_error(
210 intent.IntentResponseErrorCode.UNKNOWN,
211 "Sorry, I had a problem with my template",
214 response=intent_response, conversation_id=conversation_id
218 prompt_parts.append(llm_api.api_prompt)
220 prompt =
"\n".join(prompt_parts)
224 ChatCompletionSystemMessageParam(role=
"system", content=prompt),
226 ChatCompletionUserMessageParam(role=
"user", content=user_input.text),
229 LOGGER.debug(
"Prompt: %s", messages)
230 LOGGER.debug(
"Tools: %s", tools)
231 trace.async_conversation_trace_append(
232 trace.ConversationTraceEventType.AGENT_DETAIL,
233 {
"messages": messages,
"tools": llm_api.tools
if llm_api
else None},
236 client = self.
entryentry.runtime_data
239 for _iteration
in range(MAX_TOOL_ITERATIONS):
241 result = await client.chat.completions.create(
242 model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
244 tools=tools
or NOT_GIVEN,
245 max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
246 top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
247 temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
248 user=conversation_id,
250 except openai.OpenAIError
as err:
251 LOGGER.error(
"Error talking to OpenAI: %s", err)
252 intent_response = intent.IntentResponse(language=user_input.language)
253 intent_response.async_set_error(
254 intent.IntentResponseErrorCode.UNKNOWN,
255 "Sorry, I had a problem talking to OpenAI",
258 response=intent_response, conversation_id=conversation_id
261 LOGGER.debug(
"Response %s", result)
262 response = result.choices[0].message
265 message: ChatCompletionMessage,
266 ) -> ChatCompletionMessageParam:
267 """Convert from class to TypedDict."""
268 tool_calls: list[ChatCompletionMessageToolCallParam] = []
269 if message.tool_calls:
271 ChatCompletionMessageToolCallParam(
274 arguments=tool_call.function.arguments,
275 name=tool_call.function.name,
279 for tool_call
in message.tool_calls
281 param = ChatCompletionAssistantMessageParam(
283 content=message.content,
286 param[
"tool_calls"] = tool_calls
289 messages.append(message_convert(response))
290 tool_calls = response.tool_calls
292 if not tool_calls
or not llm_api:
295 for tool_call
in tool_calls:
296 tool_input = llm.ToolInput(
297 tool_name=tool_call.function.name,
298 tool_args=json.loads(tool_call.function.arguments),
301 "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
305 tool_response = await llm_api.async_call_tool(tool_input)
306 except (HomeAssistantError, vol.Invalid)
as e:
307 tool_response = {
"error": type(e).__name__}
309 tool_response[
"error_text"] =
str(e)
311 LOGGER.debug(
"Tool response: %s", tool_response)
313 ChatCompletionToolMessageParam(
315 tool_call_id=tool_call.id,
316 content=json.dumps(tool_response),
320 self.history[conversation_id] = messages
322 intent_response = intent.IntentResponse(language=user_input.language)
323 intent_response.async_set_speech(response.content
or "")
325 response=intent_response, conversation_id=conversation_id
329 self, hass: HomeAssistant, entry: ConfigEntry
331 """Handle options update."""
333 await hass.config_entries.async_reload(entry.entry_id)
conversation.ConversationResult async_process(self, conversation.ConversationInput user_input)
list[str]|Literal["*"] supported_languages(self)
None async_added_to_hass(self)
None __init__(self, OpenAIConfigEntry entry)
None _async_entry_update_listener(self, HomeAssistant hass, ConfigEntry entry)
None async_will_remove_from_hass(self)
None async_setup_entry(HomeAssistant hass, OpenAIConfigEntry config_entry, AddEntitiesCallback async_add_entities)
ChatCompletionToolParam _format_tool(llm.Tool tool, Callable[[Any], Any]|None custom_serializer)