Home Assistant Unofficial Reference 2024.12.1
connection.py
Go to the documentation of this file.
1 """Connection session."""
2 
3 from __future__ import annotations
4 
5 from collections.abc import Callable, Hashable
6 from contextvars import ContextVar
7 from typing import TYPE_CHECKING, Any, Literal
8 
9 from aiohttp import web
10 import voluptuous as vol
11 
12 from homeassistant.auth.models import RefreshToken, User
13 from homeassistant.core import Context, HomeAssistant, callback
14 from homeassistant.exceptions import HomeAssistantError, Unauthorized
15 from homeassistant.helpers.http import current_request
16 from homeassistant.util.json import JsonValueType
17 
18 from . import const, messages
19 from .messages import (
20  error_message,
21  event_message,
22  message_to_json_bytes,
23  result_message,
24 )
25 from .util import describe_request
26 
27 if TYPE_CHECKING:
28  from .http import WebSocketAdapter
29 
30 
31 current_connection = ContextVar["ActiveConnection | None"](
32  "current_connection", default=None
33 )
34 
35 type MessageHandler = Callable[[HomeAssistant, ActiveConnection, dict[str, Any]], None]
36 type BinaryHandler = Callable[[HomeAssistant, ActiveConnection, bytes], None]
37 
38 
40  """Handle an active websocket client connection."""
41 
42  __slots__ = (
43  "logger",
44  "hass",
45  "send_message",
46  "user",
47  "refresh_token_id",
48  "subscriptions",
49  "last_id",
50  "can_coalesce",
51  "supported_features",
52  "handlers",
53  "binary_handlers",
54  )
55 
56  def __init__(
57  self,
58  logger: WebSocketAdapter,
59  hass: HomeAssistant,
60  send_message: Callable[[bytes | str | dict[str, Any]], None],
61  user: User,
62  refresh_token: RefreshToken,
63  ) -> None:
64  """Initialize an active connection."""
65  self.loggerlogger = logger
66  self.hasshass = hass
67  self.send_messagesend_message = send_message
68  self.useruser = user
69  self.refresh_token_idrefresh_token_id = refresh_token.id
70  self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
71  self.last_idlast_id = 0
72  self.can_coalescecan_coalesce = False
73  self.supported_featuressupported_features: dict[str, float] = {}
74  self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[False]]] = (
75  self.hasshass.data[const.DOMAIN]
76  )
77  self.binary_handlers: list[BinaryHandler | None] = []
78  current_connection.set(self)
79 
80  def __repr__(self) -> str:
81  """Return the representation."""
82  return f"<ActiveConnection {self.get_description(None)}>"
83 
84  def set_supported_features(self, features: dict[str, float]) -> None:
85  """Set supported features."""
86  self.supported_featuressupported_features = features
87  self.can_coalescecan_coalesce = const.FEATURE_COALESCE_MESSAGES in features
88 
89  def get_description(self, request: web.Request | None) -> str:
90  """Return a description of the connection."""
91  description = self.useruser.name or ""
92  if request:
93  description += " " + describe_request(request)
94  return description
95 
96  def context(self, msg: dict[str, Any]) -> Context:
97  """Return a context."""
98  return Context(user_id=self.useruser.id)
99 
100  @callback
102  self, handler: BinaryHandler
103  ) -> tuple[int, Callable[[], None]]:
104  """Register a temporary binary handler for this connection.
105 
106  Returns a binary handler_id (1 byte) and a callback to unregister the handler.
107  """
108  if len(self.binary_handlers) < 255:
109  index = len(self.binary_handlers)
110  self.binary_handlers.append(None)
111  else:
112  # Once the list is full, we search for a None entry to reuse.
113  index = None
114  for idx, existing in enumerate(self.binary_handlers):
115  if existing is None:
116  index = idx
117  break
118 
119  if index is None:
120  raise RuntimeError("Too many binary handlers registered")
121 
122  self.binary_handlers[index] = handler
123 
124  @callback
125  def unsub() -> None:
126  """Unregister the handler."""
127  assert index is not None
128  self.binary_handlers[index] = None
129 
130  return index + 1, unsub
131 
132  @callback
133  def send_result(self, msg_id: int, result: Any | None = None) -> None:
134  """Send a result message."""
135  self.send_messagesend_message(message_to_json_bytes(result_message(msg_id, result)))
136 
137  @callback
138  def send_event(self, msg_id: int, event: Any | None = None) -> None:
139  """Send a event message."""
140  self.send_messagesend_message(message_to_json_bytes(event_message(msg_id, event)))
141 
142  @callback
144  self,
145  msg_id: int,
146  code: str,
147  message: str,
148  translation_key: str | None = None,
149  translation_domain: str | None = None,
150  translation_placeholders: dict[str, Any] | None = None,
151  ) -> None:
152  """Send an error message."""
153  self.send_messagesend_message(
156  msg_id,
157  code,
158  message,
159  translation_key=translation_key,
160  translation_domain=translation_domain,
161  translation_placeholders=translation_placeholders,
162  )
163  )
164  )
165 
166  @callback
167  def async_handle_binary(self, handler_id: int, payload: bytes) -> None:
168  """Handle a single incoming binary message."""
169  index = handler_id - 1
170  if (
171  index < 0
172  or index >= len(self.binary_handlers)
173  or (handler := self.binary_handlers[index]) is None
174  ):
175  self.loggerlogger.error(
176  "Received binary message for non-existing handler %s", handler_id
177  )
178  return
179 
180  try:
181  handler(self.hasshass, self, payload)
182  except Exception:
183  self.loggerlogger.exception("Error handling binary message")
184  self.binary_handlers[index] = None
185 
186  @callback
187  def async_handle(self, msg: JsonValueType) -> None:
188  """Handle a single incoming message."""
189  if (
190  # Not using isinstance as we don't care about children
191  # as these are always coming from JSON
192  type(msg) is not dict # noqa: E721
193  or (
194  not (cur_id := msg.get("id"))
195  or type(cur_id) is not int # noqa: E721
196  or cur_id < 0
197  or not (type_ := msg.get("type"))
198  or type(type_) is not str # noqa: E721
199  )
200  ):
201  self.loggerlogger.error("Received invalid command: %s", msg)
202  id_ = msg.get("id") if isinstance(msg, dict) else 0
203  self.send_messagesend_message(
204  messages.error_message(
205  id_, # type: ignore[arg-type]
206  const.ERR_INVALID_FORMAT,
207  "Message incorrectly formatted.",
208  )
209  )
210  return
211 
212  if cur_id <= self.last_idlast_id:
213  self.send_messagesend_message(
214  messages.error_message(
215  cur_id, const.ERR_ID_REUSE, "Identifier values have to increase."
216  )
217  )
218  return
219 
220  if not (handler_schema := self.handlers.get(type_)):
221  self.loggerlogger.info("Received unknown command: %s", type_)
222  self.send_messagesend_message(
223  messages.error_message(
224  cur_id, const.ERR_UNKNOWN_COMMAND, "Unknown command."
225  )
226  )
227  return
228 
229  handler, schema = handler_schema
230 
231  try:
232  if schema is False:
233  if len(msg) > 2:
234  raise vol.Invalid("extra keys not allowed") # noqa: TRY301
235  handler(self.hasshass, self, msg)
236  else:
237  handler(self.hasshass, self, schema(msg))
238  except Exception as err: # noqa: BLE001
239  self.async_handle_exceptionasync_handle_exception(msg, err)
240 
241  self.last_idlast_id = cur_id
242 
243  @callback
244  def async_handle_close(self) -> None:
245  """Handle closing down connection."""
246  for unsub in self.subscriptions.values():
247  try:
248  unsub()
249  except Exception:
250  # If one fails, make sure we still try the rest
251  self.loggerlogger.exception(
252  "Error unsubscribing from subscription: %s", unsub
253  )
254  self.subscriptions.clear()
255  self.send_messagesend_message = self._connect_closed_error_connect_closed_error
256  current_request.set(None)
257  current_connection.set(None)
258 
259  @callback
261  self, msg: bytes | str | dict[str, Any] | Callable[[], str]
262  ) -> None:
263  """Send a message when the connection is closed."""
264  self.loggerlogger.debug("Tried to send message %s on closed connection", msg)
265 
266  @callback
267  def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
268  """Handle an exception while processing a handler."""
269  log_handler = self.loggerlogger.error
270 
271  code = const.ERR_UNKNOWN_ERROR
272  err_message: str | None = None
273  translation_domain: str | None = None
274  translation_key: str | None = None
275  translation_placeholders: dict[str, Any] | None = None
276 
277  if isinstance(err, Unauthorized):
278  code = const.ERR_UNAUTHORIZED
279  err_message = "Unauthorized"
280  elif isinstance(err, vol.Invalid):
281  code = const.ERR_INVALID_FORMAT
282  err_message = vol.humanize.humanize_error(msg, err)
283  elif isinstance(err, TimeoutError):
284  code = const.ERR_TIMEOUT
285  err_message = "Timeout"
286  elif isinstance(err, HomeAssistantError):
287  err_message = str(err)
288  code = const.ERR_HOME_ASSISTANT_ERROR
289  translation_domain = err.translation_domain
290  translation_key = err.translation_key
291  translation_placeholders = err.translation_placeholders
292 
293  # This if-check matches all other errors but also matches errors which
294  # result in an empty message. In that case we will also log the stack
295  # trace so it can be fixed.
296  if not err_message:
297  err_message = "Unknown error"
298  log_handler = self.loggerlogger.exception
299 
300  self.send_messagesend_message(
301  messages.error_message(
302  msg["id"],
303  code,
304  err_message,
305  translation_domain=translation_domain,
306  translation_key=translation_key,
307  translation_placeholders=translation_placeholders,
308  )
309  )
310 
311  if code:
312  err_message += f" ({code})"
313  err_message += " " + self.get_descriptionget_description(current_request.get())
314 
315  log_handler("Error handling message: %s", err_message)
None __init__(self, WebSocketAdapter logger, HomeAssistant hass, Callable[[bytes|str|dict[str, Any]], None] send_message, User user, RefreshToken refresh_token)
Definition: connection.py:63
None send_result(self, int msg_id, Any|None result=None)
Definition: connection.py:133
None set_supported_features(self, dict[str, float] features)
Definition: connection.py:84
None _connect_closed_error(self, bytes|str|dict[str, Any]|Callable[[], str] msg)
Definition: connection.py:262
None async_handle_exception(self, dict[str, Any] msg, Exception err)
Definition: connection.py:267
None send_error(self, int msg_id, str code, str message, str|None translation_key=None, str|None translation_domain=None, dict[str, Any]|None translation_placeholders=None)
Definition: connection.py:151
tuple[int, Callable[[], None]] async_register_binary_handler(self, BinaryHandler handler)
Definition: connection.py:103
None async_handle_binary(self, int handler_id, bytes payload)
Definition: connection.py:167
None send_event(self, int msg_id, Any|None event=None)
Definition: connection.py:138
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88
dict[str, Any] result_message(int iden, Any result=None)
Definition: messages.py:63
dict[str, Any] error_message(int|None iden, str code, str message, str|None translation_key=None, str|None translation_domain=None, dict[str, Any]|None translation_placeholders=None)
Definition: messages.py:88
bytes message_to_json_bytes(dict[str, Any] message)
Definition: messages.py:255
dict[str, Any] event_message(int iden, Any event)
Definition: messages.py:107
str describe_request(web.Request request)
Definition: util.py:8