Home Assistant Unofficial Reference 2024.12.1
util.py
Go to the documentation of this file.
1 """Utility functions for the MQTT integration."""
2 
3 from __future__ import annotations
4 
5 import asyncio
6 from collections.abc import Callable, Coroutine
7 from functools import lru_cache
8 import logging
9 import os
10 from pathlib import Path
11 import tempfile
12 from typing import Any
13 
14 import voluptuous as vol
15 
16 from homeassistant.config_entries import ConfigEntry, ConfigEntryState
17 from homeassistant.const import MAX_LENGTH_STATE_STATE, STATE_UNKNOWN, Platform
18 from homeassistant.core import HomeAssistant, callback
19 from homeassistant.exceptions import HomeAssistantError
20 from homeassistant.helpers import config_validation as cv, template
21 from homeassistant.helpers.typing import ConfigType
22 from homeassistant.util.async_ import create_eager_task
23 
24 from .const import (
25  ATTR_PAYLOAD,
26  ATTR_QOS,
27  ATTR_RETAIN,
28  ATTR_TOPIC,
29  CONF_CERTIFICATE,
30  CONF_CLIENT_CERT,
31  CONF_CLIENT_KEY,
32  DEFAULT_ENCODING,
33  DEFAULT_QOS,
34  DEFAULT_RETAIN,
35  DOMAIN,
36 )
37 from .models import DATA_MQTT, DATA_MQTT_AVAILABLE, ReceiveMessage
38 
39 AVAILABILITY_TIMEOUT = 50.0
40 
41 TEMP_DIR_NAME = f"home-assistant-{DOMAIN}"
42 
43 _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
44 
45 _LOGGER = logging.getLogger(__name__)
46 
47 
49  """Ensure a cool down period before executing a job.
50 
51  When a new execute request arrives we cancel the current request
52  and start a new one.
53 
54  We allow patching this util, as we generally have exceptions
55  for sleeps/waits/debouncers/timers causing long run times in tests.
56  """
57 
58  def __init__(
59  self, timeout: float, callback_job: Callable[[], Coroutine[Any, None, None]]
60  ) -> None:
61  """Initialize the timer."""
62  self._loop_loop = asyncio.get_running_loop()
63  self._timeout_timeout = timeout
64  self._callback_callback = callback_job
65  self._task_task: asyncio.Task | None = None
66  self._timer_timer: asyncio.TimerHandle | None = None
67  self._next_execute_time_next_execute_time = 0.0
68 
69  def set_timeout(self, timeout: float) -> None:
70  """Set a new timeout period."""
71  self._timeout_timeout = timeout
72 
73  async def _async_job(self) -> None:
74  """Execute after a cooldown period."""
75  try:
76  await self._callback_callback()
77  except HomeAssistantError as ha_error:
78  _LOGGER.error("%s", ha_error)
79 
80  @callback
81  def _async_task_done(self, task: asyncio.Task) -> None:
82  """Handle task done."""
83  self._task_task = None
84 
85  @callback
86  def async_execute(self) -> asyncio.Task:
87  """Execute the job."""
88  if self._task_task:
89  # Task already running,
90  # so we schedule another run
91  self.async_scheduleasync_schedule()
92  return self._task_task
93 
94  self._async_cancel_timer_async_cancel_timer()
95  self._task_task = create_eager_task(self._async_job_async_job())
96  self._task_task.add_done_callback(self._async_task_done_async_task_done)
97  return self._task_task
98 
99  @callback
100  def _async_cancel_timer(self) -> None:
101  """Cancel any pending task."""
102  if self._timer_timer:
103  self._timer_timer.cancel()
104  self._timer_timer = None
105 
106  @callback
107  def async_schedule(self) -> None:
108  """Ensure we execute after a cooldown period."""
109  # We want to reschedule the timer in the future
110  # every time this is called.
111  next_when = self._loop_loop.time() + self._timeout_timeout
112  if not self._timer_timer:
113  self._timer_timer = self._loop_loop.call_at(next_when, self._async_timer_reached_async_timer_reached)
114  return
115 
116  if self._timer_timer.when() < next_when:
117  # Timer already running, set the next execute time
118  # if it fires too early, it will get rescheduled
119  self._next_execute_time_next_execute_time = next_when
120 
121  @callback
122  def _async_timer_reached(self) -> None:
123  """Handle timer fire."""
124  self._timer_timer = None
125  if self._loop_loop.time() >= self._next_execute_time_next_execute_time:
126  self.async_executeasync_execute()
127  return
128  # Timer fired too early because there were multiple
129  # calls async_schedule. Reschedule the timer.
130  self._timer_timer = self._loop_loop.call_at(
131  self._next_execute_time_next_execute_time, self._async_timer_reached_async_timer_reached
132  )
133 
134  async def async_cleanup(self) -> None:
135  """Cleanup any pending task."""
136  self._async_cancel_timer_async_cancel_timer()
137  if not self._task_task:
138  return
139  self._task_task.cancel()
140  try:
141  await self._task_task
142  except asyncio.CancelledError:
143  pass
144  except Exception:
145  _LOGGER.exception("Error cleaning up task")
146 
147 
148 def platforms_from_config(config: list[ConfigType]) -> set[Platform | str]:
149  """Return the platforms to be set up."""
150  return {key for platform in config for key in platform}
151 
152 
154  hass: HomeAssistant,
155  config_entry: ConfigEntry,
156  platforms: set[Platform | str],
157  late: bool = False,
158 ) -> None:
159  """Forward the config entry setup to the platforms and set up discovery."""
160  mqtt_data = hass.data[DATA_MQTT]
161  platforms_loaded = mqtt_data.platforms_loaded
162  new_platforms: set[Platform | str] = platforms - platforms_loaded
163  tasks: list[asyncio.Task] = []
164  if "device_automation" in new_platforms:
165  # Local import to avoid circular dependencies
166  # pylint: disable-next=import-outside-toplevel
167  from . import device_automation
168 
169  tasks.append(
170  create_eager_task(device_automation.async_setup_entry(hass, config_entry))
171  )
172  if "tag" in new_platforms:
173  # Local import to avoid circular dependencies
174  # pylint: disable-next=import-outside-toplevel
175  from . import tag
176 
177  tasks.append(create_eager_task(tag.async_setup_entry(hass, config_entry)))
178  if new_entity_platforms := (new_platforms - {"tag", "device_automation"}):
179  tasks.append(
180  create_eager_task(
181  hass.config_entries.async_forward_entry_setups(
182  config_entry, new_entity_platforms
183  )
184  )
185  )
186  if not tasks:
187  return
188  await asyncio.gather(*tasks)
189  platforms_loaded.update(new_platforms)
190 
191 
192 def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
193  """Return true when the MQTT config entry is enabled."""
194  # If the mqtt client is connected, skip the expensive config
195  # entry check as its roughly two orders of magnitude faster.
196  return (
197  DATA_MQTT in hass.data and hass.data[DATA_MQTT].client.connected
198  ) or hass.config_entries.async_has_entries(
199  DOMAIN, include_disabled=False, include_ignore=False
200  )
201 
202 
203 async def async_wait_for_mqtt_client(hass: HomeAssistant) -> bool:
204  """Wait for the MQTT client to become available.
205 
206  Waits when mqtt set up is in progress,
207  It is not needed that the client is connected.
208  Returns True if the mqtt client is available.
209  Returns False when the client is not available.
210  """
211  if not mqtt_config_entry_enabled(hass):
212  return False
213 
214  entry = hass.config_entries.async_entries(DOMAIN)[0]
215  if entry.state == ConfigEntryState.LOADED:
216  return True
217 
218  state_reached_future: asyncio.Future[bool]
219  if DATA_MQTT_AVAILABLE not in hass.data:
220  state_reached_future = hass.loop.create_future()
221  hass.data[DATA_MQTT_AVAILABLE] = state_reached_future
222  else:
223  state_reached_future = hass.data[DATA_MQTT_AVAILABLE]
224 
225  try:
226  async with asyncio.timeout(AVAILABILITY_TIMEOUT):
227  # Await the client setup or an error state was received
228  return await state_reached_future
229  except TimeoutError:
230  return False
231 
232 
233 def valid_topic(topic: Any) -> str:
234  """Validate that this is a valid topic name/filter.
235 
236  This function is not cached and is not expected to be called
237  directly outside of this module. It is not marked as protected
238  only because its tested directly in test_util.py.
239 
240  If it gets used outside of valid_subscribe_topic and
241  valid_publish_topic, it may need an lru_cache decorator or
242  an lru_cache decorator on the function where its used.
243  """
244  validated_topic = cv.string(topic)
245  try:
246  raw_validated_topic = validated_topic.encode("utf-8")
247  except UnicodeError as err:
248  raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") from err
249  if not raw_validated_topic:
250  raise vol.Invalid("MQTT topic name/filter must not be empty.")
251  if len(raw_validated_topic) > 65535:
252  raise vol.Invalid(
253  "MQTT topic name/filter must not be longer than 65535 encoded bytes."
254  )
255 
256  for char in validated_topic:
257  if char == "\0":
258  raise vol.Invalid("MQTT topic name/filter must not contain null character.")
259  if char <= "\u001f" or "\u007f" <= char <= "\u009f":
260  raise vol.Invalid(
261  "MQTT topic name/filter must not contain control characters."
262  )
263  if "\ufdd0" <= char <= "\ufdef" or (ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF):
264  raise vol.Invalid("MQTT topic name/filter must not contain non-characters.")
265 
266  return validated_topic
267 
268 
269 @lru_cache
270 def valid_subscribe_topic(topic: Any) -> str:
271  """Validate that we can subscribe using this MQTT topic."""
272  validated_topic = valid_topic(topic)
273  if "+" in validated_topic:
274  for i in (i for i, c in enumerate(validated_topic) if c == "+"):
275  if (i > 0 and validated_topic[i - 1] != "/") or (
276  i < len(validated_topic) - 1 and validated_topic[i + 1] != "/"
277  ):
278  raise vol.Invalid(
279  "Single-level wildcard must occupy an entire level of the filter"
280  )
281 
282  index = validated_topic.find("#")
283  if index != -1:
284  if index != len(validated_topic) - 1:
285  # If there are multiple wildcards, this will also trigger
286  raise vol.Invalid(
287  "Multi-level wildcard must be the last character in the topic filter."
288  )
289  if len(validated_topic) > 1 and validated_topic[index - 1] != "/":
290  raise vol.Invalid(
291  "Multi-level wildcard must be after a topic level separator."
292  )
293 
294  return validated_topic
295 
296 
297 def valid_subscribe_topic_template(value: Any) -> template.Template:
298  """Validate either a jinja2 template or a valid MQTT subscription topic."""
299  tpl = cv.template(value)
300 
301  if tpl.is_static:
302  valid_subscribe_topic(value)
303 
304  return tpl
305 
306 
307 @lru_cache
308 def valid_publish_topic(topic: Any) -> str:
309  """Validate that we can publish using this MQTT topic."""
310  validated_topic = valid_topic(topic)
311  if "+" in validated_topic or "#" in validated_topic:
312  raise vol.Invalid("Wildcards cannot be used in topic names")
313  return validated_topic
314 
315 
316 def valid_qos_schema(qos: Any) -> int:
317  """Validate that QOS value is valid."""
318  validated_qos: int = _VALID_QOS_SCHEMA(qos)
319  return validated_qos
320 
321 
322 _MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
323  {
324  vol.Required(ATTR_TOPIC): valid_publish_topic,
325  vol.Required(ATTR_PAYLOAD): cv.string,
326  vol.Optional(ATTR_QOS, default=DEFAULT_QOS): valid_qos_schema,
327  vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
328  },
329  required=True,
330 )
331 
332 
333 def valid_birth_will(config: ConfigType) -> ConfigType:
334  """Validate a birth or will configuration and required topic/payload."""
335  if config:
336  config = _MQTT_WILL_BIRTH_SCHEMA(config)
337  return config
338 
339 
341  hass: HomeAssistant, config: ConfigType
342 ) -> None:
343  """Create certificate temporary files for the MQTT client."""
344 
345  def _create_temp_file(temp_file: Path, data: str | None) -> None:
346  if data is None or data == "auto":
347  if temp_file.exists():
348  os.remove(Path(temp_file))
349  return
350  temp_file.write_text(data)
351 
352  def _create_temp_dir_and_files() -> None:
353  """Create temporary directory."""
354  temp_dir = Path(tempfile.gettempdir()) / TEMP_DIR_NAME
355 
356  if (
357  config.get(CONF_CERTIFICATE)
358  or config.get(CONF_CLIENT_CERT)
359  or config.get(CONF_CLIENT_KEY)
360  ) and not temp_dir.exists():
361  temp_dir.mkdir(0o700)
362 
363  _create_temp_file(temp_dir / CONF_CERTIFICATE, config.get(CONF_CERTIFICATE))
364  _create_temp_file(temp_dir / CONF_CLIENT_CERT, config.get(CONF_CLIENT_CERT))
365  _create_temp_file(temp_dir / CONF_CLIENT_KEY, config.get(CONF_CLIENT_KEY))
366 
367  await hass.async_add_executor_job(_create_temp_dir_and_files)
368 
369 
371  logger: logging.Logger, proposed_state: str, entity_id: str, msg: ReceiveMessage
372 ) -> bool:
373  """Check if the processed state is too long and log warning."""
374  if (state_length := len(proposed_state)) > MAX_LENGTH_STATE_STATE:
375  logger.warning(
376  "Cannot update state for entity %s after processing "
377  "payload on topic %s. The requested state (%s) exceeds "
378  "the maximum allowed length (%s). Fall back to "
379  "%s, failed state: %s",
380  entity_id,
381  msg.topic,
382  state_length,
383  MAX_LENGTH_STATE_STATE,
384  STATE_UNKNOWN,
385  proposed_state[:8192],
386  )
387  return True
388 
389  return False
390 
391 
392 def get_file_path(option: str, default: str | None = None) -> str | None:
393  """Get file path of a certificate file."""
394  temp_dir = Path(tempfile.gettempdir()) / TEMP_DIR_NAME
395  if not temp_dir.exists():
396  return default
397 
398  file_path: Path = temp_dir / option
399  if not file_path.exists():
400  return default
401 
402  return str(temp_dir / option)
403 
404 
405 def migrate_certificate_file_to_content(file_name_or_auto: str) -> str | None:
406  """Convert certificate file or setting to config entry setting."""
407  if file_name_or_auto == "auto":
408  return "auto"
409  try:
410  with open(file_name_or_auto, encoding=DEFAULT_ENCODING) as certificate_file:
411  return certificate_file.read()
412  except OSError:
413  return None
None _async_task_done(self, asyncio.Task task)
Definition: util.py:81
None __init__(self, float timeout, Callable[[], Coroutine[Any, None, None]] callback_job)
Definition: util.py:60
ConfigType valid_birth_will(ConfigType config)
Definition: util.py:333
None async_create_certificate_temp_files(HomeAssistant hass, ConfigType config)
Definition: util.py:342
bool check_state_too_long(logging.Logger logger, str proposed_state, str entity_id, ReceiveMessage msg)
Definition: util.py:372
str|None migrate_certificate_file_to_content(str file_name_or_auto)
Definition: util.py:405
bool async_wait_for_mqtt_client(HomeAssistant hass)
Definition: util.py:203
str|None get_file_path(str option, str|None default=None)
Definition: util.py:392
str valid_subscribe_topic(Any topic)
Definition: util.py:270
bool|None mqtt_config_entry_enabled(HomeAssistant hass)
Definition: util.py:192
str valid_topic(Any topic)
Definition: util.py:233
set[Platform|str] platforms_from_config(list[ConfigType] config)
Definition: util.py:148
template.Template valid_subscribe_topic_template(Any value)
Definition: util.py:297
str valid_publish_topic(Any topic)
Definition: util.py:308
None async_forward_entry_setup_and_setup_discovery(HomeAssistant hass, ConfigEntry config_entry, set[Platform|str] platforms, bool late=False)
Definition: util.py:158
None open(self, **Any kwargs)
Definition: lock.py:86
bool time(HomeAssistant hass, dt_time|str|None before=None, dt_time|str|None after=None, str|Container[str]|None weekday=None)
Definition: condition.py:802