Home Assistant Unofficial Reference 2024.12.1
account_link.py
Go to the documentation of this file.
1 """Account linking via the cloud."""
2 
3 from __future__ import annotations
4 
5 from datetime import datetime
6 import logging
7 from typing import Any
8 
9 import aiohttp
10 from awesomeversion import AwesomeVersion
11 from hass_nabucasa import account_link
12 
13 from homeassistant.const import __version__ as HA_VERSION
14 from homeassistant.core import HomeAssistant, callback
15 from homeassistant.helpers import config_entry_oauth2_flow, event
16 
17 from .const import DATA_CLOUD, DOMAIN
18 
19 DATA_SERVICES = "cloud_account_link_services"
20 CACHE_TIMEOUT = 3600
21 _LOGGER = logging.getLogger(__name__)
22 
23 CURRENT_VERSION = AwesomeVersion(HA_VERSION)
24 CURRENT_PLAIN_VERSION = AwesomeVersion(
25  CURRENT_VERSION.string.removesuffix(f"{CURRENT_VERSION.modifier}")
26 )
27 
28 
29 @callback
30 def async_setup(hass: HomeAssistant) -> None:
31  """Set up cloud account link."""
32  config_entry_oauth2_flow.async_add_implementation_provider(
33  hass, DOMAIN, async_provide_implementation
34  )
35 
36 
38  hass: HomeAssistant, domain: str
39 ) -> list[config_entry_oauth2_flow.AbstractOAuth2Implementation]:
40  """Provide an implementation for a domain."""
41  services = await _get_services(hass)
42 
43  for service in services:
44  if (
45  service["service"] == domain
46  and service["min_version"] <= CURRENT_PLAIN_VERSION
47  and (
48  service.get("accepts_new_authorizations", True)
49  or (
50  (entries := hass.config_entries.async_entries(domain))
51  and any(
52  entry.data.get("auth_implementation") == DOMAIN
53  for entry in entries
54  )
55  )
56  )
57  ):
58  return [CloudOAuth2Implementation(hass, domain)]
59 
60  return []
61 
62 
63 async def _get_services(hass: HomeAssistant) -> list[dict[str, Any]]:
64  """Get the available services."""
65  services: list[dict[str, Any]]
66  if DATA_SERVICES in hass.data:
67  services = hass.data[DATA_SERVICES]
68  return services # noqa: RET504
69 
70  try:
71  services = await account_link.async_fetch_available_services(
72  hass.data[DATA_CLOUD]
73  )
74  except (aiohttp.ClientError, TimeoutError):
75  return []
76 
77  hass.data[DATA_SERVICES] = services
78 
79  @callback
80  def clear_services(_now: datetime) -> None:
81  """Clear services cache."""
82  hass.data.pop(DATA_SERVICES, None)
83 
84  event.async_call_later(hass, CACHE_TIMEOUT, clear_services)
85 
86  return services
87 
88 
89 class CloudOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation):
90  """Cloud implementation of the OAuth2 flow."""
91 
92  def __init__(self, hass: HomeAssistant, service: str) -> None:
93  """Initialize cloud OAuth2 implementation."""
94  self.hasshass = hass
95  self.serviceservice = service
96 
97  @property
98  def name(self) -> str:
99  """Name of the implementation."""
100  return "Home Assistant Cloud"
101 
102  @property
103  def domain(self) -> str:
104  """Domain that is providing the implementation."""
105  return DOMAIN
106 
107  async def async_generate_authorize_url(self, flow_id: str) -> str:
108  """Generate a url for the user to authorize."""
109  helper = account_link.AuthorizeAccountHelper(
110  self.hasshass.data[DATA_CLOUD], self.serviceservice
111  )
112  authorize_url = await helper.async_get_authorize_url()
113 
114  async def await_tokens() -> None:
115  """Wait for tokens and pass them on when received."""
116  try:
117  tokens = await helper.async_get_tokens()
118 
119  except TimeoutError:
120  _LOGGER.info("Timeout fetching tokens for flow %s", flow_id)
121  except account_link.AccountLinkException as err:
122  _LOGGER.info(
123  "Failed to fetch tokens for flow %s: %s", flow_id, err.code
124  )
125  else:
126  await self.hasshass.config_entries.flow.async_configure(
127  flow_id=flow_id, user_input=tokens
128  )
129 
130  self.hasshass.async_create_task(await_tokens())
131 
132  return authorize_url
133 
134  async def async_resolve_external_data(self, external_data: Any) -> dict:
135  """Resolve external data to tokens."""
136  # We already passed in tokens
137  dict_data: dict = external_data
138  return dict_data
139 
140  async def _async_refresh_token(self, token: dict) -> dict:
141  """Refresh a token."""
142  new_token = await account_link.async_fetch_access_token(
143  self.hasshass.data[DATA_CLOUD], self.serviceservice, token["refresh_token"]
144  )
145  return {**token, **new_token}