1 """Connection session."""
3 from __future__
import annotations
5 from collections.abc
import Callable, Hashable
6 from contextvars
import ContextVar
7 from typing
import TYPE_CHECKING, Any, Literal
9 from aiohttp
import web
10 import voluptuous
as vol
18 from .
import const, messages
19 from .messages
import (
22 message_to_json_bytes,
25 from .util
import describe_request
28 from .http
import WebSocketAdapter
31 current_connection = ContextVar[
"ActiveConnection | None"](
32 "current_connection", default=
None
35 type MessageHandler = Callable[[HomeAssistant, ActiveConnection, dict[str, Any]],
None]
36 type BinaryHandler = Callable[[HomeAssistant, ActiveConnection, bytes],
None]
40 """Handle an active websocket client connection."""
58 logger: WebSocketAdapter,
60 send_message: Callable[[bytes | str | dict[str, Any]],
None],
62 refresh_token: RefreshToken,
64 """Initialize an active connection."""
70 self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
74 self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[
False]]] = (
75 self.
hasshass.data[const.DOMAIN]
77 self.binary_handlers: list[BinaryHandler |
None] = []
78 current_connection.set(self)
81 """Return the representation."""
82 return f
"<ActiveConnection {self.get_description(None)}>"
85 """Set supported features."""
87 self.
can_coalescecan_coalesce = const.FEATURE_COALESCE_MESSAGES
in features
90 """Return a description of the connection."""
91 description = self.
useruser.name
or ""
96 def context(self, msg: dict[str, Any]) -> Context:
97 """Return a context."""
102 self, handler: BinaryHandler
103 ) -> tuple[int, Callable[[],
None]]:
104 """Register a temporary binary handler for this connection.
106 Returns a binary handler_id (1 byte) and a callback to unregister the handler.
108 if len(self.binary_handlers) < 255:
109 index = len(self.binary_handlers)
110 self.binary_handlers.append(
None)
114 for idx, existing
in enumerate(self.binary_handlers):
120 raise RuntimeError(
"Too many binary handlers registered")
122 self.binary_handlers[index] = handler
126 """Unregister the handler."""
127 assert index
is not None
128 self.binary_handlers[index] =
None
130 return index + 1, unsub
133 def send_result(self, msg_id: int, result: Any |
None =
None) ->
None:
134 """Send a result message."""
138 def send_event(self, msg_id: int, event: Any |
None =
None) ->
None:
139 """Send a event message."""
148 translation_key: str |
None =
None,
149 translation_domain: str |
None =
None,
150 translation_placeholders: dict[str, Any] |
None =
None,
152 """Send an error message."""
159 translation_key=translation_key,
160 translation_domain=translation_domain,
161 translation_placeholders=translation_placeholders,
168 """Handle a single incoming binary message."""
169 index = handler_id - 1
172 or index >= len(self.binary_handlers)
173 or (handler := self.binary_handlers[index])
is None
176 "Received binary message for non-existing handler %s", handler_id
181 handler(self.
hasshass, self, payload)
183 self.
loggerlogger.exception(
"Error handling binary message")
184 self.binary_handlers[index] =
None
188 """Handle a single incoming message."""
192 type(msg)
is not dict
194 not (cur_id := msg.get(
"id"))
195 or type(cur_id)
is not int
197 or not (type_ := msg.get(
"type"))
198 or type(type_)
is not str
201 self.
loggerlogger.error(
"Received invalid command: %s", msg)
202 id_ = msg.get(
"id")
if isinstance(msg, dict)
else 0
204 messages.error_message(
206 const.ERR_INVALID_FORMAT,
207 "Message incorrectly formatted.",
212 if cur_id <= self.
last_idlast_id:
214 messages.error_message(
215 cur_id, const.ERR_ID_REUSE,
"Identifier values have to increase."
220 if not (handler_schema := self.handlers.
get(type_)):
221 self.
loggerlogger.info(
"Received unknown command: %s", type_)
223 messages.error_message(
224 cur_id, const.ERR_UNKNOWN_COMMAND,
"Unknown command."
229 handler, schema = handler_schema
234 raise vol.Invalid(
"extra keys not allowed")
235 handler(self.
hasshass, self, msg)
237 handler(self.
hasshass, self, schema(msg))
238 except Exception
as err:
245 """Handle closing down connection."""
246 for unsub
in self.subscriptions.values():
251 self.
loggerlogger.exception(
252 "Error unsubscribing from subscription: %s", unsub
254 self.subscriptions.clear()
256 current_request.set(
None)
257 current_connection.set(
None)
261 self, msg: bytes | str | dict[str, Any] | Callable[[], str]
263 """Send a message when the connection is closed."""
264 self.
loggerlogger.debug(
"Tried to send message %s on closed connection", msg)
268 """Handle an exception while processing a handler."""
269 log_handler = self.
loggerlogger.error
271 code = const.ERR_UNKNOWN_ERROR
272 err_message: str |
None =
None
273 translation_domain: str |
None =
None
274 translation_key: str |
None =
None
275 translation_placeholders: dict[str, Any] |
None =
None
277 if isinstance(err, Unauthorized):
278 code = const.ERR_UNAUTHORIZED
279 err_message =
"Unauthorized"
280 elif isinstance(err, vol.Invalid):
281 code = const.ERR_INVALID_FORMAT
282 err_message = vol.humanize.humanize_error(msg, err)
283 elif isinstance(err, TimeoutError):
284 code = const.ERR_TIMEOUT
285 err_message =
"Timeout"
286 elif isinstance(err, HomeAssistantError):
287 err_message =
str(err)
288 code = const.ERR_HOME_ASSISTANT_ERROR
289 translation_domain = err.translation_domain
290 translation_key = err.translation_key
291 translation_placeholders = err.translation_placeholders
297 err_message =
"Unknown error"
298 log_handler = self.
loggerlogger.exception
301 messages.error_message(
305 translation_domain=translation_domain,
306 translation_key=translation_key,
307 translation_placeholders=translation_placeholders,
312 err_message += f
" ({code})"
313 err_message +=
" " + self.
get_descriptionget_description(current_request.get())
315 log_handler(
"Error handling message: %s", err_message)
None __init__(self, WebSocketAdapter logger, HomeAssistant hass, Callable[[bytes|str|dict[str, Any]], None] send_message, User user, RefreshToken refresh_token)
None send_result(self, int msg_id, Any|None result=None)
None async_handle_close(self)
None set_supported_features(self, dict[str, float] features)
None _connect_closed_error(self, bytes|str|dict[str, Any]|Callable[[], str] msg)
Context context(self, dict[str, Any] msg)
None async_handle_exception(self, dict[str, Any] msg, Exception err)
None send_error(self, int msg_id, str code, str message, str|None translation_key=None, str|None translation_domain=None, dict[str, Any]|None translation_placeholders=None)
str get_description(self, web.Request|None request)
None async_handle(self, JsonValueType msg)
tuple[int, Callable[[], None]] async_register_binary_handler(self, BinaryHandler handler)
None async_handle_binary(self, int handler_id, bytes payload)
None send_event(self, int msg_id, Any|None event=None)
web.Response get(self, web.Request request, str config_key)
dict[str, Any] result_message(int iden, Any result=None)
dict[str, Any] error_message(int|None iden, str code, str message, str|None translation_key=None, str|None translation_domain=None, dict[str, Any]|None translation_placeholders=None)
bytes message_to_json_bytes(dict[str, Any] message)
dict[str, Any] event_message(int iden, Any event)
str describe_request(web.Request request)