1 """Provide an authentication layer for Home Assistant."""
3 from __future__
import annotations
6 from collections
import OrderedDict
7 from collections.abc
import Mapping
8 from datetime
import datetime, timedelta
9 from functools
import partial
11 from typing
import Any, cast
26 from .
import auth_store, jwt_wrapper, models
27 from .const
import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
28 from .mfa_modules
import MultiFactorAuthModule, auth_mfa_module_from_config
29 from .models
import AuthFlowContext, AuthFlowResult
30 from .providers
import AuthProvider, LoginFlow, auth_provider_from_config
31 from .providers.homeassistant
import HassAuthProvider
33 EVENT_USER_ADDED =
"user_added"
34 EVENT_USER_UPDATED =
"user_updated"
35 EVENT_USER_REMOVED =
"user_removed"
37 type _MfaModuleDict = dict[str, MultiFactorAuthModule]
38 type _ProviderKey = tuple[str, str |
None]
39 type _ProviderDict = dict[_ProviderKey, AuthProvider]
43 """Raised when a authentication error occurs."""
47 """Authentication provider not found."""
52 provider_configs: list[dict[str, Any]],
53 module_configs: list[dict[str, Any]],
55 """Initialize an auth manager from config.
57 CORE_CONFIG_SCHEMA will make sure no duplicated auth providers or
58 mfa modules exist in configs.
60 store = auth_store.AuthStore(hass)
61 await store.async_load()
63 providers = await asyncio.gather(
66 for config
in provider_configs
72 provider_hash: _ProviderDict = OrderedDict()
73 for provider
in providers:
74 key = (provider.type, provider.id)
75 provider_hash[key] = provider
77 if isinstance(provider, HassAuthProvider):
82 await provider.async_initialize()
85 modules = await asyncio.gather(
91 module_hash: _MfaModuleDict = OrderedDict()
92 for module
in modules:
93 module_hash[module.id] = module
95 manager = AuthManager(hass, store, provider_hash, module_hash)
96 await manager.async_setup()
101 FlowManager[AuthFlowContext, AuthFlowResult, tuple[str, str]]
103 """Manage authentication flows."""
105 _flow_result = AuthFlowResult
107 def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) ->
None:
108 """Init auth manager flows."""
114 handler_key: tuple[str, str],
116 context: AuthFlowContext |
None =
None,
117 data: dict[str, Any] |
None =
None,
119 """Create a login flow."""
120 auth_provider = self.
auth_managerauth_manager.get_auth_provider(*handler_key)
121 if not auth_provider:
122 raise KeyError(f
"Unknown auth provider {handler_key}")
123 return await auth_provider.async_login_flow(context)
127 flow: FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
128 result: AuthFlowResult,
130 """Return a user as result of login flow.
132 This method is called when a flow step returns FlowResultType.ABORT or
133 FlowResultType.CREATE_ENTRY.
135 flow = cast(LoginFlow, flow)
137 if result[
"type"] != FlowResultType.CREATE_ENTRY:
142 result[
"result"] = result[
"data"]
145 auth_provider = self.
auth_managerauth_manager.get_auth_provider(*result[
"handler"])
146 if not auth_provider:
147 raise KeyError(f
"Unknown auth provider {result['handler']}")
149 credentials = await auth_provider.async_get_or_create_credentials(
150 cast(Mapping[str, str], result[
"data"]),
153 if flow.context.get(
"credential_only"):
154 result[
"result"] = credentials
159 if auth_provider.support_mfa
and not credentials.is_new:
160 user = await self.
auth_managerauth_manager.async_get_user_by_credentials(credentials)
162 modules = await self.
auth_managerauth_manager.async_get_enabled_mfa(user)
165 flow.credential = credentials
167 flow.available_mfa_modules = modules
168 return await flow.async_step_select_mfa_module()
170 result[
"result"] = credentials
175 """Manage the authentication for Home Assistant."""
181 providers: _ProviderDict,
182 mfa_modules: _MfaModuleDict,
184 """Initialize the auth manager."""
190 self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
197 """Set up the auth manager."""
199 hass.async_add_shutdown_job(
208 """Return a list of available auth providers."""
213 """Return a list of available auth modules."""
217 self, provider_type: str, provider_id: str |
None
218 ) -> AuthProvider |
None:
219 """Return an auth provider, None if not found."""
220 return self.
_providers_providers.
get((provider_type, provider_id))
223 """Return a List of auth provider of one type, Empty if not found."""
226 for (p_type, _), provider
in self.
_providers_providers.items()
227 if p_type == provider_type
231 """Return a multi-factor auth module, None if not found."""
235 """Retrieve all users."""
239 """Retrieve a user."""
243 """Retrieve the owner."""
245 return next((user
for user
in users
if user.is_owner),
None)
248 """Retrieve all groups."""
252 self, credentials: models.Credentials
254 """Get a user by credential, return None if not found."""
256 for creds
in user.credentials:
257 if creds.id == credentials.id:
266 group_ids: list[str] |
None =
None,
267 local_only: bool |
None =
None,
269 """Create a system user."""
272 system_generated=
True,
274 group_ids=group_ids
or [],
275 local_only=local_only,
278 self.
hasshass.bus.async_fire(EVENT_USER_ADDED, {
"user_id": user.id})
286 group_ids: list[str] |
None =
None,
287 local_only: bool |
None =
None,
290 kwargs: dict[str, Any] = {
293 "group_ids": group_ids
or [],
294 "local_only": local_only,
298 kwargs[
"is_owner"] =
True
302 self.
hasshass.bus.async_fire(EVENT_USER_ADDED, {
"user_id": user.id})
307 self, credentials: models.Credentials
309 """Get or create a user."""
310 if not credentials.is_new:
313 raise ValueError(
"Unable to find the user.")
318 if auth_provider
is None:
319 raise RuntimeError(
"Credential with unknown provider encountered")
321 info = await auth_provider.async_user_meta_for_credentials(credentials)
324 credentials=credentials,
326 is_active=info.is_active,
327 group_ids=[GROUP_ID_ADMIN
if info.group
is None else info.group],
328 local_only=info.local_only,
331 self.
hasshass.bus.async_fire(EVENT_USER_ADDED, {
"user_id": user.id})
338 """Link credentials to an existing user."""
340 if linked_user == user:
342 if linked_user
is not None:
343 raise ValueError(
"Credential is already linked to a user")
351 for credentials
in user.credentials
355 await asyncio.gather(*tasks)
359 self.
hasshass.bus.async_fire(EVENT_USER_REMOVED, {
"user_id": user.id})
364 name: str |
None =
None,
365 is_active: bool |
None =
None,
366 group_ids: list[str] |
None =
None,
367 local_only: bool |
None =
None,
370 kwargs: dict[str, Any] = {
372 for attr_name, value
in (
374 (
"group_ids", group_ids),
375 (
"local_only", local_only),
381 if is_active
is not None:
382 if is_active
is True:
387 self.
hasshass.bus.async_fire(EVENT_USER_UPDATED, {
"user_id": user.id})
391 self, credentials: models.Credentials, data: dict[str, Any]
393 """Update credentials data."""
397 """Activate a user."""
401 """Deactivate a user."""
403 raise ValueError(
"Unable to deactivate the owner")
407 """Remove credentials."""
410 if provider
is not None and hasattr(provider,
"async_will_remove_credentials"):
411 await provider.async_will_remove_credentials(credentials)
416 self, user: models.User, mfa_module_id: str, data: Any
418 """Enable a multi-factor auth module for user."""
419 if user.system_generated:
421 "System generated users cannot enable multi-factor auth module."
425 raise ValueError(f
"Unable find multi-factor auth module: {mfa_module_id}")
427 await module.async_setup_user(user.id, data)
430 self, user: models.User, mfa_module_id: str
432 """Disable a multi-factor auth module for user."""
433 if user.system_generated:
435 "System generated users cannot disable multi-factor auth module."
439 raise ValueError(f
"Unable find multi-factor auth module: {mfa_module_id}")
441 await module.async_depose_user(user.id)
444 """List enabled mfa modules for user."""
445 modules: dict[str, str] = OrderedDict()
446 for module_id, module
in self.
_mfa_modules_mfa_modules.items():
447 if await module.async_is_user_setup(user.id):
448 modules[module_id] = module.name
454 client_id: str |
None =
None,
455 client_name: str |
None =
None,
456 client_icon: str |
None =
None,
457 token_type: str |
None =
None,
458 access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
461 """Create a new refresh token for a user."""
462 if not user.is_active:
463 raise ValueError(
"User is not active")
465 if user.system_generated
and client_id
is not None:
467 "System generated users cannot have refresh tokens connected "
471 if token_type
is None:
472 if user.system_generated:
473 token_type = models.TOKEN_TYPE_SYSTEM
475 token_type = models.TOKEN_TYPE_NORMAL
477 if token_type
is models.TOKEN_TYPE_NORMAL:
478 expire_at = time.time() + REFRESH_TOKEN_EXPIRATION
482 if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
484 "System generated users can only have system type refresh tokens"
487 if token_type == models.TOKEN_TYPE_NORMAL
and client_id
is None:
488 raise ValueError(
"Client is required to generate a refresh token.")
491 token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
492 and client_name
is None
494 raise ValueError(
"Client_name is required for long-lived access token")
496 if token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN:
497 for token
in user.refresh_tokens.values():
499 token.client_name == client_name
500 and token.token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
504 raise ValueError(f
"{client_name} already exists")
512 access_token_expiration,
519 """Get refresh token by id."""
526 """Get refresh token by token."""
531 """Delete a refresh token."""
534 callbacks = self._revoke_callbacks.pop(refresh_token.id, ())
535 for revoke_callback
in callbacks:
540 self, refresh_token: models.RefreshToken, *, enable_expiry: bool
542 """Enable or disable expiry of a refresh token."""
547 """Remove expired refresh tokens."""
549 for token
in self.
_store_store.async_get_refresh_tokens():
550 if (expire_at := token.expire_at)
is not None and expire_at <= now:
556 """Initialise all token expiration scheduled tasks."""
557 next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION
558 for token
in self.
_store_store.async_get_refresh_tokens():
560 expire_at := token.expire_at
561 )
is not None and expire_at < next_expiration:
562 next_expiration = expire_at
567 dt_util.utc_from_timestamp(next_expiration),
572 """Cancel tracking of expired refresh tokens."""
579 self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE
581 """Unregister a callback."""
582 callbacks.remove(callback_)
586 self, refresh_token_id: str, revoke_callback: CALLBACK_TYPE
588 """Register a callback to be called when the refresh token id is revoked."""
589 if refresh_token_id
not in self._revoke_callbacks:
590 self._revoke_callbacks[refresh_token_id] = set()
592 callbacks = self._revoke_callbacks[refresh_token_id]
593 callbacks.add(revoke_callback)
594 return partial(self.
_async_unregister_async_unregister, callbacks, revoke_callback)
598 self, refresh_token: models.RefreshToken, remote_ip: str |
None =
None
600 """Create a new access token."""
603 self.
_store_store.async_log_refresh_token_usage(refresh_token, remote_ip)
605 now =
int(time.time())
606 expire_seconds =
int(refresh_token.access_token_expiration.total_seconds())
609 "iss": refresh_token.id,
611 "exp": now + expire_seconds,
613 refresh_token.jwt_key,
619 self, refresh_token: models.RefreshToken
620 ) -> AuthProvider |
None:
621 """Get the auth provider for the given refresh token.
623 Raises an exception if the expected provider is no longer available or return
624 None if no provider was expected for this refresh token.
626 if refresh_token.credential
is None:
630 refresh_token.credential.auth_provider_type,
631 refresh_token.credential.auth_provider_id,
635 f
"Auth provider {refresh_token.credential.auth_provider_type},"
636 f
" {refresh_token.credential.auth_provider_id} not available"
642 self, refresh_token: models.RefreshToken, remote_ip: str |
None =
None
644 """Validate that a refresh token is usable.
646 Will raise InvalidAuthError on errors.
649 provider.async_validate_refresh_token(refresh_token, remote_ip)
653 """Return refresh token if an access token is valid."""
655 unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
656 except jwt.InvalidTokenError:
660 cast(str, unverif_claims.get(
"iss"))
663 if refresh_token
is None:
667 jwt_key = refresh_token.jwt_key
668 issuer = refresh_token.id
671 jwt_wrapper.verify_and_decode(
672 token, jwt_key, leeway=10, issuer=issuer, algorithms=[
"HS256"]
674 except jwt.InvalidTokenError:
677 if refresh_token
is None or not refresh_token.user.is_active:
684 self, credentials: models.Credentials
685 ) -> AuthProvider |
None:
686 """Get auth provider from a set of credentials."""
687 auth_provider_key = (
688 credentials.auth_provider_type,
689 credentials.auth_provider_id,
694 """Determine if user should be owner.
696 A user should be an owner if it is the first non-system user that is
700 if not user.system_generated:
AuthFlowResult async_finish_flow(self, FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]] flow, AuthFlowResult result)
LoginFlow async_create_flow(self, tuple[str, str] handler_key, *AuthFlowContext|None context=None, dict[str, Any]|None data=None)
None __init__(self, HomeAssistant hass, AuthManager auth_manager)
models.RefreshToken|None async_get_refresh_token(self, str token_id)
None async_set_expiry(self, models.RefreshToken refresh_token, *bool enable_expiry)
None async_update_user_credentials_data(self, models.Credentials credentials, dict[str, Any] data)
None _async_unregister(self, set[CALLBACK_TYPE] callbacks, CALLBACK_TYPE callback_)
AuthProvider|None get_auth_provider(self, str provider_type, str|None provider_id)
models.RefreshToken|None async_validate_access_token(self, str token)
str async_create_access_token(self, models.RefreshToken refresh_token, str|None remote_ip=None)
models.User async_get_or_create_user(self, models.Credentials credentials)
None async_remove_user(self, models.User user)
models.RefreshToken|None async_get_refresh_token_by_token(self, str token)
None async_enable_user_mfa(self, models.User user, str mfa_module_id, Any data)
None async_remove_refresh_token(self, models.RefreshToken refresh_token)
AuthProvider|None _async_resolve_provider(self, models.RefreshToken refresh_token)
None _async_track_next_refresh_token_expiration(self)
None async_link_user(self, models.User user, models.Credentials credentials)
models.User|None async_get_user_by_credentials(self, models.Credentials credentials)
None async_validate_refresh_token(self, models.RefreshToken refresh_token, str|None remote_ip=None)
AuthProvider|None _async_get_auth_provider(self, models.Credentials credentials)
MultiFactorAuthModule|None get_auth_mfa_module(self, str module_id)
None async_disable_user_mfa(self, models.User user, str mfa_module_id)
None _async_remove_expired_refresh_tokens(self, datetime|None _=None)
dict[str, str] async_get_enabled_mfa(self, models.User user)
models.User|None async_get_owner(self)
None _async_cancel_expiration_schedule(self)
None async_activate_user(self, models.User user)
list[AuthProvider] get_auth_providers(self, str provider_type)
models.Group|None async_get_group(self, str group_id)
models.User async_create_system_user(self, str name, *list[str]|None group_ids=None, bool|None local_only=None)
None async_deactivate_user(self, models.User user)
None __init__(self, HomeAssistant hass, auth_store.AuthStore store, _ProviderDict providers, _MfaModuleDict mfa_modules)
models.User|None async_get_user(self, str user_id)
models.RefreshToken async_create_refresh_token(self, models.User user, str|None client_id=None, str|None client_name=None, str|None client_icon=None, str|None token_type=None, timedelta access_token_expiration=ACCESS_TOKEN_EXPIRATION, models.Credentials|None credential=None)
list[models.User] async_get_users(self)
None async_update_user(self, models.User user, str|None name=None, bool|None is_active=None, list[str]|None group_ids=None, bool|None local_only=None)
models.User async_create_user(self, str name, *list[str]|None group_ids=None, bool|None local_only=None)
list[MultiFactorAuthModule] auth_mfa_modules(self)
CALLBACK_TYPE async_register_revoke_token_callback(self, str refresh_token_id, CALLBACK_TYPE revoke_callback)
bool _user_should_be_owner(self)
list[AuthProvider] auth_providers(self)
None async_remove_credentials(self, models.Credentials credentials)
MultiFactorAuthModule auth_mfa_module_from_config(HomeAssistant hass, dict[str, Any] config)
AuthProvider auth_provider_from_config(HomeAssistant hass, AuthStore store, dict[str, Any] config)
AuthManager auth_manager_from_config(HomeAssistant hass, list[dict[str, Any]] provider_configs, list[dict[str, Any]] module_configs)
web.Response get(self, web.Request request, str config_key)
CALLBACK_TYPE async_track_point_in_utc_time(HomeAssistant hass, HassJob[[datetime], Coroutine[Any, Any, None]|None]|Callable[[datetime], Coroutine[Any, Any, None]|None] action, datetime point_in_time)