Home Assistant Unofficial Reference 2024.12.1
image_processing.py
Go to the documentation of this file.
1 """Support for performing TensorFlow classification on images."""
2 
3 from __future__ import annotations
4 
5 import io
6 import logging
7 import os
8 import sys
9 import time
10 
11 import numpy as np
12 from PIL import Image, ImageDraw, UnidentifiedImageError
13 import tensorflow as tf
14 import voluptuous as vol
15 
17  CONF_CONFIDENCE,
18  PLATFORM_SCHEMA as IMAGE_PROCESSING_PLATFORM_SCHEMA,
19  ImageProcessingEntity,
20 )
21 from homeassistant.const import (
22  CONF_ENTITY_ID,
23  CONF_MODEL,
24  CONF_NAME,
25  CONF_SOURCE,
26  EVENT_HOMEASSISTANT_START,
27 )
28 from homeassistant.core import HomeAssistant, split_entity_id
29 from homeassistant.helpers import template
31 from homeassistant.helpers.entity_platform import AddEntitiesCallback
32 from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
33 from homeassistant.util.pil import draw_box
34 
35 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
36 
37 DOMAIN = "tensorflow"
38 _LOGGER = logging.getLogger(__name__)
39 
40 ATTR_MATCHES = "matches"
41 ATTR_SUMMARY = "summary"
42 ATTR_TOTAL_MATCHES = "total_matches"
43 ATTR_PROCESS_TIME = "process_time"
44 
45 CONF_AREA = "area"
46 CONF_BOTTOM = "bottom"
47 CONF_CATEGORIES = "categories"
48 CONF_CATEGORY = "category"
49 CONF_FILE_OUT = "file_out"
50 CONF_GRAPH = "graph"
51 CONF_LABELS = "labels"
52 CONF_LABEL_OFFSET = "label_offset"
53 CONF_LEFT = "left"
54 CONF_MODEL_DIR = "model_dir"
55 CONF_RIGHT = "right"
56 CONF_TOP = "top"
57 
58 AREA_SCHEMA = vol.Schema(
59  {
60  vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
61  vol.Optional(CONF_LEFT, default=0): cv.small_float,
62  vol.Optional(CONF_RIGHT, default=1): cv.small_float,
63  vol.Optional(CONF_TOP, default=0): cv.small_float,
64  }
65 )
66 
67 CATEGORY_SCHEMA = vol.Schema(
68  {vol.Required(CONF_CATEGORY): cv.string, vol.Optional(CONF_AREA): AREA_SCHEMA}
69 )
70 
71 PLATFORM_SCHEMA = IMAGE_PROCESSING_PLATFORM_SCHEMA.extend(
72  {
73  vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]),
74  vol.Required(CONF_MODEL): vol.Schema(
75  {
76  vol.Required(CONF_GRAPH): cv.isdir,
77  vol.Optional(CONF_AREA): AREA_SCHEMA,
78  vol.Optional(CONF_CATEGORIES, default=[]): vol.All(
79  cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)]
80  ),
81  vol.Optional(CONF_LABELS): cv.isfile,
82  vol.Optional(CONF_LABEL_OFFSET, default=1): int,
83  vol.Optional(CONF_MODEL_DIR): cv.isdir,
84  }
85  ),
86  }
87 )
88 
89 
91  """Get a tf.function for detection."""
92 
93  @tf.function
94  def detect_fn(image):
95  """Detect objects in image."""
96 
97  image, shapes = model.preprocess(image)
98  prediction_dict = model.predict(image, shapes)
99  return model.postprocess(prediction_dict, shapes)
100 
101  return detect_fn
102 
103 
105  hass: HomeAssistant,
106  config: ConfigType,
107  add_entities: AddEntitiesCallback,
108  discovery_info: DiscoveryInfoType | None = None,
109 ) -> None:
110  """Set up the TensorFlow image processing platform."""
111  model_config = config[CONF_MODEL]
112  model_dir = model_config.get(CONF_MODEL_DIR) or hass.config.path("tensorflow")
113  labels = model_config.get(CONF_LABELS) or hass.config.path(
114  "tensorflow", "object_detection", "data", "mscoco_label_map.pbtxt"
115  )
116  checkpoint = os.path.join(model_config[CONF_GRAPH], "checkpoint")
117  pipeline_config = os.path.join(model_config[CONF_GRAPH], "pipeline.config")
118 
119  # Make sure locations exist
120  if (
121  not os.path.isdir(model_dir)
122  or not os.path.isdir(checkpoint)
123  or not os.path.exists(pipeline_config)
124  or not os.path.exists(labels)
125  ):
126  _LOGGER.error("Unable to locate tensorflow model or label map")
127  return
128 
129  # append custom model path to sys.path
130  sys.path.append(model_dir)
131 
132  try:
133  # Verify that the TensorFlow Object Detection API is pre-installed
134  # These imports shouldn't be moved to the top, because they depend on code from the model_dir.
135  # (The model_dir is created during the manual setup process. See integration docs.)
136 
137  # pylint: disable=import-outside-toplevel
138  from object_detection.builders import model_builder
139  from object_detection.utils import config_util, label_map_util
140  except ImportError:
141  _LOGGER.error(
142  "No TensorFlow Object Detection library found! Install or compile "
143  "for your system following instructions here: "
144  "https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md#installation"
145  )
146  return
147 
148  try:
149  # Display warning that PIL will be used if no OpenCV is found.
150  import cv2 # noqa: F401 pylint: disable=import-outside-toplevel
151  except ImportError:
152  _LOGGER.warning(
153  "No OpenCV library found. TensorFlow will process image with "
154  "PIL at reduced resolution"
155  )
156 
157  hass.data[DOMAIN] = {CONF_MODEL: None}
158 
159  def tensorflow_hass_start(_event):
160  """Set up TensorFlow model on hass start."""
161  start = time.perf_counter()
162 
163  # Load pipeline config and build a detection model
164  pipeline_configs = config_util.get_configs_from_pipeline_file(pipeline_config)
165  detection_model = model_builder.build(
166  model_config=pipeline_configs["model"], is_training=False
167  )
168 
169  # Restore checkpoint
170  ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
171  ckpt.restore(os.path.join(checkpoint, "ckpt-0")).expect_partial()
172 
173  _LOGGER.debug(
174  "Model checkpoint restore took %d seconds", time.perf_counter() - start
175  )
176 
177  model = get_model_detection_function(detection_model)
178 
179  # Preload model cache with empty image tensor
180  inp = np.zeros([2160, 3840, 3], dtype=np.uint8)
181  # The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
182  input_tensor = tf.convert_to_tensor(inp, dtype=tf.float32)
183  # The model expects a batch of images, so add an axis with `tf.newaxis`.
184  input_tensor = input_tensor[tf.newaxis, ...]
185  # Run inference
186  model(input_tensor)
187 
188  _LOGGER.debug("Model load took %d seconds", time.perf_counter() - start)
189  hass.data[DOMAIN][CONF_MODEL] = model
190 
191  hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
192 
193  category_index = label_map_util.create_category_index_from_labelmap(
194  labels, use_display_name=True
195  )
196 
197  add_entities(
199  hass,
200  camera[CONF_ENTITY_ID],
201  camera.get(CONF_NAME),
202  category_index,
203  config,
204  )
205  for camera in config[CONF_SOURCE]
206  )
207 
208 
210  """Representation of an TensorFlow image processor."""
211 
212  def __init__(
213  self,
214  hass,
215  camera_entity,
216  name,
217  category_index,
218  config,
219  ):
220  """Initialize the TensorFlow entity."""
221  model_config = config.get(CONF_MODEL)
222  self.hasshasshass = hass
223  self._camera_entity_camera_entity = camera_entity
224  if name:
225  self._name_name = name
226  else:
227  self._name_name = f"TensorFlow {split_entity_id(camera_entity)[1]}"
228  self._category_index_category_index = category_index
229  self._min_confidence_min_confidence = config.get(CONF_CONFIDENCE)
230  self._file_out_file_out = config.get(CONF_FILE_OUT)
231 
232  # handle categories and specific detection areas
233  self._label_id_offset_label_id_offset = model_config.get(CONF_LABEL_OFFSET)
234  categories = model_config.get(CONF_CATEGORIES)
235  self._include_categories_include_categories = []
236  self._category_areas_category_areas = {}
237  for category in categories:
238  if isinstance(category, dict):
239  category_name = category.get(CONF_CATEGORY)
240  category_area = category.get(CONF_AREA)
241  self._include_categories_include_categories.append(category_name)
242  self._category_areas_category_areas[category_name] = [0, 0, 1, 1]
243  if category_area:
244  self._category_areas_category_areas[category_name] = [
245  category_area.get(CONF_TOP),
246  category_area.get(CONF_LEFT),
247  category_area.get(CONF_BOTTOM),
248  category_area.get(CONF_RIGHT),
249  ]
250  else:
251  self._include_categories_include_categories.append(category)
252  self._category_areas_category_areas[category] = [0, 0, 1, 1]
253 
254  # Handle global detection area
255  self._area_area = [0, 0, 1, 1]
256  if area_config := model_config.get(CONF_AREA):
257  self._area_area = [
258  area_config.get(CONF_TOP),
259  area_config.get(CONF_LEFT),
260  area_config.get(CONF_BOTTOM),
261  area_config.get(CONF_RIGHT),
262  ]
263 
264  self._matches_matches = {}
265  self._total_matches_total_matches = 0
266  self._last_image_last_image = None
267  self._process_time_process_time = 0
268 
269  @property
270  def camera_entity(self):
271  """Return camera entity id from process pictures."""
272  return self._camera_entity_camera_entity
273 
274  @property
275  def name(self):
276  """Return the name of the image processor."""
277  return self._name_name
278 
279  @property
280  def state(self):
281  """Return the state of the entity."""
282  return self._total_matches_total_matches
283 
284  @property
286  """Return device specific state attributes."""
287  return {
288  ATTR_MATCHES: self._matches_matches,
289  ATTR_SUMMARY: {
290  category: len(values) for category, values in self._matches_matches.items()
291  },
292  ATTR_TOTAL_MATCHES: self._total_matches_total_matches,
293  ATTR_PROCESS_TIME: self._process_time_process_time,
294  }
295 
296  def _save_image(self, image, matches, paths):
297  img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
298  img_width, img_height = img.size
299  draw = ImageDraw.Draw(img)
300 
301  # Draw custom global region/area
302  if self._area_area != [0, 0, 1, 1]:
303  draw_box(
304  draw, self._area_area, img_width, img_height, "Detection Area", (0, 255, 255)
305  )
306 
307  for category, values in matches.items():
308  # Draw custom category regions/areas
309  if category in self._category_areas_category_areas and self._category_areas_category_areas[category] != [
310  0,
311  0,
312  1,
313  1,
314  ]:
315  label = f"{category.capitalize()} Detection Area"
316  draw_box(
317  draw,
318  self._category_areas_category_areas[category],
319  img_width,
320  img_height,
321  label,
322  (0, 255, 0),
323  )
324 
325  # Draw detected objects
326  for instance in values:
327  label = f"{category} {instance['score']:.1f}%"
328  draw_box(
329  draw, instance["box"], img_width, img_height, label, (255, 255, 0)
330  )
331 
332  for path in paths:
333  _LOGGER.debug("Saving results image to %s", path)
334  os.makedirs(os.path.dirname(path), exist_ok=True)
335  img.save(path)
336 
337  def process_image(self, image):
338  """Process the image."""
339  if not (model := self.hasshasshass.data[DOMAIN][CONF_MODEL]):
340  _LOGGER.debug("Model not yet ready")
341  return
342 
343  start = time.perf_counter()
344  try:
345  import cv2 # pylint: disable=import-outside-toplevel
346 
347  img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
348  inp = img[:, :, [2, 1, 0]] # BGR->RGB
349  inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
350  except ImportError:
351  try:
352  img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
353  except UnidentifiedImageError:
354  _LOGGER.warning("Unable to process image, bad data")
355  return
356  img.thumbnail((460, 460), Image.ANTIALIAS)
357  img_width, img_height = img.size
358  inp = (
359  np.array(img.getdata())
360  .reshape((img_height, img_width, 3))
361  .astype(np.uint8)
362  )
363  inp_expanded = np.expand_dims(inp, axis=0)
364 
365  # The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
366  input_tensor = tf.convert_to_tensor(inp_expanded, dtype=tf.float32)
367 
368  detections = model(input_tensor)
369  boxes = detections["detection_boxes"][0].numpy()
370  scores = detections["detection_scores"][0].numpy()
371  classes = (
372  detections["detection_classes"][0].numpy() + self._label_id_offset_label_id_offset
373  ).astype(int)
374 
375  matches = {}
376  total_matches = 0
377  for box, score, obj_class in zip(boxes, scores, classes, strict=False):
378  score = score * 100
379  boxes = box.tolist()
380 
381  # Exclude matches below min confidence value
382  if score < self._min_confidence_min_confidence:
383  continue
384 
385  # Exclude matches outside global area definition
386  if (
387  boxes[0] < self._area_area[0]
388  or boxes[1] < self._area_area[1]
389  or boxes[2] > self._area_area[2]
390  or boxes[3] > self._area_area[3]
391  ):
392  continue
393 
394  category = self._category_index_category_index[obj_class]["name"]
395 
396  # Exclude unlisted categories
397  if self._include_categories_include_categories and category not in self._include_categories_include_categories:
398  continue
399 
400  # Exclude matches outside category specific area definition
401  if self._category_areas_category_areas and (
402  boxes[0] < self._category_areas_category_areas[category][0]
403  or boxes[1] < self._category_areas_category_areas[category][1]
404  or boxes[2] > self._category_areas_category_areas[category][2]
405  or boxes[3] > self._category_areas_category_areas[category][3]
406  ):
407  continue
408 
409  # If we got here, we should include it
410  if category not in matches:
411  matches[category] = []
412  matches[category].append({"score": float(score), "box": boxes})
413  total_matches += 1
414 
415  # Save Images
416  if total_matches and self._file_out_file_out:
417  paths = []
418  for path_template in self._file_out_file_out:
419  if isinstance(path_template, template.Template):
420  paths.append(
421  path_template.render(camera_entity=self._camera_entity_camera_entity)
422  )
423  else:
424  paths.append(path_template)
425  self._save_image_save_image(image, matches, paths)
426 
427  self._matches_matches = matches
428  self._total_matches_total_matches = total_matches
429  self._process_time_process_time = time.perf_counter() - start
def __init__(self, hass, camera_entity, name, category_index, config)
None add_entities(AsusWrtRouter router, AddEntitiesCallback async_add_entities, set[str] tracked)
None setup_platform(HomeAssistant hass, ConfigType config, AddEntitiesCallback add_entities, DiscoveryInfoType|None discovery_info=None)
None draw_box(ImageDraw draw, tuple[float, float, float, float] box, int img_width, int img_height, str text="", tuple[int, int, int] color=(255, 255, 0))
Definition: pil.py:18