Home Assistant Unofficial Reference 2024.12.1
config_flow.py
Go to the documentation of this file.
1 """Config flow for MySensors."""
2 
3 from __future__ import annotations
4 
5 import os
6 from typing import Any
7 
8 from awesomeversion import (
9  AwesomeVersion,
10  AwesomeVersionStrategy,
11  AwesomeVersionStrategyException,
12 )
13 import voluptuous as vol
14 
16  DOMAIN as MQTT_DOMAIN,
17  valid_publish_topic,
18  valid_subscribe_topic,
19 )
20 from homeassistant.config_entries import ConfigEntry, ConfigFlow, ConfigFlowResult
21 from homeassistant.const import CONF_DEVICE
22 from homeassistant.core import callback
23 from homeassistant.helpers import selector
25 from homeassistant.helpers.typing import VolDictType
26 
27 from .const import (
28  CONF_BAUD_RATE,
29  CONF_GATEWAY_TYPE,
30  CONF_GATEWAY_TYPE_MQTT,
31  CONF_GATEWAY_TYPE_SERIAL,
32  CONF_GATEWAY_TYPE_TCP,
33  CONF_PERSISTENCE_FILE,
34  CONF_RETAIN,
35  CONF_TCP_PORT,
36  CONF_TOPIC_IN_PREFIX,
37  CONF_TOPIC_OUT_PREFIX,
38  CONF_VERSION,
39  DOMAIN,
40  ConfGatewayType,
41 )
42 from .gateway import MQTT_COMPONENT, is_serial_port, is_socket_address, try_connect
43 
44 DEFAULT_BAUD_RATE = 115200
45 DEFAULT_TCP_PORT = 5003
46 DEFAULT_VERSION = "1.4"
47 
48 _PORT_SELECTOR = vol.All(
49  selector.NumberSelector(
50  selector.NumberSelectorConfig(
51  min=1, max=65535, mode=selector.NumberSelectorMode.BOX
52  ),
53  ),
54  vol.Coerce(int),
55 )
56 
57 
58 def is_persistence_file(value: str) -> str:
59  """Validate that persistence file path ends in either .pickle or .json."""
60  if value.endswith((".json", ".pickle")):
61  return value
62  raise vol.Invalid(f"{value} does not end in either `.json` or `.pickle`")
63 
64 
65 def _get_schema_common(user_input: dict[str, str]) -> dict:
66  """Create a schema with options common to all gateway types."""
67  return {
68  vol.Required(
69  CONF_VERSION,
70  description={
71  "suggested_value": user_input.get(CONF_VERSION, DEFAULT_VERSION)
72  },
73  ): str,
74  vol.Optional(CONF_PERSISTENCE_FILE): str,
75  }
76 
77 
78 def _validate_version(version: str) -> dict[str, str]:
79  """Validate a version string from the user."""
80  version_okay = True
81  try:
82  AwesomeVersion(
83  version,
84  ensure_strategy=[
85  AwesomeVersionStrategy.SIMPLEVER,
86  AwesomeVersionStrategy.SEMVER,
87  ],
88  )
89  except AwesomeVersionStrategyException:
90  version_okay = False
91 
92  if version_okay:
93  return {}
94  return {CONF_VERSION: "invalid_version"}
95 
96 
98  gw_type: ConfGatewayType, user_input: dict[str, Any], entry: ConfigEntry
99 ) -> bool:
100  """Check if another ConfigDevice is actually the same as user_input.
101 
102  This function only compares addresses and tcp ports, so it is possible to fool it with tricks like port forwarding.
103  """
104  if entry.data[CONF_DEVICE] != user_input[CONF_DEVICE]:
105  return False
106  if gw_type == CONF_GATEWAY_TYPE_TCP:
107  entry_tcp_port: int = entry.data[CONF_TCP_PORT]
108  input_tcp_port: int = user_input[CONF_TCP_PORT]
109  return entry_tcp_port == input_tcp_port
110  if gw_type == CONF_GATEWAY_TYPE_MQTT:
111  entry_topics = {
112  entry.data[CONF_TOPIC_IN_PREFIX],
113  entry.data[CONF_TOPIC_OUT_PREFIX],
114  }
115  return (
116  user_input.get(CONF_TOPIC_IN_PREFIX) in entry_topics
117  or user_input.get(CONF_TOPIC_OUT_PREFIX) in entry_topics
118  )
119  return True
120 
121 
123  """Handle a config flow."""
124 
125  def __init__(self) -> None:
126  """Set up config flow."""
127  self._gw_type_gw_type: str | None = None
128 
129  async def async_step_user(
130  self, user_input: dict[str, str] | None = None
131  ) -> ConfigFlowResult:
132  """Create a config entry from frontend user input."""
133  return await self.async_step_select_gateway_typeasync_step_select_gateway_type()
134 
136  self, user_input: dict[str, str] | None = None
137  ) -> ConfigFlowResult:
138  """Show the select gateway type menu."""
139  return self.async_show_menuasync_show_menu(
140  step_id="select_gateway_type",
141  menu_options=["gw_serial", "gw_tcp", "gw_mqtt"],
142  )
143 
145  self, user_input: dict[str, Any] | None = None
146  ) -> ConfigFlowResult:
147  """Create config entry for a serial gateway."""
148  gw_type = self._gw_type_gw_type = CONF_GATEWAY_TYPE_SERIAL
149  errors: dict[str, str] = {}
150 
151  if user_input is not None:
152  errors.update(await self.validate_commonvalidate_common(gw_type, errors, user_input))
153  if not errors:
154  return self._async_create_entry_async_create_entry(user_input)
155 
156  user_input = user_input or {}
157  schema: VolDictType = {
158  vol.Required(
159  CONF_DEVICE, default=user_input.get(CONF_DEVICE, "/dev/ttyACM0")
160  ): str,
161  vol.Required(
162  CONF_BAUD_RATE,
163  default=user_input.get(CONF_BAUD_RATE, DEFAULT_BAUD_RATE),
164  ): cv.positive_int,
165  }
166  schema.update(_get_schema_common(user_input))
167 
168  return self.async_show_formasync_show_formasync_show_form(
169  step_id="gw_serial", data_schema=vol.Schema(schema), errors=errors
170  )
171 
172  async def async_step_gw_tcp(
173  self, user_input: dict[str, Any] | None = None
174  ) -> ConfigFlowResult:
175  """Create a config entry for a tcp gateway."""
176  gw_type = self._gw_type_gw_type = CONF_GATEWAY_TYPE_TCP
177  errors: dict[str, str] = {}
178 
179  if user_input is not None:
180  errors.update(await self.validate_commonvalidate_common(gw_type, errors, user_input))
181  if not errors:
182  return self._async_create_entry_async_create_entry(user_input)
183 
184  user_input = user_input or {}
185  schema: VolDictType = {
186  vol.Required(
187  CONF_DEVICE, default=user_input.get(CONF_DEVICE, "127.0.0.1")
188  ): str,
189  vol.Optional(
190  CONF_TCP_PORT, default=user_input.get(CONF_TCP_PORT, DEFAULT_TCP_PORT)
191  ): _PORT_SELECTOR,
192  }
193  schema.update(_get_schema_common(user_input))
194 
195  return self.async_show_formasync_show_formasync_show_form(
196  step_id="gw_tcp", data_schema=vol.Schema(schema), errors=errors
197  )
198 
199  def _check_topic_exists(self, topic: str) -> bool:
200  for other_config in self._async_current_entries_async_current_entries():
201  if topic == other_config.data.get(
202  CONF_TOPIC_IN_PREFIX
203  ) or topic == other_config.data.get(CONF_TOPIC_OUT_PREFIX):
204  return True
205  return False
206 
208  self, user_input: dict[str, Any] | None = None
209  ) -> ConfigFlowResult:
210  """Create a config entry for a mqtt gateway."""
211  # Naive check that doesn't consider config entry state.
212  if MQTT_DOMAIN not in self.hass.config.components:
213  return self.async_abortasync_abortasync_abort(reason="mqtt_required")
214 
215  gw_type = self._gw_type_gw_type = CONF_GATEWAY_TYPE_MQTT
216  errors: dict[str, str] = {}
217 
218  if user_input is not None:
219  user_input[CONF_DEVICE] = MQTT_COMPONENT
220 
221  try:
222  valid_subscribe_topic(user_input[CONF_TOPIC_IN_PREFIX])
223  except vol.Invalid:
224  errors[CONF_TOPIC_IN_PREFIX] = "invalid_subscribe_topic"
225  else:
226  if self._check_topic_exists_check_topic_exists(user_input[CONF_TOPIC_IN_PREFIX]):
227  errors[CONF_TOPIC_IN_PREFIX] = "duplicate_topic"
228 
229  try:
230  valid_publish_topic(user_input[CONF_TOPIC_OUT_PREFIX])
231  except vol.Invalid:
232  errors[CONF_TOPIC_OUT_PREFIX] = "invalid_publish_topic"
233  if not errors:
234  if (
235  user_input[CONF_TOPIC_IN_PREFIX]
236  == user_input[CONF_TOPIC_OUT_PREFIX]
237  ):
238  errors[CONF_TOPIC_OUT_PREFIX] = "same_topic"
239  elif self._check_topic_exists_check_topic_exists(user_input[CONF_TOPIC_OUT_PREFIX]):
240  errors[CONF_TOPIC_OUT_PREFIX] = "duplicate_topic"
241 
242  errors.update(await self.validate_commonvalidate_common(gw_type, errors, user_input))
243  if not errors:
244  return self._async_create_entry_async_create_entry(user_input)
245 
246  user_input = user_input or {}
247  schema: VolDictType = {
248  vol.Required(
249  CONF_TOPIC_IN_PREFIX, default=user_input.get(CONF_TOPIC_IN_PREFIX, "")
250  ): str,
251  vol.Required(
252  CONF_TOPIC_OUT_PREFIX, default=user_input.get(CONF_TOPIC_OUT_PREFIX, "")
253  ): str,
254  vol.Required(CONF_RETAIN, default=user_input.get(CONF_RETAIN, True)): bool,
255  }
256  schema.update(_get_schema_common(user_input))
257 
258  return self.async_show_formasync_show_formasync_show_form(
259  step_id="gw_mqtt", data_schema=vol.Schema(schema), errors=errors
260  )
261 
262  @callback
263  def _async_create_entry(self, user_input: dict[str, Any]) -> ConfigFlowResult:
264  """Create the config entry."""
265  return self.async_create_entryasync_create_entryasync_create_entry(
266  title=f"{user_input[CONF_DEVICE]}",
267  data={**user_input, CONF_GATEWAY_TYPE: self._gw_type_gw_type},
268  )
269 
270  def _normalize_persistence_file(self, path: str) -> str:
271  return os.path.realpath(os.path.normcase(self.hass.config.path(path)))
272 
273  async def validate_common(
274  self,
275  gw_type: ConfGatewayType,
276  errors: dict[str, str],
277  user_input: dict[str, Any],
278  ) -> dict[str, str]:
279  """Validate parameters common to all gateway types."""
280  errors.update(_validate_version(user_input[CONF_VERSION]))
281 
282  if gw_type != CONF_GATEWAY_TYPE_MQTT:
283  if gw_type == CONF_GATEWAY_TYPE_TCP:
284  verification_func = is_socket_address
285  else:
286  verification_func = is_serial_port
287 
288  try:
289  await self.hass.async_add_executor_job(
290  verification_func, user_input[CONF_DEVICE]
291  )
292  except vol.Invalid:
293  errors[CONF_DEVICE] = (
294  "invalid_ip"
295  if gw_type == CONF_GATEWAY_TYPE_TCP
296  else "invalid_serial"
297  )
298  if CONF_PERSISTENCE_FILE in user_input:
299  try:
300  is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
301  except vol.Invalid:
302  errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
303  else:
304  real_persistence_path = user_input[CONF_PERSISTENCE_FILE] = (
305  self._normalize_persistence_file_normalize_persistence_file(user_input[CONF_PERSISTENCE_FILE])
306  )
307  for other_entry in self._async_current_entries_async_current_entries():
308  if CONF_PERSISTENCE_FILE not in other_entry.data:
309  continue
310  if real_persistence_path == self._normalize_persistence_file_normalize_persistence_file(
311  other_entry.data[CONF_PERSISTENCE_FILE]
312  ):
313  errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
314  break
315 
316  if not errors:
317  for other_entry in self._async_current_entries_async_current_entries():
318  if _is_same_device(gw_type, user_input, other_entry):
319  errors["base"] = "already_configured"
320  break
321 
322  # if no errors so far, try to connect
323  if not errors and not await try_connect(self.hass, gw_type, user_input):
324  errors["base"] = "cannot_connect"
325 
326  return errors
ConfigFlowResult async_step_gw_tcp(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:174
ConfigFlowResult async_step_select_gateway_type(self, dict[str, str]|None user_input=None)
Definition: config_flow.py:137
ConfigFlowResult _async_create_entry(self, dict[str, Any] user_input)
Definition: config_flow.py:263
ConfigFlowResult async_step_gw_serial(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:146
ConfigFlowResult async_step_user(self, dict[str, str]|None user_input=None)
Definition: config_flow.py:131
ConfigFlowResult async_step_gw_mqtt(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:209
dict[str, str] validate_common(self, ConfGatewayType gw_type, dict[str, str] errors, dict[str, Any] user_input)
Definition: config_flow.py:278
ConfigFlowResult async_create_entry(self, *str title, Mapping[str, Any] data, str|None description=None, Mapping[str, str]|None description_placeholders=None, Mapping[str, Any]|None options=None)
list[ConfigEntry] _async_current_entries(self, bool|None include_ignore=None)
ConfigFlowResult async_abort(self, *str reason, Mapping[str, str]|None description_placeholders=None)
ConfigFlowResult async_show_form(self, *str|None step_id=None, vol.Schema|None data_schema=None, dict[str, str]|None errors=None, Mapping[str, str]|None description_placeholders=None, bool|None last_step=None, str|None preview=None)
_FlowResultT async_show_form(self, *str|None step_id=None, vol.Schema|None data_schema=None, dict[str, str]|None errors=None, Mapping[str, str]|None description_placeholders=None, bool|None last_step=None, str|None preview=None)
_FlowResultT async_show_menu(self, *str|None step_id=None, Container[str] menu_options, Mapping[str, str]|None description_placeholders=None)
_FlowResultT async_create_entry(self, *str|None title=None, Mapping[str, Any] data, str|None description=None, Mapping[str, str]|None description_placeholders=None)
_FlowResultT async_abort(self, *str reason, Mapping[str, str]|None description_placeholders=None)
str valid_subscribe_topic(Any topic)
Definition: util.py:270
str valid_publish_topic(Any topic)
Definition: util.py:308
dict[str, str] _validate_version(str version)
Definition: config_flow.py:78
dict _get_schema_common(dict[str, str] user_input)
Definition: config_flow.py:65
bool _is_same_device(ConfGatewayType gw_type, dict[str, Any] user_input, ConfigEntry entry)
Definition: config_flow.py:99
bool try_connect(HomeAssistant hass, ConfGatewayType gateway_type, dict[str, Any] user_input)
Definition: gateway.py:82