1 """Use Bayesian Inference to trigger a binary sensor."""
3 from __future__
import annotations
5 from collections
import OrderedDict
6 from collections.abc
import Callable
9 from typing
import TYPE_CHECKING, Any, NamedTuple
12 import voluptuous
as vol
15 PLATFORM_SCHEMA
as BINARY_SENSOR_PLATFORM_SCHEMA,
16 BinarySensorDeviceClass,
40 TrackTemplateResultInfo,
41 async_track_state_change_event,
42 async_track_template_result,
48 from .
import DOMAIN, PLATFORMS
51 ATTR_OCCURRED_OBSERVATION_ENTITIES,
53 ATTR_PROBABILITY_THRESHOLD,
59 CONF_PROBABILITY_THRESHOLD,
63 DEFAULT_PROBABILITY_THRESHOLD,
65 from .helpers
import Observation
66 from .issues
import raise_mirrored_entries, raise_no_prob_given_false
68 _LOGGER = logging.getLogger(__name__)
72 if config[CONF_PLATFORM] == CONF_NUMERIC_STATE:
73 above = config.get(CONF_ABOVE)
74 below = config.get(CONF_BELOW)
75 if above
is None and below
is None:
77 "For bayesian numeric state for entity: %s at least one of 'above' or 'below' must be specified",
78 config[CONF_ENTITY_ID],
81 "For bayesian numeric state at least one of 'above' or 'below' must be specified."
83 if above
is not None and below
is not None:
86 "For bayesian numeric state 'above' (%s) must be less than 'below' (%s)",
90 raise vol.Invalid(
"'above' is greater than 'below'")
94 NUMERIC_STATE_SCHEMA = vol.All(
97 CONF_PLATFORM: CONF_NUMERIC_STATE,
98 vol.Required(CONF_ENTITY_ID): cv.entity_id,
99 vol.Optional(CONF_ABOVE): vol.Coerce(float),
100 vol.Optional(CONF_BELOW): vol.Coerce(float),
101 vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
102 vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
106 _above_greater_than_below,
112 config
for config
in configs
if config[CONF_PLATFORM] == CONF_NUMERIC_STATE
114 if len(numeric_configs) < 2:
117 class NumericConfig(NamedTuple):
121 d: dict[str, list[NumericConfig]] = {}
122 for _, config
in enumerate(numeric_configs):
123 above = config.get(CONF_ABOVE, -math.inf)
124 below = config.get(CONF_BELOW, math.inf)
125 entity_id: str =
str(config[CONF_ENTITY_ID])
126 d.setdefault(entity_id, []).append(NumericConfig(above, below))
128 for ent_id, intervals
in d.items():
129 intervals = sorted(intervals, key=
lambda tup: tup.above)
131 for i, tup
in enumerate(intervals):
132 if len(intervals) > i + 1
and tup.below > intervals[i + 1].above:
134 f
"Ranges for bayesian numeric state entities must not overlap, but {ent_id} has overlapping ranges, above:{tup.above}, below:{tup.below} overlaps with above:{intervals[i+1].above}, below:{intervals[i+1].below}."
139 STATE_SCHEMA = vol.Schema(
141 CONF_PLATFORM: CONF_STATE,
142 vol.Required(CONF_ENTITY_ID): cv.entity_id,
143 vol.Required(CONF_TO_STATE): cv.string,
144 vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
145 vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
150 TEMPLATE_SCHEMA = vol.Schema(
152 CONF_PLATFORM: CONF_TEMPLATE,
153 vol.Required(CONF_VALUE_TEMPLATE): cv.template,
154 vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
155 vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
160 PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend(
162 vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
163 vol.Optional(CONF_UNIQUE_ID): cv.string,
164 vol.Optional(CONF_DEVICE_CLASS): cv.string,
165 vol.Required(CONF_OBSERVATIONS): vol.Schema(
168 [vol.Any(TEMPLATE_SCHEMA, STATE_SCHEMA, NUMERIC_STATE_SCHEMA)],
172 vol.Required(CONF_PRIOR): vol.Coerce(float),
174 CONF_PROBABILITY_THRESHOLD, default=DEFAULT_PROBABILITY_THRESHOLD
175 ): vol.Coerce(float),
181 prior: float, prob_given_true: float, prob_given_false: float
183 """Update probability using Bayes' rule."""
184 numerator = prob_given_true * prior
185 denominator = numerator + prob_given_false * (1 - prior)
186 return numerator / denominator
192 async_add_entities: AddEntitiesCallback,
193 discovery_info: DiscoveryInfoType |
None =
None,
195 """Set up the Bayesian Binary sensor."""
198 name: str = config[CONF_NAME]
199 unique_id: str |
None = config.get(CONF_UNIQUE_ID)
200 observations: list[ConfigType] = config[CONF_OBSERVATIONS]
201 prior: float = config[CONF_PRIOR]
202 probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
203 device_class: BinarySensorDeviceClass |
None = config.get(CONF_DEVICE_CLASS)
206 broken_observations: list[dict[str, Any]] = []
207 for observation
in observations:
208 if CONF_P_GIVEN_F
not in observation:
209 text: str = f
"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}"
211 _LOGGER.error(
"Missing prob_given_false YAML entry for %s", text)
212 broken_observations.append(observation)
213 observations = [x
for x
in observations
if x
not in broken_observations]
222 probability_threshold,
230 """Representation of a Bayesian sensor."""
232 _attr_should_poll =
False
237 unique_id: str |
None,
239 observations: list[ConfigType],
240 probability_threshold: float,
241 device_class: BinarySensorDeviceClass |
None,
243 """Initialize the Bayesian sensor."""
248 entity_id=observation.get(CONF_ENTITY_ID),
249 platform=observation[CONF_PLATFORM],
250 prob_given_false=observation[CONF_P_GIVEN_F],
251 prob_given_true=observation[CONF_P_GIVEN_T],
253 to_state=observation.get(CONF_TO_STATE),
254 above=observation.get(CONF_ABOVE),
255 below=observation.get(CONF_BELOW),
256 value_template=observation.get(CONF_VALUE_TEMPLATE),
258 for observation
in observations
263 self._callbacks: list[TrackTemplateResultInfo] = []
268 self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({})
273 self.observation_handlers: dict[
274 str, Callable[[Observation, bool], bool |
None]
281 """Call when entity about to be added.
283 All relevant update logic for instance attributes occurs within this closure.
284 Other methods in this class are designed to avoid directly modifying instance
285 attributes, by instead focusing on returning relevant data back to this method.
287 The goal of this method is to ensure that `self.current_observations` and `self.probability`
288 are set on a best-effort basis when this entity is register with hass.
290 In addition, this method must register the state listener defined within, which
291 will be called any time a relevant entity changes its state.
295 def async_threshold_sensor_state_listener(
296 event: Event[EventStateChangedData],
298 """Handle sensor state changes.
300 When a state changes, we must update our list of current observations,
301 then calculate the new probability.
304 entity_id = event.data[
"entity_id"]
306 self.current_observations.
update(
316 async_threshold_sensor_state_listener,
321 def _async_template_result_changed(
322 event: Event[EventStateChangedData] |
None,
323 updates: list[TrackTemplateResult],
325 track_template_result = updates.pop()
326 template = track_template_result.template
327 result = track_template_result.result
328 entity_id =
None if event
is None else event.data[
"entity_id"]
329 if isinstance(result, TemplateError):
331 "TemplateError('%s') while processing template '%s' in entity '%s'",
342 observation.observed = observed
345 if entity_id
is not None:
346 observation.entity_id = entity_id
348 self.current_observations[observation.id] = observation
358 _async_template_result_changed,
361 self._callbacks.append(info)
372 self.
hasshass, observations, text=f
"{self._attr_name}/{entity}"
375 all_template_observations: list[Observation] = [
378 if len(all_template_observations) == 2:
381 all_template_observations,
382 text=f
"{self._attr_name}/{all_template_observations[0].value_template}",
392 local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
395 return local_observations
399 ) -> OrderedDict[UUID, Observation]:
400 local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
403 platform = observation.platform
405 observation.observed = self.observation_handlers[platform](
406 observation, observation.multi
408 local_observations[observation.id] = observation
410 return local_observations
413 prior = self.
priorprior
415 for observation
in self.current_observations.values():
416 if observation.observed
is True:
419 observation.prob_given_true,
420 observation.prob_given_false,
423 if observation.observed
is False:
426 1 - observation.prob_given_true,
427 1 - observation.prob_given_false,
431 if observation.entity_id
is not None:
434 "Observation for entity '%s' returned None, it will not be used"
435 " for Bayesian updating"
437 observation.entity_id,
442 "Observation for template entity returned None rather than a valid"
443 " boolean, it will not be used for Bayesian updating"
450 """Build and return data structure of the form below.
453 "sensor.sensor1": [Observation, Observation],
454 "sensor.sensor2": [Observation],
458 Each "observation" must be recognized uniquely, and it should be possible
459 for all relevant observations to be looked up via their `entity_id`.
462 observations_by_entity: dict[str, list[Observation]] = {}
464 if (key := observation.entity_id)
is None:
466 observations_by_entity.setdefault(key, []).append(observation)
468 for entity_observations
in observations_by_entity.values():
469 if len(entity_observations) == 1:
471 for observation
in entity_observations:
472 observation.multi =
True
474 return observations_by_entity
477 """Build and return data structure of the form below.
480 "template": [Observation, Observation],
481 "template2": [Observation],
485 Each "observation" must be recognized uniquely, and it should be possible
486 for all relevant observations to be looked up via their `template`.
489 observations_by_template: dict[Template, list[Observation]] = {}
491 if observation.value_template
is None:
494 template = observation.value_template
495 observations_by_template.setdefault(template, []).append(observation)
497 return observations_by_template
500 self, entity_observation: Observation, multi: bool =
False
502 """Return True if numeric condition is met, return False if not, return None otherwise."""
503 entity_id = entity_observation.entity_id
506 assert entity_id
is not None
508 entity = self.
hasshass.states.get(entity_id)
513 if condition.state(self.
hasshass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
515 result = condition.async_numeric_state(
518 entity_observation.below,
519 entity_observation.above,
521 entity_observation.to_dict(),
526 state =
float(entity.state)
528 entity_observation.below
is not None
529 and state == entity_observation.below
533 except ConditionError:
539 self, entity_observation: Observation, multi: bool =
False
541 """Return True if state conditions are met, return False if they are not.
543 Returns None if the state is unavailable.
546 entity = entity_observation.entity_id
549 if condition.state(self.
hasshass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
552 result = condition.state(self.
hasshass, entity, entity_observation.to_state)
553 if multi
and not result:
555 except ConditionError:
562 """Return the state attributes of the sensor."""
565 ATTR_PROBABILITY: round(self.
probabilityprobability, 2),
568 ATTR_OCCURRED_OBSERVATION_ENTITIES:
list(
570 observation.entity_id
571 for observation
in self.current_observations.values()
572 if observation
is not None
573 and observation.entity_id
is not None
574 and observation.observed
is not None
578 observation.to_dict()
579 for observation
in self.current_observations.values()
580 if observation
is not None
585 """Get the latest data and update the states."""
586 if not self._callbacks:
591 for call
in self._callbacks:
dict[Template, list[Observation]] _build_observations_by_template(self)
None __init__(self, str name, str|None unique_id, float prior, list[ConfigType] observations, float probability_threshold, BinarySensorDeviceClass|None device_class)
float _calculate_new_probability(self)
OrderedDict[UUID, Observation] _initialize_current_observations(self)
bool|None _process_numeric_state(self, Observation entity_observation, bool multi=False)
dict[str, list[Observation]] _build_observations_by_entity(self)
None async_added_to_hass(self)
dict[str, Any] extra_state_attributes(self)
bool|None _process_state(self, Observation entity_observation, bool multi=False)
OrderedDict[UUID, Observation] _record_entity_observations(self, str entity)
None _recalculate_and_write_state(self)
None async_write_ha_state(self)
None async_on_remove(self, CALLBACK_TYPE func)
None async_set_context(self, Context context)
float update_probability(float prior, float prob_given_true, float prob_given_false)
list[dict] _no_overlapping(list[dict] configs)
None async_setup_platform(HomeAssistant hass, ConfigType config, AddEntitiesCallback async_add_entities, DiscoveryInfoType|None discovery_info=None)
dict[str, Any] _above_greater_than_below(dict[str, Any] config)
None raise_mirrored_entries(HomeAssistant hass, list[Observation] observations, str text="")
None raise_no_prob_given_false(HomeAssistant hass, str text)
IssData update(pyiss.ISS iss)
CALLBACK_TYPE async_track_state_change_event(HomeAssistant hass, str|Iterable[str] entity_ids, Callable[[Event[EventStateChangedData]], Any] action, HassJobType|None job_type=None)
TrackTemplateResultInfo async_track_template_result(HomeAssistant hass, Sequence[TrackTemplate] track_templates, TrackTemplateResultListener action, bool strict=False, Callable[[int, str], None]|None log_fn=None, bool has_super_template=False)
None async_setup_reload_service(HomeAssistant hass, str domain, Iterable[str] platforms)
bool result_as_boolean(Any|None template_result)