1 """Module to coordinate llm tools."""
3 from __future__
import annotations
5 from abc
import ABC, abstractmethod
6 from collections.abc
import Callable
7 from dataclasses
import dataclass
8 from decimal
import Decimal
10 from functools
import cache, partial
11 from typing
import Any
13 import slugify
as unicode_slug
14 import voluptuous
as vol
15 from voluptuous_openapi
import UNSUPPORTED, convert
19 ConversationTraceEventType,
20 async_conversation_trace_append,
30 EVENT_HOMEASSISTANT_CLOSE,
31 EVENT_SERVICE_REMOVED,
33 from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
41 config_validation
as cv,
42 device_registry
as dr,
43 entity_registry
as er,
49 from .singleton
import singleton
51 SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str |
None, vol.Schema]]] =
HassKey(
52 "llm_script_parameters_cache"
56 LLM_API_ASSIST =
"assist"
59 'Current time is {{ now().strftime("%H:%M:%S") }}. '
60 'Today\'s date is {{ now().strftime("%Y-%m-%d") }}.\n'
63 DEFAULT_INSTRUCTIONS_PROMPT =
"""You are a voice assistant for Home Assistant.
64 Answer questions about the world truthfully.
65 Answer in plain text. Keep it simple and to the point.
71 """Return the prompt to be used when no API is configured.
73 No longer used since Home Assistant 2024.7.
81 """Get all the LLM APIs."""
89 """Register an API to be exposed to LLMs."""
99 hass: HomeAssistant, api_id: str, llm_context: LLMContext
104 if api_id
not in apis:
107 return await apis[api_id].async_get_api_instance(llm_context)
112 """Get all the LLM APIs."""
116 @dataclass(slots=True)
118 """Tool input to be processed."""
121 context: Context |
None
122 user_prompt: str |
None
124 assistant: str |
None
125 device_id: str |
None
128 @dataclass(slots=True)
130 """Tool input to be processed."""
133 tool_args: dict[str, Any]
137 """LLM Tool base class."""
140 description: str |
None =
None
141 parameters: vol.Schema = vol.Schema({})
145 self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
148 raise NotImplementedError
150 def __repr__(self) -> str:
151 """Represent a string of a Tool."""
152 return f
"<{self.__class__.__name__} - {self.name}>"
157 """Instance of an API to be used by an LLM."""
161 llm_context: LLMContext
163 custom_serializer: Callable[[Any], Any] |
None =
None
166 """Call a LLM tool, validate args and return the response."""
168 ConversationTraceEventType.TOOL_CALL,
169 {
"tool_name": tool_input.tool_name,
"tool_args": tool_input.tool_args},
172 for tool
in self.tools:
173 if tool.name == tool_input.tool_name:
178 return await tool.async_call(self.api.hass, tool_input, self.llm_context)
181 @dataclass(slots=True, kw_only=True)
183 """An API to expose to LLMs."""
191 """Return the instance of the API."""
192 raise NotImplementedError
196 """LLM Tool representing an Intent."""
203 """Init the class."""
206 intent_handler.description
or f
"Execute Home Assistant {self.name} intent"
209 if not (slot_schema := intent_handler.slot_schema):
212 slot_schema = {**slot_schema}
215 for field
in (
"preferred_area_id",
"preferred_floor_id"):
216 if field
in slot_schema:
217 extra_slots.add(field)
218 del slot_schema[field]
225 self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
227 """Handle the intent."""
228 slots = {key: {
"value": val}
for key, val
in tool_input.tool_args.items()}
230 if self.
extra_slotsextra_slots
and llm_context.device_id:
231 device_reg = dr.async_get(hass)
232 device = device_reg.async_get(llm_context.device_id)
234 area: ar.AreaEntry |
None =
None
235 floor: fr.FloorEntry |
None =
None
237 area_reg = ar.async_get(hass)
238 if device.area_id
and (area := area_reg.async_get_area(device.area_id)):
240 floor_reg = fr.async_get(hass)
241 floor = floor_reg.async_get_floor(area.floor_id)
243 for slot_name, slot_value
in (
244 (
"preferred_area_id", area.id
if area
else None),
245 (
"preferred_floor_id", floor.floor_id
if floor
else None),
247 if slot_value
and slot_name
in self.
extra_slotsextra_slots:
248 slots[slot_name] = {
"value": slot_value}
250 intent_response = await intent.async_handle(
252 platform=llm_context.platform,
253 intent_type=self.
namename,
255 text_input=llm_context.user_prompt,
256 context=llm_context.context,
257 language=llm_context.language,
258 assistant=llm_context.assistant,
259 device_id=llm_context.device_id,
261 response = intent_response.as_dict()
262 del response[
"language"]
268 """API exposing Assist API to LLMs."""
271 INTENT_GET_TEMPERATURE,
275 intent.INTENT_GET_STATE,
276 intent.INTENT_NEVERMIND,
277 intent.INTENT_TOGGLE,
278 intent.INTENT_GET_CURRENT_DATE,
279 intent.INTENT_GET_CURRENT_TIME,
280 intent.INTENT_RESPOND,
284 """Init the class."""
291 partial(unicode_slug.slugify, separator=
"_", lowercase=
False)
295 """Return the instance of the API."""
296 if llm_context.assistant:
298 self.hass, llm_context.assistant
301 exposed_entities =
None
306 llm_context=llm_context,
308 custom_serializer=_selector_serializer,
313 self, llm_context: LLMContext, exposed_entities: dict |
None
315 """Return the prompt for the API."""
316 if not exposed_entities:
318 "Only if the user wants to control a device, tell them to expose entities "
319 "to their voice assistant in Home Assistant."
324 "When controlling Home Assistant always call the intent tools. "
325 "Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
326 "When controlling a device, prefer passing just name and domain. "
327 "When controlling an area, prefer passing just area name and domain."
330 area: ar.AreaEntry |
None =
None
331 floor: fr.FloorEntry |
None =
None
332 if llm_context.device_id:
333 device_reg = dr.async_get(self.hass)
334 device = device_reg.async_get(llm_context.device_id)
337 area_reg = ar.async_get(self.hass)
338 if device.area_id
and (area := area_reg.async_get_area(device.area_id)):
339 floor_reg = fr.async_get(self.hass)
341 floor = floor_reg.async_get_floor(area.floor_id)
343 extra =
"and all generic commands like 'turn on the lights' should target this area."
346 prompt.append(f
"You are in area {area.name} (floor {floor.name}) {extra}")
348 prompt.append(f
"You are in area {area.name} {extra}")
351 "When a user asks to turn on all devices of a specific type, "
352 "ask user to specify an area, unless there is only one device of that type."
356 self.hass, llm_context.device_id
358 prompt.append(
"This device is not able to start timers.")
362 "An overview of the areas and the devices in this smart home:"
364 prompt.append(yaml.dump(
list(exposed_entities.values())))
366 return "\n".join(prompt)
370 self, llm_context: LLMContext, exposed_entities: dict |
None
372 """Return a list of LLM tools."""
375 self.hass, llm_context.device_id
377 ignore_intents = ignore_intents | {
378 intent.INTENT_START_TIMER,
379 intent.INTENT_CANCEL_TIMER,
380 intent.INTENT_INCREASE_TIMER,
381 intent.INTENT_DECREASE_TIMER,
382 intent.INTENT_PAUSE_TIMER,
383 intent.INTENT_UNPAUSE_TIMER,
384 intent.INTENT_TIMER_STATUS,
389 for intent_handler
in intent.async_get(self.hass)
390 if intent_handler.intent_type
not in ignore_intents
393 exposed_domains: set[str] |
None =
None
394 if exposed_entities
is not None:
400 for intent_handler
in intent_handlers
401 if intent_handler.platforms
is None
402 or intent_handler.platforms & exposed_domains
405 tools: list[Tool] = [
407 for intent_handler
in intent_handlers
410 if llm_context.assistant
is not None:
411 for state
in self.hass.states.async_all(SCRIPT_DOMAIN):
413 self.hass, llm_context.assistant, state.entity_id
417 tools.append(
ScriptTool(self.hass, state.entity_id))
423 hass: HomeAssistant, assistant: str
424 ) -> dict[str, dict[str, Any]]:
425 """Get exposed entities."""
426 area_registry = ar.async_get(hass)
427 entity_registry = er.async_get(hass)
428 device_registry = dr.async_get(hass)
429 interesting_attributes = {
431 "current_temperature",
435 "unit_of_measurement",
447 for state
in hass.states.async_all():
450 or state.domain == SCRIPT_DOMAIN
454 description: str |
None =
None
455 entity_entry = entity_registry.async_get(state.entity_id)
459 if entity_entry
is not None:
460 names.extend(entity_entry.aliases)
461 if entity_entry.area_id
and (
462 area := area_registry.async_get_area(entity_entry.area_id)
465 area_names.append(area.name)
466 area_names.extend(area.aliases)
467 elif entity_entry.device_id
and (
468 device := device_registry.async_get(entity_entry.device_id)
471 if device.area_id
and (
472 area := area_registry.async_get_area(device.area_id)
474 area_names.append(area.name)
475 area_names.extend(area.aliases)
477 info: dict[str, Any] = {
478 "names":
", ".join(names),
479 "domain": state.domain,
480 "state": state.state,
484 info[
"description"] = description
487 info[
"areas"] =
", ".join(area_names)
490 attr_name:
str(attr_value)
491 if isinstance(attr_value, (Enum, Decimal, int))
493 for attr_name, attr_value
in state.attributes.items()
494 if attr_name
in interesting_attributes
496 info[
"attributes"] = attributes
498 entities[state.entity_id] = info
504 """Convert selectors into OpenAPI schema."""
509 return {
"type":
"string",
"pattern":
"^(?:\\/backup|\\w+)$"}
512 return {
"type":
"boolean"}
517 "items": {
"type":
"number"},
524 return convert(cv.CONDITIONS_SCHEMA)
527 return convert(vol.Schema(schema.config[
"value"]))
529 result: dict[str, Any]
531 result = {
"type":
"number"}
532 if "min" in schema.config:
533 result[
"minimum"] = schema.config[
"min"]
534 elif "min_mireds" in schema.config:
535 result[
"minimum"] = schema.config[
"min_mireds"]
536 if "max" in schema.config:
537 result[
"maximum"] = schema.config[
"max"]
538 elif "max_mireds" in schema.config:
539 result[
"maximum"] = schema.config[
"max_mireds"]
543 if schema.config.get(
"countries"):
544 return {
"type":
"string",
"enum": schema.config[
"countries"]}
545 return {
"type":
"string",
"format":
"ISO 3166-1 alpha-2"}
548 return {
"type":
"string",
"format":
"date"}
551 return {
"type":
"string",
"format":
"date-time"}
554 return convert(cv.time_period_dict)
557 if schema.config.get(
"multiple"):
558 return {
"type":
"array",
"items": {
"type":
"string",
"format":
"entity_id"}}
560 return {
"type":
"string",
"format":
"entity_id"}
563 if schema.config.get(
"languages"):
564 return {
"type":
"string",
"enum": schema.config[
"languages"]}
565 return {
"type":
"string",
"format":
"RFC 5646"}
568 return convert(schema.DATA_SCHEMA)
571 result = {
"type":
"number"}
572 if "min" in schema.config:
573 result[
"minimum"] = schema.config[
"min"]
574 if "max" in schema.config:
575 result[
"maximum"] = schema.config[
"max"]
579 return {
"type":
"object",
"additionalProperties":
True}
583 x[
"value"]
if isinstance(x, dict)
else x
for x
in schema.config[
"options"]
585 if schema.config.get(
"multiple"):
588 "items": {
"type":
"string",
"enum": options},
591 return {
"type":
"string",
"enum": options}
594 return convert(cv.TARGET_SERVICE_FIELDS)
597 return {
"type":
"string",
"format":
"jinja2"}
600 return {
"type":
"string",
"format":
"time"}
603 return {
"type":
"array",
"items": {
"type":
"string"}}
605 if schema.config.get(
"multiple"):
606 return {
"type":
"array",
"items": {
"type":
"string"}}
608 return {
"type":
"string"}
612 hass: HomeAssistant, entity_id: str
613 ) -> tuple[str |
None, vol.Schema]:
614 """Get script description and schema."""
615 entity_registry = er.async_get(hass)
618 parameters = vol.Schema({})
619 entity_entry = entity_registry.async_get(entity_id)
620 if entity_entry
and entity_entry.unique_id:
621 parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
623 if parameters_cache
is None:
624 parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
627 def clear_cache(event: Event) ->
None:
628 """Clear script parameter cache on script reload or delete."""
630 event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
631 and event.data[ATTR_SERVICE]
in parameters_cache
633 parameters_cache.pop(event.data[ATTR_SERVICE])
635 cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
638 def on_homeassistant_close(event: Event) ->
None:
642 hass.bus.async_listen_once(
643 EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
646 if entity_entry.unique_id
in parameters_cache:
647 return parameters_cache[entity_entry.unique_id]
649 if service_desc := service.async_get_cached_service_description(
650 hass, SCRIPT_DOMAIN, entity_entry.unique_id
652 description = service_desc.get(
"description")
653 schema: dict[vol.Marker, Any] = {}
654 fields = service_desc.get(
"fields", {})
656 for field, config
in fields.items():
657 field_description = config.get(
"description")
658 if not field_description:
659 field_description = config.get(
"name")
661 if config.get(
"required"):
662 key = vol.Required(field, description=field_description)
664 key = vol.Optional(field, description=field_description)
665 if "selector" in config:
666 schema[key] = selector.selector(config[
"selector"])
668 schema[key] = cv.string
670 parameters = vol.Schema(schema)
672 aliases: list[str] = []
673 if entity_entry.name:
674 aliases.append(entity_entry.name)
675 if entity_entry.aliases:
676 aliases.extend(entity_entry.aliases)
679 description = description +
". Aliases: " +
str(
list(aliases))
681 description =
"Aliases: " +
str(
list(aliases))
683 parameters_cache[entity_entry.unique_id] = (description, parameters)
685 return description, parameters
689 """LLM Tool representing a Script."""
694 script_entity_id: str,
696 """Init the class."""
698 if self.
namename[0].isdigit():
702 hass, script_entity_id
706 self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
708 """Run the script."""
710 for field, validator
in self.
parametersparameters.schema.items():
711 if field
not in tool_input.tool_args:
714 area_reg = ar.async_get(hass)
715 if validator.config.get(
"multiple"):
716 areas: list[ar.AreaEntry] = []
717 for area
in tool_input.tool_args[field]:
718 areas.extend(intent.find_areas(area, area_reg))
719 tool_input.tool_args[field] =
list({area.id
for area
in areas})
721 area = tool_input.tool_args[field]
722 area =
list(intent.find_areas(area, area_reg))[0].id
723 tool_input.tool_args[field] = area
726 floor_reg = fr.async_get(hass)
727 if validator.config.get(
"multiple"):
728 floors: list[fr.FloorEntry] = []
729 for floor
in tool_input.tool_args[field]:
730 floors.extend(intent.find_floors(floor, floor_reg))
731 tool_input.tool_args[field] =
list(
732 {floor.floor_id
for floor
in floors}
735 floor = tool_input.tool_args[field]
736 floor =
list(intent.find_floors(floor, floor_reg))[0].floor_id
737 tool_input.tool_args[field] = floor
739 result = await hass.services.async_call(
742 tool_input.tool_args,
743 context=llm_context.context,
745 return_response=
True,
748 return {
"success":
True,
"result": result}
JsonObjectType async_call_tool(self, ToolInput tool_input)
APIInstance async_get_api_instance(self, LLMContext llm_context)
str _async_get_api_prompt(self, LLMContext llm_context, dict|None exposed_entities)
dictionary IGNORE_INTENTS
APIInstance async_get_api_instance(self, LLMContext llm_context)
list[Tool] _async_get_tools(self, LLMContext llm_context, dict|None exposed_entities)
None __init__(self, HomeAssistant hass)
None async_conversation_trace_append(ConversationTraceEventType event_type, dict[str, Any] event_data)
bool async_should_expose(HomeAssistant hass, str assistant, str entity_id)
bool async_device_supports_timers(HomeAssistant hass, str device_id)
tuple[str, str] split_entity_id(str entity_id)
str async_render_no_api_prompt(HomeAssistant hass)
Any _selector_serializer(Any schema)
dict[str, API] _async_get_apis(HomeAssistant hass)
tuple[str|None, vol.Schema] _get_cached_script_parameters(HomeAssistant hass, str entity_id)
list[API] async_get_apis(HomeAssistant hass)
APIInstance async_get_api(HomeAssistant hass, str api_id, LLMContext llm_context)
dict[str, dict[str, Any]] _get_exposed_entities(HomeAssistant hass, str assistant)
None async_register_api(HomeAssistant hass, API api)