Home Assistant Unofficial Reference 2024.12.1
auth_store.py
Go to the documentation of this file.
1 """Storage for auth models."""
2 
3 from __future__ import annotations
4 
5 from datetime import timedelta
6 import hmac
7 import itertools
8 from logging import getLogger
9 from typing import Any
10 
11 from homeassistant.core import HomeAssistant, callback
12 from homeassistant.helpers import device_registry as dr, entity_registry as er
13 from homeassistant.helpers.storage import Store
14 from homeassistant.util import dt as dt_util
15 
16 from . import models
17 from .const import (
18  ACCESS_TOKEN_EXPIRATION,
19  GROUP_ID_ADMIN,
20  GROUP_ID_READ_ONLY,
21  GROUP_ID_USER,
22  REFRESH_TOKEN_EXPIRATION,
23 )
24 from .permissions import system_policies
25 from .permissions.models import PermissionLookup
26 from .permissions.types import PolicyType
27 
28 STORAGE_VERSION = 1
29 STORAGE_KEY = "auth"
30 GROUP_NAME_ADMIN = "Administrators"
31 GROUP_NAME_USER = "Users"
32 GROUP_NAME_READ_ONLY = "Read Only"
33 
34 # We always save the auth store after we load it since
35 # we may migrate data and do not want to have to do it again
36 # but we don't want to do it during startup so we schedule
37 # the first save 5 minutes out knowing something else may
38 # want to save the auth store before then, and since Storage
39 # will honor the lower of the two delays, it will save it
40 # faster if something else saves it.
41 INITIAL_LOAD_SAVE_DELAY = 300
42 
43 DEFAULT_SAVE_DELAY = 1
44 
45 
46 class AuthStore:
47  """Stores authentication info.
48 
49  Any mutation to an object should happen inside the auth store.
50 
51  The auth store is lazy. It won't load the data from disk until a method is
52  called that needs it.
53  """
54 
55  def __init__(self, hass: HomeAssistant) -> None:
56  """Initialize the auth store."""
57  self.hasshass = hass
58  self._loaded_loaded = False
59  self._users_users: dict[str, models.User] = None # type: ignore[assignment]
60  self._groups_groups: dict[str, models.Group] = None # type: ignore[assignment]
61  self._perm_lookup_perm_lookup: PermissionLookup = None # type: ignore[assignment]
62  self._store_store = Store[dict[str, list[dict[str, Any]]]](
63  hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
64  )
65  self._token_id_to_user_id_token_id_to_user_id: dict[str, str] = {}
66 
67  async def async_get_groups(self) -> list[models.Group]:
68  """Retrieve all users."""
69  return list(self._groups_groups.values())
70 
71  async def async_get_group(self, group_id: str) -> models.Group | None:
72  """Retrieve all users."""
73  return self._groups_groups.get(group_id)
74 
75  async def async_get_users(self) -> list[models.User]:
76  """Retrieve all users."""
77  return list(self._users_users.values())
78 
79  async def async_get_user(self, user_id: str) -> models.User | None:
80  """Retrieve a user by id."""
81  return self._users_users.get(user_id)
82 
83  async def async_create_user(
84  self,
85  name: str | None,
86  is_owner: bool | None = None,
87  is_active: bool | None = None,
88  system_generated: bool | None = None,
89  credentials: models.Credentials | None = None,
90  group_ids: list[str] | None = None,
91  local_only: bool | None = None,
92  ) -> models.User:
93  """Create a new user."""
94  groups = []
95  for group_id in group_ids or []:
96  if (group := self._groups_groups.get(group_id)) is None:
97  raise ValueError(f"Invalid group specified {group_id}")
98  groups.append(group)
99 
100  kwargs: dict[str, Any] = {
101  "name": name,
102  # Until we get group management, we just put everyone in the
103  # same group.
104  "groups": groups,
105  "perm_lookup": self._perm_lookup_perm_lookup,
106  }
107 
108  kwargs.update(
109  {
110  attr_name: value
111  for attr_name, value in (
112  ("is_owner", is_owner),
113  ("is_active", is_active),
114  ("local_only", local_only),
115  ("system_generated", system_generated),
116  )
117  if value is not None
118  }
119  )
120 
121  new_user = models.User(**kwargs)
122 
123  self._users_users[new_user.id] = new_user
124 
125  if credentials is None:
126  self._async_schedule_save_async_schedule_save()
127  return new_user
128 
129  # Saving is done inside the link.
130  await self.async_link_userasync_link_user(new_user, credentials)
131  return new_user
132 
133  async def async_link_user(
134  self, user: models.User, credentials: models.Credentials
135  ) -> None:
136  """Add credentials to an existing user."""
137  user.credentials.append(credentials)
138  self._async_schedule_save_async_schedule_save()
139  credentials.is_new = False
140 
141  async def async_remove_user(self, user: models.User) -> None:
142  """Remove a user."""
143  user = self._users_users.pop(user.id)
144  for refresh_token_id in user.refresh_tokens:
145  del self._token_id_to_user_id_token_id_to_user_id[refresh_token_id]
146  user.refresh_tokens.clear()
147  self._async_schedule_save_async_schedule_save()
148 
149  async def async_update_user(
150  self,
151  user: models.User,
152  name: str | None = None,
153  is_active: bool | None = None,
154  group_ids: list[str] | None = None,
155  local_only: bool | None = None,
156  ) -> None:
157  """Update a user."""
158  if group_ids is not None:
159  groups = []
160  for grid in group_ids:
161  if (group := self._groups_groups.get(grid)) is None:
162  raise ValueError("Invalid group specified.")
163  groups.append(group)
164 
165  user.groups = groups
166 
167  for attr_name, value in (
168  ("name", name),
169  ("is_active", is_active),
170  ("local_only", local_only),
171  ):
172  if value is not None:
173  setattr(user, attr_name, value)
174 
175  self._async_schedule_save_async_schedule_save()
176 
177  async def async_activate_user(self, user: models.User) -> None:
178  """Activate a user."""
179  user.is_active = True
180  self._async_schedule_save_async_schedule_save()
181 
182  async def async_deactivate_user(self, user: models.User) -> None:
183  """Activate a user."""
184  user.is_active = False
185  self._async_schedule_save_async_schedule_save()
186 
187  async def async_remove_credentials(self, credentials: models.Credentials) -> None:
188  """Remove credentials."""
189  for user in self._users_users.values():
190  found = None
191 
192  for index, cred in enumerate(user.credentials):
193  if cred is credentials:
194  found = index
195  break
196 
197  if found is not None:
198  user.credentials.pop(found)
199  break
200 
201  self._async_schedule_save_async_schedule_save()
202 
204  self,
205  user: models.User,
206  client_id: str | None = None,
207  client_name: str | None = None,
208  client_icon: str | None = None,
209  token_type: str = models.TOKEN_TYPE_NORMAL,
210  access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
211  expire_at: float | None = None,
212  credential: models.Credentials | None = None,
213  ) -> models.RefreshToken:
214  """Create a new token for a user."""
215  kwargs: dict[str, Any] = {
216  "user": user,
217  "client_id": client_id,
218  "token_type": token_type,
219  "access_token_expiration": access_token_expiration,
220  "expire_at": expire_at,
221  "credential": credential,
222  }
223  if client_name:
224  kwargs["client_name"] = client_name
225  if client_icon:
226  kwargs["client_icon"] = client_icon
227 
228  refresh_token = models.RefreshToken(**kwargs)
229  token_id = refresh_token.id
230  user.refresh_tokens[token_id] = refresh_token
231  self._token_id_to_user_id_token_id_to_user_id[token_id] = user.id
232 
233  self._async_schedule_save_async_schedule_save()
234  return refresh_token
235 
236  @callback
237  def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
238  """Remove a refresh token."""
239  refresh_token_id = refresh_token.id
240  if user_id := self._token_id_to_user_id_token_id_to_user_id.get(refresh_token_id):
241  del self._users_users[user_id].refresh_tokens[refresh_token_id]
242  del self._token_id_to_user_id_token_id_to_user_id[refresh_token_id]
243  self._async_schedule_save_async_schedule_save()
244 
245  @callback
246  def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
247  """Get refresh token by id."""
248  if user_id := self._token_id_to_user_id_token_id_to_user_id.get(token_id):
249  return self._users_users[user_id].refresh_tokens.get(token_id)
250  return None
251 
252  @callback
254  self, token: str
255  ) -> models.RefreshToken | None:
256  """Get refresh token by token."""
257  found = None
258 
259  for user in self._users_users.values():
260  for refresh_token in user.refresh_tokens.values():
261  if hmac.compare_digest(refresh_token.token, token):
262  found = refresh_token
263 
264  return found
265 
266  @callback
267  def async_get_refresh_tokens(self) -> list[models.RefreshToken]:
268  """Get all refresh tokens."""
269  return list(
270  itertools.chain.from_iterable(
271  user.refresh_tokens.values() for user in self._users_users.values()
272  )
273  )
274 
275  @callback
277  self, refresh_token: models.RefreshToken, remote_ip: str | None = None
278  ) -> None:
279  """Update refresh token last used information."""
280  refresh_token.last_used_at = dt_util.utcnow()
281  refresh_token.last_used_ip = remote_ip
282  if refresh_token.expire_at:
283  refresh_token.expire_at = (
284  refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
285  )
286  self._async_schedule_save_async_schedule_save()
287 
288  @callback
290  self, refresh_token: models.RefreshToken, *, enable_expiry: bool
291  ) -> None:
292  """Enable or disable expiry of a refresh token."""
293  if enable_expiry:
294  if refresh_token.expire_at is None:
295  refresh_token.expire_at = (
296  refresh_token.last_used_at or dt_util.utcnow()
297  ).timestamp() + REFRESH_TOKEN_EXPIRATION
298  self._async_schedule_save_async_schedule_save()
299  else:
300  refresh_token.expire_at = None
301  self._async_schedule_save_async_schedule_save()
302 
303  @callback
305  self, credentials: models.Credentials, data: dict[str, Any]
306  ) -> None:
307  """Update credentials data."""
308  credentials.data = data
309  self._async_schedule_save_async_schedule_save()
310 
311  async def async_load(self) -> None: # noqa: C901
312  """Load the users."""
313  if self._loaded_loaded:
314  raise RuntimeError("Auth storage is already loaded")
315  self._loaded_loaded = True
316 
317  dev_reg = dr.async_get(self.hasshass)
318  ent_reg = er.async_get(self.hasshass)
319  data = await self._store_store.async_load()
320 
321  perm_lookup = PermissionLookup(ent_reg, dev_reg)
322  self._perm_lookup_perm_lookup = perm_lookup
323 
324  if data is None or not isinstance(data, dict):
325  self._set_defaults_set_defaults()
326  return
327 
328  users: dict[str, models.User] = {}
329  groups: dict[str, models.Group] = {}
330  credentials: dict[str, models.Credentials] = {}
331 
332  # Soft-migrating data as we load. We are going to make sure we have a
333  # read only group and an admin group. There are two states that we can
334  # migrate from:
335  # 1. Data from a recent version which has a single group without policy
336  # 2. Data from old version which has no groups
337  has_admin_group = False
338  has_user_group = False
339  has_read_only_group = False
340  group_without_policy = None
341 
342  # When creating objects we mention each attribute explicitly. This
343  # prevents crashing if user rolls back HA version after a new property
344  # was added.
345 
346  for group_dict in data.get("groups", []):
347  policy: PolicyType | None = None
348 
349  if group_dict["id"] == GROUP_ID_ADMIN:
350  has_admin_group = True
351 
352  name = GROUP_NAME_ADMIN
353  policy = system_policies.ADMIN_POLICY
354  system_generated = True
355 
356  elif group_dict["id"] == GROUP_ID_USER:
357  has_user_group = True
358 
359  name = GROUP_NAME_USER
360  policy = system_policies.USER_POLICY
361  system_generated = True
362 
363  elif group_dict["id"] == GROUP_ID_READ_ONLY:
364  has_read_only_group = True
365 
366  name = GROUP_NAME_READ_ONLY
367  policy = system_policies.READ_ONLY_POLICY
368  system_generated = True
369 
370  else:
371  name = group_dict["name"]
372  policy = group_dict.get("policy")
373  system_generated = False
374 
375  # We don't want groups without a policy that are not system groups
376  # This is part of migrating from state 1
377  if policy is None:
378  group_without_policy = group_dict["id"]
379  continue
380 
381  groups[group_dict["id"]] = models.Group(
382  id=group_dict["id"],
383  name=name,
384  policy=policy,
385  system_generated=system_generated,
386  )
387 
388  # If there are no groups, add all existing users to the admin group.
389  # This is part of migrating from state 2
390  migrate_users_to_admin_group = not groups and group_without_policy is None
391 
392  # If we find a no_policy_group, we need to migrate all users to the
393  # admin group. We only do this if there are no other groups, as is
394  # the expected state. If not expected state, not marking people admin.
395  # This is part of migrating from state 1
396  if groups and group_without_policy is not None:
397  group_without_policy = None
398 
399  # This is part of migrating from state 1 and 2
400  if not has_admin_group:
401  admin_group = _system_admin_group()
402  groups[admin_group.id] = admin_group
403 
404  # This is part of migrating from state 1 and 2
405  if not has_read_only_group:
406  read_only_group = _system_read_only_group()
407  groups[read_only_group.id] = read_only_group
408 
409  if not has_user_group:
410  user_group = _system_user_group()
411  groups[user_group.id] = user_group
412 
413  for user_dict in data["users"]:
414  # Collect the users group.
415  user_groups = []
416  for group_id in user_dict.get("group_ids", []):
417  # This is part of migrating from state 1
418  if group_id == group_without_policy:
419  group_id = GROUP_ID_ADMIN
420  user_groups.append(groups[group_id])
421 
422  # This is part of migrating from state 2
423  if not user_dict["system_generated"] and migrate_users_to_admin_group:
424  user_groups.append(groups[GROUP_ID_ADMIN])
425 
426  users[user_dict["id"]] = models.User(
427  name=user_dict["name"],
428  groups=user_groups,
429  id=user_dict["id"],
430  is_owner=user_dict["is_owner"],
431  is_active=user_dict["is_active"],
432  system_generated=user_dict["system_generated"],
433  perm_lookup=perm_lookup,
434  # New in 2021.11
435  local_only=user_dict.get("local_only", False),
436  )
437 
438  for cred_dict in data["credentials"]:
439  credential = models.Credentials(
440  id=cred_dict["id"],
441  is_new=False,
442  auth_provider_type=cred_dict["auth_provider_type"],
443  auth_provider_id=cred_dict["auth_provider_id"],
444  data=cred_dict["data"],
445  )
446  credentials[cred_dict["id"]] = credential
447  users[cred_dict["user_id"]].credentials.append(credential)
448 
449  for rt_dict in data["refresh_tokens"]:
450  # Filter out the old keys that don't have jwt_key (pre-0.76)
451  if "jwt_key" not in rt_dict:
452  continue
453 
454  created_at = dt_util.parse_datetime(rt_dict["created_at"])
455  if created_at is None:
456  getLogger(__name__).error(
457  (
458  "Ignoring refresh token %(id)s with invalid created_at "
459  "%(created_at)s for user_id %(user_id)s"
460  ),
461  rt_dict,
462  )
463  continue
464 
465  if (token_type := rt_dict.get("token_type")) is None:
466  if rt_dict["client_id"] is None:
467  token_type = models.TOKEN_TYPE_SYSTEM
468  else:
469  token_type = models.TOKEN_TYPE_NORMAL
470 
471  # old refresh_token don't have last_used_at (pre-0.78)
472  if last_used_at_str := rt_dict.get("last_used_at"):
473  last_used_at = dt_util.parse_datetime(last_used_at_str)
474  else:
475  last_used_at = None
476 
477  token = models.RefreshToken(
478  id=rt_dict["id"],
479  user=users[rt_dict["user_id"]],
480  client_id=rt_dict["client_id"],
481  # use dict.get to keep backward compatibility
482  client_name=rt_dict.get("client_name"),
483  client_icon=rt_dict.get("client_icon"),
484  token_type=token_type,
485  created_at=created_at,
486  access_token_expiration=timedelta(
487  seconds=rt_dict["access_token_expiration"]
488  ),
489  token=rt_dict["token"],
490  jwt_key=rt_dict["jwt_key"],
491  last_used_at=last_used_at,
492  last_used_ip=rt_dict.get("last_used_ip"),
493  expire_at=rt_dict.get("expire_at"),
494  version=rt_dict.get("version"),
495  )
496  if "credential_id" in rt_dict:
497  token.credential = credentials.get(rt_dict["credential_id"])
498  users[rt_dict["user_id"]].refresh_tokens[token.id] = token
499 
500  self._groups_groups = groups
501  self._users_users = users
502  self._build_token_id_to_user_id_build_token_id_to_user_id()
503  self._async_schedule_save_async_schedule_save(INITIAL_LOAD_SAVE_DELAY)
504 
505  @callback
506  def _build_token_id_to_user_id(self) -> None:
507  """Build a map of token id to user id."""
508  self._token_id_to_user_id_token_id_to_user_id = {
509  token_id: user_id
510  for user_id, user in self._users_users.items()
511  for token_id in user.refresh_tokens
512  }
513 
514  @callback
515  def _async_schedule_save(self, delay: float = DEFAULT_SAVE_DELAY) -> None:
516  """Save users."""
517  self._store_store.async_delay_save(self._data_to_save_data_to_save, delay)
518 
519  @callback
520  def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
521  """Return the data to store."""
522  users = [
523  {
524  "id": user.id,
525  "group_ids": [group.id for group in user.groups],
526  "is_owner": user.is_owner,
527  "is_active": user.is_active,
528  "name": user.name,
529  "system_generated": user.system_generated,
530  "local_only": user.local_only,
531  }
532  for user in self._users_users.values()
533  ]
534 
535  groups = []
536  for group in self._groups_groups.values():
537  g_dict: dict[str, Any] = {
538  "id": group.id,
539  # Name not read for sys groups. Kept here for backwards compat
540  "name": group.name,
541  }
542 
543  if not group.system_generated:
544  g_dict["policy"] = group.policy
545 
546  groups.append(g_dict)
547 
548  credentials = [
549  {
550  "id": credential.id,
551  "user_id": user.id,
552  "auth_provider_type": credential.auth_provider_type,
553  "auth_provider_id": credential.auth_provider_id,
554  "data": credential.data,
555  }
556  for user in self._users_users.values()
557  for credential in user.credentials
558  ]
559 
560  refresh_tokens = [
561  {
562  "id": refresh_token.id,
563  "user_id": user.id,
564  "client_id": refresh_token.client_id,
565  "client_name": refresh_token.client_name,
566  "client_icon": refresh_token.client_icon,
567  "token_type": refresh_token.token_type,
568  "created_at": refresh_token.created_at.isoformat(),
569  "access_token_expiration": (
570  refresh_token.access_token_expiration.total_seconds()
571  ),
572  "token": refresh_token.token,
573  "jwt_key": refresh_token.jwt_key,
574  "last_used_at": refresh_token.last_used_at.isoformat()
575  if refresh_token.last_used_at
576  else None,
577  "last_used_ip": refresh_token.last_used_ip,
578  "expire_at": refresh_token.expire_at,
579  "credential_id": refresh_token.credential.id
580  if refresh_token.credential
581  else None,
582  "version": refresh_token.version,
583  }
584  for user in self._users_users.values()
585  for refresh_token in user.refresh_tokens.values()
586  ]
587 
588  return {
589  "users": users,
590  "groups": groups,
591  "credentials": credentials,
592  "refresh_tokens": refresh_tokens,
593  }
594 
595  def _set_defaults(self) -> None:
596  """Set default values for auth store."""
597  self._users_users = {}
598 
599  groups: dict[str, models.Group] = {}
600  admin_group = _system_admin_group()
601  groups[admin_group.id] = admin_group
602  user_group = _system_user_group()
603  groups[user_group.id] = user_group
604  read_only_group = _system_read_only_group()
605  groups[read_only_group.id] = read_only_group
606  self._groups_groups = groups
607  self._build_token_id_to_user_id_build_token_id_to_user_id()
608 
609 
610 def _system_admin_group() -> models.Group:
611  """Create system admin group."""
612  return models.Group(
613  name=GROUP_NAME_ADMIN,
614  id=GROUP_ID_ADMIN,
615  policy=system_policies.ADMIN_POLICY,
616  system_generated=True,
617  )
618 
619 
620 def _system_user_group() -> models.Group:
621  """Create system user group."""
622  return models.Group(
623  name=GROUP_NAME_USER,
624  id=GROUP_ID_USER,
625  policy=system_policies.USER_POLICY,
626  system_generated=True,
627  )
628 
629 
630 def _system_read_only_group() -> models.Group:
631  """Create read only group."""
632  return models.Group(
633  name=GROUP_NAME_READ_ONLY,
634  id=GROUP_ID_READ_ONLY,
635  policy=system_policies.READ_ONLY_POLICY,
636  system_generated=True,
637  )
models.Group|None async_get_group(self, str group_id)
Definition: auth_store.py:71
models.RefreshToken|None async_get_refresh_token(self, str token_id)
Definition: auth_store.py:246
None async_remove_credentials(self, models.Credentials credentials)
Definition: auth_store.py:187
models.User async_create_user(self, str|None name, bool|None is_owner=None, bool|None is_active=None, bool|None system_generated=None, models.Credentials|None credentials=None, list[str]|None group_ids=None, bool|None local_only=None)
Definition: auth_store.py:92
list[models.User] async_get_users(self)
Definition: auth_store.py:75
None async_remove_user(self, models.User user)
Definition: auth_store.py:141
None async_update_user_credentials_data(self, models.Credentials credentials, dict[str, Any] data)
Definition: auth_store.py:306
models.User|None async_get_user(self, str user_id)
Definition: auth_store.py:79
None async_activate_user(self, models.User user)
Definition: auth_store.py:177
None async_log_refresh_token_usage(self, models.RefreshToken refresh_token, str|None remote_ip=None)
Definition: auth_store.py:278
list[models.Group] async_get_groups(self)
Definition: auth_store.py:67
models.RefreshToken|None async_get_refresh_token_by_token(self, str token)
Definition: auth_store.py:255
None _async_schedule_save(self, float delay=DEFAULT_SAVE_DELAY)
Definition: auth_store.py:515
list[models.RefreshToken] async_get_refresh_tokens(self)
Definition: auth_store.py:267
None async_set_expiry(self, models.RefreshToken refresh_token, *bool enable_expiry)
Definition: auth_store.py:291
None async_remove_refresh_token(self, models.RefreshToken refresh_token)
Definition: auth_store.py:237
None async_deactivate_user(self, models.User user)
Definition: auth_store.py:182
None async_update_user(self, models.User user, str|None name=None, bool|None is_active=None, list[str]|None group_ids=None, bool|None local_only=None)
Definition: auth_store.py:156
None async_link_user(self, models.User user, models.Credentials credentials)
Definition: auth_store.py:135
dict[str, list[dict[str, Any]]] _data_to_save(self)
Definition: auth_store.py:520
models.RefreshToken async_create_refresh_token(self, models.User user, str|None client_id=None, str|None client_name=None, str|None client_icon=None, str token_type=models.TOKEN_TYPE_NORMAL, timedelta access_token_expiration=ACCESS_TOKEN_EXPIRATION, float|None expire_at=None, models.Credentials|None credential=None)
Definition: auth_store.py:213
None __init__(self, HomeAssistant hass)
Definition: auth_store.py:55
models.Group _system_admin_group()
Definition: auth_store.py:610
models.Group _system_read_only_group()
Definition: auth_store.py:630
models.Group _system_user_group()
Definition: auth_store.py:620
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88
None async_delay_save(self, Callable[[], _T] data_func, float delay=0)
Definition: storage.py:444