Home Assistant Unofficial Reference 2024.12.1
config_entry_oauth2_flow.py
Go to the documentation of this file.
1 """Config Flow using OAuth2.
2 
3 This module exists of the following parts:
4  - OAuth2 config flow which supports multiple OAuth2 implementations
5  - OAuth2 implementation that works with local provided client ID/secret
6 
7 """
8 
9 from __future__ import annotations
10 
11 from abc import ABC, ABCMeta, abstractmethod
12 import asyncio
13 from asyncio import Lock
14 from collections.abc import Awaitable, Callable
15 from http import HTTPStatus
16 from json import JSONDecodeError
17 import logging
18 import secrets
19 import time
20 from typing import Any, cast
21 
22 from aiohttp import ClientError, ClientResponseError, client, web
23 import jwt
24 import voluptuous as vol
25 from yarl import URL
26 
27 from homeassistant import config_entries
28 from homeassistant.components import http
29 from homeassistant.core import HomeAssistant, callback
30 from homeassistant.loader import async_get_application_credentials
31 from homeassistant.util.hass_dict import HassKey
32 
33 from .aiohttp_client import async_get_clientsession
34 from .network import NoURLAvailableError
35 
36 _LOGGER = logging.getLogger(__name__)
37 
38 DATA_JWT_SECRET = "oauth2_jwt_secret"
39 DATA_IMPLEMENTATIONS: HassKey[dict[str, dict[str, AbstractOAuth2Implementation]]] = (
40  HassKey("oauth2_impl")
41 )
42 DATA_PROVIDERS: HassKey[
43  dict[
44  str,
45  Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]],
46  ]
47 ] = HassKey("oauth2_providers")
48 AUTH_CALLBACK_PATH = "/auth/external/callback"
49 HEADER_FRONTEND_BASE = "HA-Frontend-Base"
50 MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
51 
52 CLOCK_OUT_OF_SYNC_MAX_SEC = 20
53 
54 OAUTH_AUTHORIZE_URL_TIMEOUT_SEC = 30
55 OAUTH_TOKEN_TIMEOUT_SEC = 30
56 
57 
59  """Base class to abstract OAuth2 authentication."""
60 
61  @property
62  @abstractmethod
63  def name(self) -> str:
64  """Name of the implementation."""
65 
66  @property
67  @abstractmethod
68  def domain(self) -> str:
69  """Domain that is providing the implementation."""
70 
71  @abstractmethod
72  async def async_generate_authorize_url(self, flow_id: str) -> str:
73  """Generate a url for the user to authorize.
74 
75  This step is called when a config flow is initialized. It should redirect the
76  user to the vendor website where they can authorize Home Assistant.
77 
78  The implementation is responsible to get notified when the user is authorized
79  and pass this to the specified config flow. Do as little work as possible once
80  notified. You can do the work inside async_resolve_external_data. This will
81  give the best UX.
82 
83  Pass external data in with:
84 
85  await hass.config_entries.flow.async_configure(
86  flow_id=flow_id, user_input={'code': 'abcd', 'state': … }
87 
88  )
89 
90  """
91 
92  @abstractmethod
93  async def async_resolve_external_data(self, external_data: Any) -> dict:
94  """Resolve external data to tokens.
95 
96  Turn the data that the implementation passed to the config flow as external
97  step data into tokens. These tokens will be stored as 'token' in the
98  config entry data.
99  """
100 
101  async def async_refresh_token(self, token: dict) -> dict:
102  """Refresh a token and update expires info."""
103  new_token = await self._async_refresh_token_async_refresh_token(token)
104  # Force int for non-compliant oauth2 providers
105  new_token["expires_in"] = int(new_token["expires_in"])
106  new_token["expires_at"] = time.time() + new_token["expires_in"]
107  return new_token
108 
109  @abstractmethod
110  async def _async_refresh_token(self, token: dict) -> dict:
111  """Refresh a token."""
112 
113 
115  """Local OAuth2 implementation."""
116 
117  def __init__(
118  self,
119  hass: HomeAssistant,
120  domain: str,
121  client_id: str,
122  client_secret: str,
123  authorize_url: str,
124  token_url: str,
125  ) -> None:
126  """Initialize local auth implementation."""
127  self.hasshass = hass
128  self._domain_domain = domain
129  self.client_idclient_id = client_id
130  self.client_secretclient_secret = client_secret
131  self.authorize_urlauthorize_url = authorize_url
132  self.token_urltoken_url = token_url
133 
134  @property
135  def name(self) -> str:
136  """Name of the implementation."""
137  return "Configuration.yaml"
138 
139  @property
140  def domain(self) -> str:
141  """Domain providing the implementation."""
142  return self._domain_domain
143 
144  @property
145  def redirect_uri(self) -> str:
146  """Return the redirect uri."""
147  if "my" in self.hasshass.config.components:
148  return MY_AUTH_CALLBACK_PATH
149 
150  if (req := http.current_request.get()) is None:
151  raise RuntimeError("No current request in context")
152 
153  if (ha_host := req.headers.get(HEADER_FRONTEND_BASE)) is None:
154  raise RuntimeError("No header in request")
155 
156  return f"{ha_host}{AUTH_CALLBACK_PATH}"
157 
158  @property
159  def extra_authorize_data(self) -> dict:
160  """Extra data that needs to be appended to the authorize url."""
161  return {}
162 
163  async def async_generate_authorize_url(self, flow_id: str) -> str:
164  """Generate a url for the user to authorize."""
165  redirect_uri = self.redirect_uriredirect_uri
166  return str(
167  URL(self.authorize_urlauthorize_url)
168  .with_query(
169  {
170  "response_type": "code",
171  "client_id": self.client_idclient_id,
172  "redirect_uri": redirect_uri,
173  "state": _encode_jwt(
174  self.hasshass, {"flow_id": flow_id, "redirect_uri": redirect_uri}
175  ),
176  }
177  )
178  .update_query(self.extra_authorize_dataextra_authorize_data)
179  )
180 
181  async def async_resolve_external_data(self, external_data: Any) -> dict:
182  """Resolve the authorization code to tokens."""
183  return await self._token_request_token_request(
184  {
185  "grant_type": "authorization_code",
186  "code": external_data["code"],
187  "redirect_uri": external_data["state"]["redirect_uri"],
188  }
189  )
190 
191  async def _async_refresh_token(self, token: dict) -> dict:
192  """Refresh tokens."""
193  new_token = await self._token_request_token_request(
194  {
195  "grant_type": "refresh_token",
196  "client_id": self.client_idclient_id,
197  "refresh_token": token["refresh_token"],
198  }
199  )
200  return {**token, **new_token}
201 
202  async def _token_request(self, data: dict) -> dict:
203  """Make a token request."""
204  session = async_get_clientsession(self.hasshass)
205 
206  data["client_id"] = self.client_idclient_id
207 
208  if self.client_secretclient_secret is not None:
209  data["client_secret"] = self.client_secretclient_secret
210 
211  _LOGGER.debug("Sending token request to %s", self.token_urltoken_url)
212  resp = await session.post(self.token_urltoken_url, data=data)
213  if resp.status >= 400:
214  try:
215  error_response = await resp.json()
216  except (ClientError, JSONDecodeError):
217  error_response = {}
218  error_code = error_response.get("error", "unknown")
219  error_description = error_response.get("error_description", "unknown error")
220  _LOGGER.error(
221  "Token request for %s failed (%s): %s",
222  self.domaindomaindomain,
223  error_code,
224  error_description,
225  )
226  resp.raise_for_status()
227  return cast(dict, await resp.json())
228 
229 
231  """Handle a config flow."""
232 
233  DOMAIN = ""
234 
235  VERSION = 1
236 
237  def __init__(self) -> None:
238  """Instantiate config flow."""
239  if self.DOMAINDOMAINDOMAIN == "":
240  raise TypeError(
241  f"Can't instantiate class {self.__class__.__name__} without DOMAIN"
242  " being set"
243  )
244 
245  self.external_dataexternal_data: Any = None
246  self.flow_implflow_impl: AbstractOAuth2Implementation = None # type: ignore[assignment]
247 
248  @property
249  @abstractmethod
250  def logger(self) -> logging.Logger:
251  """Return logger."""
252 
253  @property
254  def extra_authorize_data(self) -> dict:
255  """Extra data that needs to be appended to the authorize url."""
256  return {}
257 
258  async def async_generate_authorize_url(self) -> str:
259  """Generate a url for the user to authorize."""
260  url = await self.flow_implflow_impl.async_generate_authorize_url(self.flow_id)
261  return str(URL(url).update_query(self.extra_authorize_dataextra_authorize_data))
262 
264  self, user_input: dict | None = None
266  """Handle a flow start."""
267  implementations = await async_get_implementations(self.hass, self.DOMAINDOMAINDOMAIN)
268 
269  if user_input is not None:
270  self.flow_implflow_impl = implementations[user_input["implementation"]]
271  return await self.async_step_authasync_step_auth()
272 
273  if not implementations:
274  if self.DOMAINDOMAINDOMAIN in await async_get_application_credentials(self.hass):
275  return self.async_abortasync_abortasync_abort(reason="missing_credentials")
276  return self.async_abortasync_abortasync_abort(reason="missing_configuration")
277 
278  req = http.current_request.get()
279  if len(implementations) == 1 and req is not None:
280  # Pick first implementation if we have only one, but only
281  # if this is triggered by a user interaction (request).
282  self.flow_implflow_impl = list(implementations.values())[0]
283  return await self.async_step_authasync_step_auth()
284 
285  return self.async_show_formasync_show_formasync_show_form(
286  step_id="pick_implementation",
287  data_schema=vol.Schema(
288  {
289  vol.Required(
290  "implementation", default=list(implementations)[0]
291  ): vol.In({key: impl.name for key, impl in implementations.items()})
292  }
293  ),
294  )
295 
296  async def async_step_auth(
297  self, user_input: dict[str, Any] | None = None
299  """Create an entry for auth."""
300  # Flow has been triggered by external data
301  if user_input is not None:
302  self.external_dataexternal_data = user_input
303  next_step = "authorize_rejected" if "error" in user_input else "creation"
304  return self.async_external_step_doneasync_external_step_done(next_step_id=next_step)
305 
306  try:
307  async with asyncio.timeout(OAUTH_AUTHORIZE_URL_TIMEOUT_SEC):
308  url = await self.async_generate_authorize_urlasync_generate_authorize_url()
309  except TimeoutError as err:
310  _LOGGER.error("Timeout generating authorize url: %s", err)
311  return self.async_abortasync_abortasync_abort(reason="authorize_url_timeout")
312  except NoURLAvailableError:
313  return self.async_abortasync_abortasync_abort(
314  reason="no_url_available",
315  description_placeholders={
316  "docs_url": (
317  "https://www.home-assistant.io/more-info/no-url-available"
318  )
319  },
320  )
321 
322  return self.async_external_stepasync_external_step(step_id="auth", url=url)
323 
325  self, user_input: dict[str, Any] | None = None
327  """Create config entry from external data."""
328  _LOGGER.debug("Creating config entry from external data")
329 
330  try:
331  async with asyncio.timeout(OAUTH_TOKEN_TIMEOUT_SEC):
332  token = await self.flow_implflow_impl.async_resolve_external_data(
333  self.external_dataexternal_data
334  )
335  except TimeoutError as err:
336  _LOGGER.error("Timeout resolving OAuth token: %s", err)
337  return self.async_abortasync_abortasync_abort(reason="oauth_timeout")
338  except (ClientResponseError, ClientError) as err:
339  _LOGGER.error("Error resolving OAuth token: %s", err)
340  if (
341  isinstance(err, ClientResponseError)
342  and err.status == HTTPStatus.UNAUTHORIZED
343  ):
344  return self.async_abortasync_abortasync_abort(reason="oauth_unauthorized")
345  return self.async_abortasync_abortasync_abort(reason="oauth_failed")
346 
347  if "expires_in" not in token:
348  _LOGGER.warning("Invalid token: %s", token)
349  return self.async_abortasync_abortasync_abort(reason="oauth_error")
350 
351  # Force int for non-compliant oauth2 providers
352  try:
353  token["expires_in"] = int(token["expires_in"])
354  except ValueError as err:
355  _LOGGER.warning("Error converting expires_in to int: %s", err)
356  return self.async_abortasync_abortasync_abort(reason="oauth_error")
357  token["expires_at"] = time.time() + token["expires_in"]
358 
359  self.loggerlogger.info("Successfully authenticated")
360 
361  return await self.async_oauth_create_entryasync_oauth_create_entry(
362  {"auth_implementation": self.flow_implflow_impl.domain, "token": token}
363  )
364 
366  self, data: None = None
368  """Step to handle flow rejection."""
369  return self.async_abortasync_abortasync_abort(
370  reason="user_rejected_authorize",
371  description_placeholders={"error": self.external_dataexternal_data["error"]},
372  )
373 
375  self, data: dict
377  """Create an entry for the flow.
378 
379  Ok to override if you want to fetch extra info or even add another step.
380  """
381  return self.async_create_entryasync_create_entryasync_create_entry(title=self.flow_implflow_impl.name, data=data)
382 
383  async def async_step_user(
384  self, user_input: dict[str, Any] | None = None
386  """Handle a flow start."""
387  return await self.async_step_pick_implementationasync_step_pick_implementation(user_input)
388 
389  @classmethod
391  cls, hass: HomeAssistant, local_impl: LocalOAuth2Implementation
392  ) -> None:
393  """Register a local implementation."""
394  async_register_implementation(hass, cls.DOMAINDOMAINDOMAIN, local_impl)
395 
396 
397 @callback
399  hass: HomeAssistant, domain: str, implementation: AbstractOAuth2Implementation
400 ) -> None:
401  """Register an OAuth2 flow implementation for an integration."""
402  implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
403  implementations.setdefault(domain, {})[implementation.domain] = implementation
404 
405 
407  hass: HomeAssistant, domain: str
408 ) -> dict[str, AbstractOAuth2Implementation]:
409  """Return OAuth2 implementations for specified domain."""
410  registered = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {})
411 
412  if DATA_PROVIDERS not in hass.data:
413  return registered
414 
415  registered = dict(registered)
416  for get_impl in list(hass.data[DATA_PROVIDERS].values()):
417  for impl in await get_impl(hass, domain):
418  registered[impl.domain] = impl
419 
420  return registered
421 
422 
424  hass: HomeAssistant, config_entry: config_entries.ConfigEntry
425 ) -> AbstractOAuth2Implementation:
426  """Return the implementation for this config entry."""
427  implementations = await async_get_implementations(hass, config_entry.domain)
428  implementation = implementations.get(config_entry.data["auth_implementation"])
429 
430  if implementation is None:
431  raise ValueError("Implementation not available")
432 
433  return implementation
434 
435 
436 @callback
438  hass: HomeAssistant,
439  provider_domain: str,
440  async_provide_implementation: Callable[
441  [HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]
442  ],
443 ) -> None:
444  """Add an implementation provider.
445 
446  If no implementation found, return None.
447  """
448  hass.data.setdefault(DATA_PROVIDERS, {})[provider_domain] = (
449  async_provide_implementation
450  )
451 
452 
454  """OAuth2 Authorization Callback View."""
455 
456  requires_auth = False
457  url = AUTH_CALLBACK_PATH
458  name = "auth:external:callback"
459 
460  async def get(self, request: web.Request) -> web.Response:
461  """Receive authorization code."""
462  if "state" not in request.query:
463  return web.Response(text="Missing state parameter")
464 
465  hass = request.app[http.KEY_HASS]
466 
467  state = _decode_jwt(hass, request.query["state"])
468 
469  if state is None:
470  return web.Response(
471  text=(
472  "Invalid state. Is My Home Assistant configured "
473  "to go to the right instance?"
474  ),
475  status=400,
476  )
477 
478  user_input: dict[str, Any] = {"state": state}
479 
480  if "code" in request.query:
481  user_input["code"] = request.query["code"]
482  elif "error" in request.query:
483  user_input["error"] = request.query["error"]
484  else:
485  return web.Response(text="Missing code or error parameter")
486 
487  await hass.config_entries.flow.async_configure(
488  flow_id=state["flow_id"], user_input=user_input
489  )
490  _LOGGER.debug("Resumed OAuth configuration flow")
491  return web.Response(
492  headers={"content-type": "text/html"},
493  text="<script>window.close()</script>",
494  )
495 
496 
498  """Session to make requests authenticated with OAuth2."""
499 
500  def __init__(
501  self,
502  hass: HomeAssistant,
503  config_entry: config_entries.ConfigEntry,
504  implementation: AbstractOAuth2Implementation,
505  ) -> None:
506  """Initialize an OAuth2 session."""
507  self.hasshass = hass
508  self.config_entryconfig_entry = config_entry
509  self.implementationimplementation = implementation
510  self._token_lock_token_lock = Lock()
511 
512  @property
513  def token(self) -> dict:
514  """Return the token."""
515  return cast(dict, self.config_entryconfig_entry.data["token"])
516 
517  @property
518  def valid_token(self) -> bool:
519  """Return if token is still valid."""
520  return (
521  cast(float, self.tokentoken["expires_at"])
522  > time.time() + CLOCK_OUT_OF_SYNC_MAX_SEC
523  )
524 
525  async def async_ensure_token_valid(self) -> None:
526  """Ensure that the current token is valid."""
527  async with self._token_lock_token_lock:
528  if self.valid_tokenvalid_token:
529  return
530 
531  new_token = await self.implementationimplementation.async_refresh_token(self.tokentoken)
532 
533  self.hasshass.config_entries.async_update_entry(
534  self.config_entryconfig_entry, data={**self.config_entryconfig_entry.data, "token": new_token}
535  )
536 
537  async def async_request(
538  self, method: str, url: str, **kwargs: Any
539  ) -> client.ClientResponse:
540  """Make a request."""
541  await self.async_ensure_token_validasync_ensure_token_valid()
542  return await async_oauth2_request(
543  self.hasshass, self.config_entryconfig_entry.data["token"], method, url, **kwargs
544  )
545 
546 
548  hass: HomeAssistant, token: dict, method: str, url: str, **kwargs: Any
549 ) -> client.ClientResponse:
550  """Make an OAuth2 authenticated request.
551 
552  This method will not refresh tokens. Use OAuth2 session for that.
553  """
554  session = async_get_clientsession(hass)
555  headers = kwargs.pop("headers", {})
556  return await session.request(
557  method,
558  url,
559  **kwargs,
560  headers={
561  **headers,
562  "authorization": f"Bearer {token['access_token']}",
563  },
564  )
565 
566 
567 @callback
568 def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
569  """JWT encode data."""
570  if (secret := hass.data.get(DATA_JWT_SECRET)) is None:
571  secret = hass.data[DATA_JWT_SECRET] = secrets.token_hex()
572 
573  return jwt.encode(data, secret, algorithm="HS256")
574 
575 
576 @callback
577 def _decode_jwt(hass: HomeAssistant, encoded: str) -> dict[str, Any] | None:
578  """JWT encode data."""
579  secret: str | None = hass.data.get(DATA_JWT_SECRET)
580 
581  if secret is None:
582  return None
583 
584  try:
585  return jwt.decode(encoded, secret, algorithms=["HS256"]) # type: ignore[no-any-return]
586  except jwt.InvalidTokenError:
587  return None
ConfigFlowResult async_create_entry(self, *str title, Mapping[str, Any] data, str|None description=None, Mapping[str, str]|None description_placeholders=None, Mapping[str, Any]|None options=None)
ConfigFlowResult async_abort(self, *str reason, Mapping[str, str]|None description_placeholders=None)
ConfigFlowResult async_show_form(self, *str|None step_id=None, vol.Schema|None data_schema=None, dict[str, str]|None errors=None, Mapping[str, str]|None description_placeholders=None, bool|None last_step=None, str|None preview=None)
_FlowResultT async_external_step(self, *str|None step_id=None, str url, Mapping[str, str]|None description_placeholders=None)
str
_FlowResultT async_show_form(self, *str|None step_id=None, vol.Schema|None data_schema=None, dict[str, str]|None errors=None, Mapping[str, str]|None description_placeholders=None, bool|None last_step=None, str|None preview=None)
_FlowResultT async_external_step_done(self, *str next_step_id)
_FlowResultT async_create_entry(self, *str|None title=None, Mapping[str, Any] data, str|None description=None, Mapping[str, str]|None description_placeholders=None)
_FlowResultT async_abort(self, *str reason, Mapping[str, str]|None description_placeholders=None)
dict extra_authorize_data(self)
config_entries.ConfigFlowResult async_step_pick_implementation(self, dict|None user_input=None)
DOMAIN
config_entries.ConfigFlowResult async_step_auth(self, dict[str, Any]|None user_input=None)
string DOMAIN
None __init__(self)
config_entries.ConfigFlowResult async_step_authorize_rejected(self, None data=None)
external_data
config_entries.ConfigFlowResult async_oauth_create_entry(self, dict data)
str async_generate_authorize_url(self)
flow_impl
config_entries.ConfigFlowResult async_step_creation(self, dict[str, Any]|None user_input=None)
logging.Logger logger(self)
config_entries.ConfigFlowResult async_step_user(self, dict[str, Any]|None user_input=None)
None async_register_implementation(cls, HomeAssistant hass, LocalOAuth2Implementation local_impl)
dict async_refresh_token(self, dict token)
dict _async_refresh_token(self, dict token)
str name(self)
str async_generate_authorize_url(self, str flow_id)
dict async_resolve_external_data(self, Any external_data)
str domain(self)
authorize_url
None __init__(self, HomeAssistant hass, str domain, str client_id, str client_secret, str authorize_url, str token_url)
str domain(self)
client_id
str async_generate_authorize_url(self, str flow_id)
str name(self)
dict extra_authorize_data(self)
dict _token_request(self, dict data)
_domain
dict async_resolve_external_data(self, Any external_data)
str redirect_uri(self)
hass
dict _async_refresh_token(self, dict token)
token_url
client_secret
web.Response get(self, web.Request request)
implementation
config_entry
None __init__(self, HomeAssistant hass, config_entries.ConfigEntry config_entry, AbstractOAuth2Implementation implementation)
None async_ensure_token_valid(self)
client.ClientResponse async_request(self, str method, str url, **Any kwargs)
dict token(self)
hass
bool valid_token(self)
_token_lock
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88
aiohttp.ClientSession async_get_clientsession(HomeAssistant hass, bool verify_ssl=True, socket.AddressFamily family=socket.AF_UNSPEC, ssl_util.SSLCipherList ssl_cipher=ssl_util.SSLCipherList.PYTHON_DEFAULT)
None async_register_implementation(HomeAssistant hass, str domain, AbstractOAuth2Implementation implementation)
str _encode_jwt(HomeAssistant hass, dict data)
dict[str, Any]|None _decode_jwt(HomeAssistant hass, str encoded)
AbstractOAuth2Implementation async_get_config_entry_implementation(HomeAssistant hass, config_entries.ConfigEntry config_entry)
dict[str, AbstractOAuth2Implementation] async_get_implementations(HomeAssistant hass, str domain)
client.ClientResponse async_oauth2_request(HomeAssistant hass, dict token, str method, str url, **Any kwargs)
None async_add_implementation_provider(HomeAssistant hass, str provider_domain, Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]] async_provide_implementation)
list[str] async_get_application_credentials(HomeAssistant hass)
Definition: loader.py:453