Home Assistant Unofficial Reference 2024.12.1
pool.py
Go to the documentation of this file.
1 """A pool for sqlite connections."""
2 
3 from __future__ import annotations
4 
5 import asyncio
6 import logging
7 import threading
8 import traceback
9 from typing import Any
10 
11 from sqlalchemy.exc import SQLAlchemyError
12 from sqlalchemy.pool import (
13  ConnectionPoolEntry,
14  NullPool,
15  SingletonThreadPool,
16  StaticPool,
17 )
18 
19 from homeassistant.helpers.frame import ReportBehavior, report_usage
20 from homeassistant.util.loop import raise_for_blocking_call
21 
22 _LOGGER = logging.getLogger(__name__)
23 
24 # For debugging the MutexPool
25 DEBUG_MUTEX_POOL = True
26 DEBUG_MUTEX_POOL_TRACE = False
27 
28 POOL_SIZE = 5
29 
30 ADVISE_MSG = (
31  "Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job()"
32 )
33 
34 
35 class RecorderPool(SingletonThreadPool, NullPool):
36  """A hybrid of NullPool and SingletonThreadPool.
37 
38  When called from the creating thread or db executor acts like SingletonThreadPool
39  When called from any other thread, acts like NullPool
40  """
41 
42  def __init__( # pylint: disable=super-init-not-called
43  self,
44  creator: Any,
45  recorder_and_worker_thread_ids: set[int] | None = None,
46  **kw: Any,
47  ) -> None:
48  """Create the pool."""
49  kw["pool_size"] = POOL_SIZE
50  assert (
51  recorder_and_worker_thread_ids is not None
52  ), "recorder_and_worker_thread_ids is required"
53  self.recorder_and_worker_thread_idsrecorder_and_worker_thread_ids = recorder_and_worker_thread_ids
54  SingletonThreadPool.__init__(self, creator, **kw)
55 
56  def recreate(self) -> RecorderPool:
57  """Recreate the pool."""
58  self.logger.info("Pool recreating")
59  return self.__class__(
60  self._creator,
61  pool_size=self.size,
62  recycle=self._recycle,
63  echo=self.echo,
64  pre_ping=self._pre_ping,
65  logging_name=self._orig_logging_name,
66  reset_on_return=self._reset_on_return,
67  _dispatch=self.dispatch,
68  dialect=self._dialect,
69  recorder_and_worker_thread_ids=self.recorder_and_worker_thread_idsrecorder_and_worker_thread_ids,
70  )
71 
72  def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
73  if threading.get_ident() in self.recorder_and_worker_thread_idsrecorder_and_worker_thread_ids:
74  super()._do_return_conn(record)
75  return
76  record.close()
77 
78  def shutdown(self) -> None:
79  """Close the connection."""
80  if (
81  threading.get_ident() in self.recorder_and_worker_thread_idsrecorder_and_worker_thread_ids
82  and self._conn
83  and hasattr(self._conn, "current")
84  and (conn := self._conn.current())
85  ):
86  conn.close()
87 
88  def dispose(self) -> None:
89  """Dispose of the connection."""
90  if threading.get_ident() in self.recorder_and_worker_thread_idsrecorder_and_worker_thread_ids:
91  super().dispose()
92 
93  def _do_get(self) -> ConnectionPoolEntry: # type: ignore[return]
94  if threading.get_ident() in self.recorder_and_worker_thread_idsrecorder_and_worker_thread_ids:
95  return super()._do_get()
96  try:
97  asyncio.get_running_loop()
98  except RuntimeError:
99  # Not in an event loop but not in the recorder or worker thread
100  # which is allowed but discouraged since its much slower
101  return self._do_get_db_connection_protected_do_get_db_connection_protected()
102  # In the event loop, raise an exception
103  raise_for_blocking_call( # noqa: RET503
104  self._do_get_db_connection_protected_do_get_db_connection_protected,
105  strict=True,
106  advise_msg=ADVISE_MSG,
107  )
108  # raise_for_blocking_call will raise an exception
109 
110  def _do_get_db_connection_protected(self) -> ConnectionPoolEntry:
111  report_usage(
112  (
113  "accesses the database without the database executor; "
114  f"{ADVISE_MSG} "
115  "for faster database operations"
116  ),
117  exclude_integrations={"recorder"},
118  core_behavior=ReportBehavior.LOG,
119  )
120  return NullPool._create_connection(self) # noqa: SLF001
121 
122 
123 class MutexPool(StaticPool):
124  """A pool which prevents concurrent accesses from multiple threads.
125 
126  This is used in tests to prevent unsafe concurrent accesses to in-memory SQLite
127  databases.
128  """
129 
130  _reference_counter = 0
131  pool_lock: threading.RLock
132 
133  def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
134  if DEBUG_MUTEX_POOL_TRACE:
135  trace = traceback.extract_stack()
136  trace_msg = "\n" + "".join(traceback.format_list(trace[:-1]))
137  else:
138  trace_msg = ""
139 
140  super()._do_return_conn(record)
141  if DEBUG_MUTEX_POOL:
142  self._reference_counter_reference_counter -= 1
143  _LOGGER.debug(
144  "%s return conn ctr: %s%s",
145  threading.current_thread().name,
146  self._reference_counter_reference_counter,
147  trace_msg,
148  )
149  MutexPool.pool_lock.release()
150 
151  def _do_get(self) -> ConnectionPoolEntry:
152  if DEBUG_MUTEX_POOL_TRACE:
153  trace = traceback.extract_stack()
154  trace_msg = "".join(traceback.format_list(trace[:-1]))
155  else:
156  trace_msg = ""
157 
158  if DEBUG_MUTEX_POOL:
159  _LOGGER.debug("%s wait conn%s", threading.current_thread().name, trace_msg)
160  # pylint: disable-next=consider-using-with
161  got_lock = MutexPool.pool_lock.acquire(timeout=10)
162  if not got_lock:
163  raise SQLAlchemyError
164  conn = super()._do_get()
165  if DEBUG_MUTEX_POOL:
166  self._reference_counter_reference_counter += 1
167  _LOGGER.debug(
168  "%s get conn: ctr: %s",
169  threading.current_thread().name,
170  self._reference_counter_reference_counter,
171  )
172  return conn
None _do_return_conn(self, ConnectionPoolEntry record)
Definition: pool.py:133
ConnectionPoolEntry _do_get(self)
Definition: pool.py:151
None _do_return_conn(self, ConnectionPoolEntry record)
Definition: pool.py:72
ConnectionPoolEntry _do_get_db_connection_protected(self)
Definition: pool.py:110
None __init__(self, Any creator, set[int]|None recorder_and_worker_thread_ids=None, **Any kw)
Definition: pool.py:47
None report_usage(str what, *str|None breaks_in_ha_version=None, ReportBehavior core_behavior=ReportBehavior.ERROR, ReportBehavior core_integration_behavior=ReportBehavior.LOG, ReportBehavior custom_integration_behavior=ReportBehavior.LOG, set[str]|None exclude_integrations=None, str|None integration_domain=None, int level=logging.WARNING)
Definition: frame.py:195
None raise_for_blocking_call(Callable[..., Any] func, Callable[[dict[str, Any]], bool]|None check_allowed=None, bool strict=True, bool strict_core=True, **Any mapped_args)
Definition: loop.py:41