1 """The conversation platform for the Ollama integration."""
3 from __future__
import annotations
5 from collections.abc
import Callable
9 from typing
import Any, Literal
12 import voluptuous
as vol
13 from voluptuous_openapi
import convert
37 from .models
import MessageHistory, MessageRole
40 MAX_TOOL_ITERATIONS = 10
42 _LOGGER = logging.getLogger(__name__)
47 config_entry: ConfigEntry,
48 async_add_entities: AddEntitiesCallback,
50 """Set up conversation entities."""
56 tool: llm.Tool, custom_serializer: Callable[[Any], Any] |
None
58 """Format tool specification."""
61 "parameters": convert(tool.parameters, custom_serializer=custom_serializer),
64 tool_spec[
"description"] = tool.description
65 return {
"type":
"function",
"function": tool_spec}
69 """Attempt to repair incorrectly formatted json function arguments.
71 Small models (for example llama3.1 8B) may produce invalid argument values
72 which we attempt to repair here.
74 if not isinstance(value, str):
76 if (value.startswith(
"[")
and value.endswith(
"]"))
or (
77 value.startswith(
"{")
and value.endswith(
"}")
80 return json.loads(value)
81 except json.decoder.JSONDecodeError:
87 """Rewrite ollama tool arguments.
89 This function improves tool use quality by fixing common mistakes made by
90 small local tool use models. This will repair invalid json arguments and
91 omit unnecessary arguments with empty values that will fail intent parsing.
99 """Ollama conversation agent."""
101 _attr_has_entity_name =
True
104 """Initialize the agent."""
108 self.
_history_history: dict[str, MessageHistory] = {}
111 if self.
entryentry.options.get(CONF_LLM_HASS_API):
113 conversation.ConversationEntityFeature.CONTROL
117 """When entity is added to Home Assistant."""
119 assist_pipeline.async_migrate_engine(
122 conversation.async_set_agent(self.
hasshass, self.
entryentry, self)
123 self.
entryentry.async_on_unload(
128 """When entity will be removed from Home Assistant."""
129 conversation.async_unset_agent(self.
hasshass, self.
entryentry)
134 """Return a list of supported languages."""
138 self, user_input: conversation.ConversationInput
140 """Process a sentence."""
141 settings = {**self.
entryentry.data, **self.
entryentry.options}
143 client = self.
hasshass.data[DOMAIN][self.
entryentry.entry_id]
144 conversation_id = user_input.conversation_id
or ulid.ulid_now()
145 model = settings[CONF_MODEL]
146 intent_response = intent.IntentResponse(language=user_input.language)
147 llm_api: llm.APIInstance |
None =
None
148 tools: list[dict[str, Any]] |
None =
None
149 user_name: str |
None =
None
150 llm_context = llm.LLMContext(
152 context=user_input.context,
153 user_prompt=user_input.text,
154 language=user_input.language,
155 assistant=conversation.DOMAIN,
156 device_id=user_input.device_id,
159 if settings.get(CONF_LLM_HASS_API):
161 llm_api = await llm.async_get_api(
163 settings[CONF_LLM_HASS_API],
166 except HomeAssistantError
as err:
167 _LOGGER.error(
"Error getting LLM API: %s", err)
168 intent_response.async_set_error(
169 intent.IntentResponseErrorCode.UNKNOWN,
170 f
"Error preparing LLM API: {err}",
173 response=intent_response, conversation_id=user_input.conversation_id
176 _format_tool(tool, llm_api.custom_serializer)
for tool
in llm_api.tools
181 and user_input.context.user_id
183 user := await self.
hasshass.auth.async_get_user(user_input.context.user_id)
186 user_name = user.name
189 message_history: MessageHistory |
None =
None
190 message_history = self.
_history_history.
get(conversation_id)
191 if message_history
is None:
199 + settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
203 "ha_name": self.
hasshass.config.location_name,
204 "user_name": user_name,
205 "llm_context": llm_context,
211 except TemplateError
as err:
212 _LOGGER.error(
"Error rendering prompt: %s", err)
213 intent_response.async_set_error(
214 intent.IntentResponseErrorCode.UNKNOWN,
215 f
"Sorry, I had a problem generating my prompt: {err}",
218 response=intent_response, conversation_id=conversation_id
222 prompt_parts.append(llm_api.api_prompt)
224 prompt =
"\n".join(prompt_parts)
225 _LOGGER.debug(
"Prompt: %s", prompt)
226 _LOGGER.debug(
"Tools: %s", tools)
229 timestamp=time.monotonic(),
231 ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
234 self.
_history_history[conversation_id] = message_history
237 message_history.timestamp = time.monotonic()
243 max_messages =
int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
244 self.
_trim_history_trim_history(message_history, max_messages)
247 message_history.messages.append(
248 ollama.Message(role=MessageRole.USER.value, content=user_input.text)
251 trace.async_conversation_trace_append(
252 trace.ConversationTraceEventType.AGENT_DETAIL,
253 {
"messages": message_history.messages},
258 for _iteration
in range(MAX_TOOL_ITERATIONS):
260 response = await client.chat(
263 messages=
list(message_history.messages),
267 keep_alive=f
"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
268 options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
270 except (ollama.RequestError, ollama.ResponseError)
as err:
271 _LOGGER.error(
"Unexpected error talking to Ollama server: %s", err)
272 intent_response.async_set_error(
273 intent.IntentResponseErrorCode.UNKNOWN,
274 f
"Sorry, I had a problem talking to the Ollama server: {err}",
277 response=intent_response, conversation_id=conversation_id
280 response_message = response[
"message"]
281 message_history.messages.append(
283 role=response_message[
"role"],
284 content=response_message.get(
"content"),
285 tool_calls=response_message.get(
"tool_calls"),
289 tool_calls = response_message.get(
"tool_calls")
290 if not tool_calls
or not llm_api:
293 for tool_call
in tool_calls:
294 tool_input = llm.ToolInput(
295 tool_name=tool_call[
"function"][
"name"],
299 "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
303 tool_response = await llm_api.async_call_tool(tool_input)
304 except (HomeAssistantError, vol.Invalid)
as e:
305 tool_response = {
"error": type(e).__name__}
307 tool_response[
"error_text"] =
str(e)
309 _LOGGER.debug(
"Tool response: %s", tool_response)
310 message_history.messages.append(
312 role=MessageRole.TOOL.value,
313 content=json.dumps(tool_response),
318 intent_response.async_set_speech(response_message[
"content"])
320 response=intent_response, conversation_id=conversation_id
324 """Remove old message histories."""
325 now = time.monotonic()
327 conversation_id: message_history
328 for conversation_id, message_history
in self.
_history_history.items()
329 if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
332 def _trim_history(self, message_history: MessageHistory, max_messages: int) ->
None:
333 """Trims excess messages from a single history."""
338 if message_history.num_user_messages >= max_messages:
342 num_keep = 2 * max_messages
343 drop_index = len(message_history.messages) - num_keep
344 message_history.messages = [
345 message_history.messages[0]
346 ] + message_history.messages[drop_index:]
349 self, hass: HomeAssistant, entry: ConfigEntry
351 """Handle options update."""
353 await hass.config_entries.async_reload(entry.entry_id)
None async_will_remove_from_hass(self)
None __init__(self, ConfigEntry entry)
None _prune_old_histories(self)
None async_added_to_hass(self)
None _trim_history(self, MessageHistory message_history, int max_messages)
conversation.ConversationResult async_process(self, conversation.ConversationInput user_input)
None _async_entry_update_listener(self, HomeAssistant hass, ConfigEntry entry)
list[str]|Literal["*"] supported_languages(self)
web.Response get(self, web.Request request, str config_key)
dict[str, Any] _parse_tool_args(dict[str, Any] arguments)
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)
dict[str, Any] _format_tool(llm.Tool tool, Callable[[Any], Any]|None custom_serializer)
Any _fix_invalid_arguments(Any value)