Home Assistant Unofficial Reference 2024.12.1
state_attributes.py
Go to the documentation of this file.
1 """Support managing StateAttributes."""
2 
3 from __future__ import annotations
4 
5 from collections.abc import Collection, Iterable
6 import logging
7 from typing import TYPE_CHECKING, cast
8 
9 from sqlalchemy.orm.session import Session
10 
11 from homeassistant.core import Event, EventStateChangedData
12 from homeassistant.util.collection import chunked_or_all
13 from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
14 
15 from ..db_schema import StateAttributes
16 from ..queries import get_shared_attributes
17 from ..util import execute_stmt_lambda_element
18 from . import BaseLRUTableManager
19 
20 if TYPE_CHECKING:
21  from ..core import Recorder
22 
23 # The number of attribute ids to cache in memory
24 #
25 # Based on:
26 # - The number of overlapping attributes
27 # - How frequently states with overlapping attributes will change
28 # - How much memory our low end hardware has
29 CACHE_SIZE = 2048
30 
31 _LOGGER = logging.getLogger(__name__)
32 
33 
34 class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
35  """Manage the StateAttributes table."""
36 
37  def __init__(self, recorder: Recorder) -> None:
38  """Initialize the event type manager."""
39  super().__init__(recorder, CACHE_SIZE)
40 
41  def serialize_from_event(self, event: Event[EventStateChangedData]) -> bytes | None:
42  """Serialize event data."""
43  try:
44  return StateAttributes.shared_attrs_bytes_from_event(
45  event, self.recorder.dialect_name
46  )
47  except JSON_ENCODE_EXCEPTIONS as ex:
48  _LOGGER.warning(
49  "State is not JSON serializable: %s: %s",
50  event.data["new_state"],
51  ex,
52  )
53  return None
54 
55  def load(
56  self, events: list[Event[EventStateChangedData]], session: Session
57  ) -> None:
58  """Load the shared_attrs to attributes_ids mapping into memory from events.
59 
60  This call is not thread-safe and must be called from the
61  recorder thread.
62  """
63  if hashes := {
64  StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
65  for event in events
66  if (shared_attrs_bytes := self.serialize_from_eventserialize_from_event(event))
67  }:
68  self._load_from_hashes_load_from_hashes(hashes, session)
69 
70  def get(self, shared_attr: str, data_hash: int, session: Session) -> int | None:
71  """Resolve shared_attrs to the attributes_id.
72 
73  This call is not thread-safe and must be called from the
74  recorder thread.
75  """
76  return self.get_manyget_many(((shared_attr, data_hash),), session)[shared_attr]
77 
78  def get_many(
79  self, shared_attrs_data_hashes: Iterable[tuple[str, int]], session: Session
80  ) -> dict[str, int | None]:
81  """Resolve shared_attrs to attributes_ids.
82 
83  This call is not thread-safe and must be called from the
84  recorder thread.
85  """
86  results: dict[str, int | None] = {}
87  missing_hashes: set[int] = set()
88  for shared_attrs, data_hash in shared_attrs_data_hashes:
89  if (attributes_id := self._id_map.get(shared_attrs)) is None:
90  missing_hashes.add(data_hash)
91 
92  results[shared_attrs] = attributes_id
93 
94  if not missing_hashes:
95  return results
96 
97  return results | self._load_from_hashes_load_from_hashes(missing_hashes, session)
98 
100  self, hashes: Collection[int], session: Session
101  ) -> dict[str, int | None]:
102  """Load the shared_attrs to attributes_ids mapping into memory from a list of hashes.
103 
104  This call is not thread-safe and must be called from the
105  recorder thread.
106  """
107  results: dict[str, int | None] = {}
108  with session.no_autoflush:
109  for hashs_chunk in chunked_or_all(hashes, self.recorder.max_bind_vars):
110  for attributes_id, shared_attrs in execute_stmt_lambda_element(
111  session, get_shared_attributes(hashs_chunk), orm_rows=False
112  ):
113  results[shared_attrs] = self._id_map[shared_attrs] = cast(
114  int, attributes_id
115  )
116 
117  return results
118 
119  def add_pending(self, db_state_attributes: StateAttributes) -> None:
120  """Add a pending StateAttributes that will be committed at the next interval.
121 
122  This call is not thread-safe and must be called from the
123  recorder thread.
124  """
125  assert db_state_attributes.shared_attrs is not None
126  shared_attrs: str = db_state_attributes.shared_attrs
127  self._pending[shared_attrs] = db_state_attributes
128 
129  def post_commit_pending(self) -> None:
130  """Call after commit to load the attributes_ids of the new StateAttributes into the LRU.
131 
132  This call is not thread-safe and must be called from the
133  recorder thread.
134  """
135  for shared_attrs, db_state_attributes in self._pending.items():
136  self._id_map[shared_attrs] = db_state_attributes.attributes_id
137  self._pending.clear()
138 
139  def evict_purged(self, attributes_ids: set[int]) -> None:
140  """Evict purged attributes_ids from the cache when they are no longer used.
141 
142  This call is not thread-safe and must be called from the
143  recorder thread.
144  """
145  id_map = self._id_map
146  state_attributes_ids_reversed = {
147  attributes_id: shared_attrs
148  for shared_attrs, attributes_id in id_map.items()
149  }
150  # Evict any purged data from the cache
151  for purged_attributes_id in attributes_ids.intersection(
152  state_attributes_ids_reversed
153  ):
154  id_map.pop(state_attributes_ids_reversed[purged_attributes_id], None)
None load(self, list[Event[EventStateChangedData]] events, Session session)
dict[str, int|None] get_many(self, Iterable[tuple[str, int]] shared_attrs_data_hashes, Session session)
dict[str, int|None] _load_from_hashes(self, Collection[int] hashes, Session session)
int|None get(self, str shared_attr, int data_hash, Session session)
StatementLambdaElement get_shared_attributes(list[int] hashes)
Definition: queries.py:38
Sequence[Row]|Result execute_stmt_lambda_element(Session session, StatementLambdaElement stmt, datetime|None start_time=None, datetime|None end_time=None, int yield_per=DEFAULT_YIELD_STATES_ROWS, bool orm_rows=True)
Definition: util.py:179
Iterable[Any] chunked_or_all(Collection[Any] iterable, int chunked_num)
Definition: collection.py:25