Home Assistant Unofficial Reference 2024.12.1
binary_sensor.py
Go to the documentation of this file.
1 """Use Bayesian Inference to trigger a binary sensor."""
2 
3 from __future__ import annotations
4 
5 from collections import OrderedDict
6 from collections.abc import Callable
7 import logging
8 import math
9 from typing import TYPE_CHECKING, Any, NamedTuple
10 from uuid import UUID
11 
12 import voluptuous as vol
13 
15  PLATFORM_SCHEMA as BINARY_SENSOR_PLATFORM_SCHEMA,
16  BinarySensorDeviceClass,
17  BinarySensorEntity,
18 )
19 from homeassistant.const import (
20  CONF_ABOVE,
21  CONF_BELOW,
22  CONF_DEVICE_CLASS,
23  CONF_ENTITY_ID,
24  CONF_NAME,
25  CONF_PLATFORM,
26  CONF_STATE,
27  CONF_UNIQUE_ID,
28  CONF_VALUE_TEMPLATE,
29  STATE_UNAVAILABLE,
30  STATE_UNKNOWN,
31 )
32 from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
33 from homeassistant.exceptions import ConditionError, TemplateError
34 from homeassistant.helpers import condition
36 from homeassistant.helpers.entity_platform import AddEntitiesCallback
37 from homeassistant.helpers.event import (
38  TrackTemplate,
39  TrackTemplateResult,
40  TrackTemplateResultInfo,
41  async_track_state_change_event,
42  async_track_template_result,
43 )
44 from homeassistant.helpers.reload import async_setup_reload_service
45 from homeassistant.helpers.template import Template, result_as_boolean
46 from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
47 
48 from . import DOMAIN, PLATFORMS
49 from .const import (
50  ATTR_OBSERVATIONS,
51  ATTR_OCCURRED_OBSERVATION_ENTITIES,
52  ATTR_PROBABILITY,
53  ATTR_PROBABILITY_THRESHOLD,
54  CONF_NUMERIC_STATE,
55  CONF_OBSERVATIONS,
56  CONF_P_GIVEN_F,
57  CONF_P_GIVEN_T,
58  CONF_PRIOR,
59  CONF_PROBABILITY_THRESHOLD,
60  CONF_TEMPLATE,
61  CONF_TO_STATE,
62  DEFAULT_NAME,
63  DEFAULT_PROBABILITY_THRESHOLD,
64 )
65 from .helpers import Observation
66 from .issues import raise_mirrored_entries, raise_no_prob_given_false
67 
68 _LOGGER = logging.getLogger(__name__)
69 
70 
71 def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
72  if config[CONF_PLATFORM] == CONF_NUMERIC_STATE:
73  above = config.get(CONF_ABOVE)
74  below = config.get(CONF_BELOW)
75  if above is None and below is None:
76  _LOGGER.error(
77  "For bayesian numeric state for entity: %s at least one of 'above' or 'below' must be specified",
78  config[CONF_ENTITY_ID],
79  )
80  raise vol.Invalid(
81  "For bayesian numeric state at least one of 'above' or 'below' must be specified."
82  )
83  if above is not None and below is not None:
84  if above > below:
85  _LOGGER.error(
86  "For bayesian numeric state 'above' (%s) must be less than 'below' (%s)",
87  above,
88  below,
89  )
90  raise vol.Invalid("'above' is greater than 'below'")
91  return config
92 
93 
94 NUMERIC_STATE_SCHEMA = vol.All(
95  vol.Schema(
96  {
97  CONF_PLATFORM: CONF_NUMERIC_STATE,
98  vol.Required(CONF_ENTITY_ID): cv.entity_id,
99  vol.Optional(CONF_ABOVE): vol.Coerce(float),
100  vol.Optional(CONF_BELOW): vol.Coerce(float),
101  vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
102  vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
103  },
104  required=True,
105  ),
106  _above_greater_than_below,
107 )
108 
109 
110 def _no_overlapping(configs: list[dict]) -> list[dict]:
111  numeric_configs = [
112  config for config in configs if config[CONF_PLATFORM] == CONF_NUMERIC_STATE
113  ]
114  if len(numeric_configs) < 2:
115  return configs
116 
117  class NumericConfig(NamedTuple):
118  above: float
119  below: float
120 
121  d: dict[str, list[NumericConfig]] = {}
122  for _, config in enumerate(numeric_configs):
123  above = config.get(CONF_ABOVE, -math.inf)
124  below = config.get(CONF_BELOW, math.inf)
125  entity_id: str = str(config[CONF_ENTITY_ID])
126  d.setdefault(entity_id, []).append(NumericConfig(above, below))
127 
128  for ent_id, intervals in d.items():
129  intervals = sorted(intervals, key=lambda tup: tup.above)
130 
131  for i, tup in enumerate(intervals):
132  if len(intervals) > i + 1 and tup.below > intervals[i + 1].above:
133  raise vol.Invalid(
134  f"Ranges for bayesian numeric state entities must not overlap, but {ent_id} has overlapping ranges, above:{tup.above}, below:{tup.below} overlaps with above:{intervals[i+1].above}, below:{intervals[i+1].below}."
135  )
136  return configs
137 
138 
139 STATE_SCHEMA = vol.Schema(
140  {
141  CONF_PLATFORM: CONF_STATE,
142  vol.Required(CONF_ENTITY_ID): cv.entity_id,
143  vol.Required(CONF_TO_STATE): cv.string,
144  vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
145  vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
146  },
147  required=True,
148 )
149 
150 TEMPLATE_SCHEMA = vol.Schema(
151  {
152  CONF_PLATFORM: CONF_TEMPLATE,
153  vol.Required(CONF_VALUE_TEMPLATE): cv.template,
154  vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
155  vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
156  },
157  required=True,
158 )
159 
160 PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend(
161  {
162  vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
163  vol.Optional(CONF_UNIQUE_ID): cv.string,
164  vol.Optional(CONF_DEVICE_CLASS): cv.string,
165  vol.Required(CONF_OBSERVATIONS): vol.Schema(
166  vol.All(
167  cv.ensure_list,
168  [vol.Any(TEMPLATE_SCHEMA, STATE_SCHEMA, NUMERIC_STATE_SCHEMA)],
169  _no_overlapping,
170  )
171  ),
172  vol.Required(CONF_PRIOR): vol.Coerce(float),
173  vol.Optional(
174  CONF_PROBABILITY_THRESHOLD, default=DEFAULT_PROBABILITY_THRESHOLD
175  ): vol.Coerce(float),
176  }
177 )
178 
179 
181  prior: float, prob_given_true: float, prob_given_false: float
182 ) -> float:
183  """Update probability using Bayes' rule."""
184  numerator = prob_given_true * prior
185  denominator = numerator + prob_given_false * (1 - prior)
186  return numerator / denominator
187 
188 
190  hass: HomeAssistant,
191  config: ConfigType,
192  async_add_entities: AddEntitiesCallback,
193  discovery_info: DiscoveryInfoType | None = None,
194 ) -> None:
195  """Set up the Bayesian Binary sensor."""
196  await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
197 
198  name: str = config[CONF_NAME]
199  unique_id: str | None = config.get(CONF_UNIQUE_ID)
200  observations: list[ConfigType] = config[CONF_OBSERVATIONS]
201  prior: float = config[CONF_PRIOR]
202  probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
203  device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS)
204 
205  # Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
206  broken_observations: list[dict[str, Any]] = []
207  for observation in observations:
208  if CONF_P_GIVEN_F not in observation:
209  text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}"
210  raise_no_prob_given_false(hass, text)
211  _LOGGER.error("Missing prob_given_false YAML entry for %s", text)
212  broken_observations.append(observation)
213  observations = [x for x in observations if x not in broken_observations]
214 
216  [
218  name,
219  unique_id,
220  prior,
221  observations,
222  probability_threshold,
223  device_class,
224  )
225  ]
226  )
227 
228 
230  """Representation of a Bayesian sensor."""
231 
232  _attr_should_poll = False
233 
234  def __init__(
235  self,
236  name: str,
237  unique_id: str | None,
238  prior: float,
239  observations: list[ConfigType],
240  probability_threshold: float,
241  device_class: BinarySensorDeviceClass | None,
242  ) -> None:
243  """Initialize the Bayesian sensor."""
244  self._attr_name_attr_name = name
245  self._attr_unique_id_attr_unique_id = unique_id and f"bayesian-{unique_id}"
246  self._observations_observations = [
247  Observation(
248  entity_id=observation.get(CONF_ENTITY_ID),
249  platform=observation[CONF_PLATFORM],
250  prob_given_false=observation[CONF_P_GIVEN_F],
251  prob_given_true=observation[CONF_P_GIVEN_T],
252  observed=None,
253  to_state=observation.get(CONF_TO_STATE),
254  above=observation.get(CONF_ABOVE),
255  below=observation.get(CONF_BELOW),
256  value_template=observation.get(CONF_VALUE_TEMPLATE),
257  )
258  for observation in observations
259  ]
260  self._probability_threshold_probability_threshold = probability_threshold
261  self._attr_device_class_attr_device_class = device_class
262  self._attr_is_on_attr_is_on = False
263  self._callbacks: list[TrackTemplateResultInfo] = []
264 
265  self.priorprior = prior
266  self.probabilityprobability = prior
267 
268  self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({})
269 
270  self.observations_by_entityobservations_by_entity = self._build_observations_by_entity_build_observations_by_entity()
271  self.observations_by_templateobservations_by_template = self._build_observations_by_template_build_observations_by_template()
272 
273  self.observation_handlers: dict[
274  str, Callable[[Observation, bool], bool | None]
275  ] = {
276  "numeric_state": self._process_numeric_state_process_numeric_state,
277  "state": self._process_state_process_state,
278  }
279 
280  async def async_added_to_hass(self) -> None:
281  """Call when entity about to be added.
282 
283  All relevant update logic for instance attributes occurs within this closure.
284  Other methods in this class are designed to avoid directly modifying instance
285  attributes, by instead focusing on returning relevant data back to this method.
286 
287  The goal of this method is to ensure that `self.current_observations` and `self.probability`
288  are set on a best-effort basis when this entity is register with hass.
289 
290  In addition, this method must register the state listener defined within, which
291  will be called any time a relevant entity changes its state.
292  """
293 
294  @callback
295  def async_threshold_sensor_state_listener(
296  event: Event[EventStateChangedData],
297  ) -> None:
298  """Handle sensor state changes.
299 
300  When a state changes, we must update our list of current observations,
301  then calculate the new probability.
302  """
303 
304  entity_id = event.data["entity_id"]
305 
306  self.current_observations.update(
307  self._record_entity_observations_record_entity_observations(entity_id)
308  )
309  self.async_set_contextasync_set_context(event.context)
310  self._recalculate_and_write_state_recalculate_and_write_state()
311 
312  self.async_on_removeasync_on_remove(
314  self.hasshass,
315  list(self.observations_by_entityobservations_by_entity),
316  async_threshold_sensor_state_listener,
317  )
318  )
319 
320  @callback
321  def _async_template_result_changed(
322  event: Event[EventStateChangedData] | None,
323  updates: list[TrackTemplateResult],
324  ) -> None:
325  track_template_result = updates.pop()
326  template = track_template_result.template
327  result = track_template_result.result
328  entity_id = None if event is None else event.data["entity_id"]
329  if isinstance(result, TemplateError):
330  _LOGGER.error(
331  "TemplateError('%s') while processing template '%s' in entity '%s'",
332  result,
333  template,
334  self.entity_identity_id,
335  )
336 
337  observed = None
338  else:
339  observed = result_as_boolean(result)
340 
341  for observation in self.observations_by_templateobservations_by_template[template]:
342  observation.observed = observed
343 
344  # in some cases a template may update because of the absence of an entity
345  if entity_id is not None:
346  observation.entity_id = entity_id
347 
348  self.current_observations[observation.id] = observation
349 
350  if event:
351  self.async_set_contextasync_set_context(event.context)
352  self._recalculate_and_write_state_recalculate_and_write_state()
353 
354  for template in self.observations_by_templateobservations_by_template:
356  self.hasshass,
357  [TrackTemplate(template, None)],
358  _async_template_result_changed,
359  )
360 
361  self._callbacks.append(info)
362  self.async_on_removeasync_on_remove(info.async_remove)
363  info.async_refresh()
364 
365  self.current_observations.update(self._initialize_current_observations_initialize_current_observations())
366  self.probabilityprobability = self._calculate_new_probability_calculate_new_probability()
367  self._attr_is_on_attr_is_on = self.probabilityprobability >= self._probability_threshold_probability_threshold
368 
369  # detect mirrored entries
370  for entity, observations in self.observations_by_entityobservations_by_entity.items():
372  self.hasshass, observations, text=f"{self._attr_name}/{entity}"
373  )
374 
375  all_template_observations: list[Observation] = [
376  observations[0] for observations in self.observations_by_templateobservations_by_template.values()
377  ]
378  if len(all_template_observations) == 2:
380  self.hasshass,
381  all_template_observations,
382  text=f"{self._attr_name}/{all_template_observations[0].value_template}",
383  )
384 
385  @callback
386  def _recalculate_and_write_state(self) -> None:
387  self.probabilityprobability = self._calculate_new_probability_calculate_new_probability()
388  self._attr_is_on_attr_is_on = bool(self.probabilityprobability >= self._probability_threshold_probability_threshold)
389  self.async_write_ha_stateasync_write_ha_state()
390 
391  def _initialize_current_observations(self) -> OrderedDict[UUID, Observation]:
392  local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
393  for entity in self.observations_by_entityobservations_by_entity:
394  local_observations.update(self._record_entity_observations_record_entity_observations(entity))
395  return local_observations
396 
398  self, entity: str
399  ) -> OrderedDict[UUID, Observation]:
400  local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
401 
402  for observation in self.observations_by_entityobservations_by_entity[entity]:
403  platform = observation.platform
404 
405  observation.observed = self.observation_handlers[platform](
406  observation, observation.multi
407  )
408  local_observations[observation.id] = observation
409 
410  return local_observations
411 
412  def _calculate_new_probability(self) -> float:
413  prior = self.priorprior
414 
415  for observation in self.current_observations.values():
416  if observation.observed is True:
417  prior = update_probability(
418  prior,
419  observation.prob_given_true,
420  observation.prob_given_false,
421  )
422  continue
423  if observation.observed is False:
424  prior = update_probability(
425  prior,
426  1 - observation.prob_given_true,
427  1 - observation.prob_given_false,
428  )
429  continue
430  # observation.observed is None
431  if observation.entity_id is not None:
432  _LOGGER.debug(
433  (
434  "Observation for entity '%s' returned None, it will not be used"
435  " for Bayesian updating"
436  ),
437  observation.entity_id,
438  )
439  continue
440  _LOGGER.debug(
441  (
442  "Observation for template entity returned None rather than a valid"
443  " boolean, it will not be used for Bayesian updating"
444  ),
445  )
446  # the prior has been updated and is now the posterior
447  return prior
448 
449  def _build_observations_by_entity(self) -> dict[str, list[Observation]]:
450  """Build and return data structure of the form below.
451 
452  {
453  "sensor.sensor1": [Observation, Observation],
454  "sensor.sensor2": [Observation],
455  ...
456  }
457 
458  Each "observation" must be recognized uniquely, and it should be possible
459  for all relevant observations to be looked up via their `entity_id`.
460  """
461 
462  observations_by_entity: dict[str, list[Observation]] = {}
463  for observation in self._observations_observations:
464  if (key := observation.entity_id) is None:
465  continue
466  observations_by_entity.setdefault(key, []).append(observation)
467 
468  for entity_observations in observations_by_entity.values():
469  if len(entity_observations) == 1:
470  continue
471  for observation in entity_observations:
472  observation.multi = True
473 
474  return observations_by_entity
475 
476  def _build_observations_by_template(self) -> dict[Template, list[Observation]]:
477  """Build and return data structure of the form below.
478 
479  {
480  "template": [Observation, Observation],
481  "template2": [Observation],
482  ...
483  }
484 
485  Each "observation" must be recognized uniquely, and it should be possible
486  for all relevant observations to be looked up via their `template`.
487  """
488 
489  observations_by_template: dict[Template, list[Observation]] = {}
490  for observation in self._observations_observations:
491  if observation.value_template is None:
492  continue
493 
494  template = observation.value_template
495  observations_by_template.setdefault(template, []).append(observation)
496 
497  return observations_by_template
498 
500  self, entity_observation: Observation, multi: bool = False
501  ) -> bool | None:
502  """Return True if numeric condition is met, return False if not, return None otherwise."""
503  entity_id = entity_observation.entity_id
504  # if we are dealing with numeric_state observations entity_id cannot be None
505  if TYPE_CHECKING:
506  assert entity_id is not None
507 
508  entity = self.hasshass.states.get(entity_id)
509  if entity is None:
510  return None
511 
512  try:
513  if condition.state(self.hasshass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
514  return None
515  result = condition.async_numeric_state(
516  self.hasshass,
517  entity,
518  entity_observation.below,
519  entity_observation.above,
520  None,
521  entity_observation.to_dict(),
522  )
523  if result:
524  return True
525  if multi:
526  state = float(entity.state)
527  if (
528  entity_observation.below is not None
529  and state == entity_observation.below
530  ):
531  return True
532  return None
533  except ConditionError:
534  return None
535  else:
536  return False
537 
539  self, entity_observation: Observation, multi: bool = False
540  ) -> bool | None:
541  """Return True if state conditions are met, return False if they are not.
542 
543  Returns None if the state is unavailable.
544  """
545 
546  entity = entity_observation.entity_id
547 
548  try:
549  if condition.state(self.hasshass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
550  return None
551 
552  result = condition.state(self.hasshass, entity, entity_observation.to_state)
553  if multi and not result:
554  return None
555  except ConditionError:
556  return None
557  else:
558  return result
559 
560  @property
561  def extra_state_attributes(self) -> dict[str, Any]:
562  """Return the state attributes of the sensor."""
563 
564  return {
565  ATTR_PROBABILITY: round(self.probabilityprobability, 2),
566  ATTR_PROBABILITY_THRESHOLD: self._probability_threshold_probability_threshold,
567  # An entity can be in more than one observation so set then list to deduplicate
568  ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
569  {
570  observation.entity_id
571  for observation in self.current_observations.values()
572  if observation is not None
573  and observation.entity_id is not None
574  and observation.observed is not None
575  }
576  ),
577  ATTR_OBSERVATIONS: [
578  observation.to_dict()
579  for observation in self.current_observations.values()
580  if observation is not None
581  ],
582  }
583 
584  async def async_update(self) -> None:
585  """Get the latest data and update the states."""
586  if not self._callbacks:
587  self._recalculate_and_write_state_recalculate_and_write_state()
588  return
589  # Force recalc of the templates. The states will
590  # update automatically.
591  for call in self._callbacks:
592  call.async_refresh()
dict[Template, list[Observation]] _build_observations_by_template(self)
None __init__(self, str name, str|None unique_id, float prior, list[ConfigType] observations, float probability_threshold, BinarySensorDeviceClass|None device_class)
OrderedDict[UUID, Observation] _initialize_current_observations(self)
bool|None _process_numeric_state(self, Observation entity_observation, bool multi=False)
bool|None _process_state(self, Observation entity_observation, bool multi=False)
OrderedDict[UUID, Observation] _record_entity_observations(self, str entity)
None async_on_remove(self, CALLBACK_TYPE func)
Definition: entity.py:1331
None async_set_context(self, Context context)
Definition: entity.py:937
float update_probability(float prior, float prob_given_true, float prob_given_false)
list[dict] _no_overlapping(list[dict] configs)
None async_setup_platform(HomeAssistant hass, ConfigType config, AddEntitiesCallback async_add_entities, DiscoveryInfoType|None discovery_info=None)
dict[str, Any] _above_greater_than_below(dict[str, Any] config)
None raise_mirrored_entries(HomeAssistant hass, list[Observation] observations, str text="")
Definition: issues.py:14
None raise_no_prob_given_false(HomeAssistant hass, str text)
Definition: issues.py:33
IssData update(pyiss.ISS iss)
Definition: __init__.py:33
CALLBACK_TYPE async_track_state_change_event(HomeAssistant hass, str|Iterable[str] entity_ids, Callable[[Event[EventStateChangedData]], Any] action, HassJobType|None job_type=None)
Definition: event.py:314
TrackTemplateResultInfo async_track_template_result(HomeAssistant hass, Sequence[TrackTemplate] track_templates, TrackTemplateResultListener action, bool strict=False, Callable[[int, str], None]|None log_fn=None, bool has_super_template=False)
Definition: event.py:1345
None async_setup_reload_service(HomeAssistant hass, str domain, Iterable[str] platforms)
Definition: reload.py:191
bool result_as_boolean(Any|None template_result)
Definition: template.py:1277