1 """Support for performing TensorFlow classification on images."""
3 from __future__
import annotations
12 from PIL
import Image, ImageDraw, UnidentifiedImageError
13 import tensorflow
as tf
14 import voluptuous
as vol
18 PLATFORM_SCHEMA
as IMAGE_PROCESSING_PLATFORM_SCHEMA,
19 ImageProcessingEntity,
26 EVENT_HOMEASSISTANT_START,
35 os.environ[
"TF_CPP_MIN_LOG_LEVEL"] =
"2"
38 _LOGGER = logging.getLogger(__name__)
40 ATTR_MATCHES =
"matches"
41 ATTR_SUMMARY =
"summary"
42 ATTR_TOTAL_MATCHES =
"total_matches"
43 ATTR_PROCESS_TIME =
"process_time"
46 CONF_BOTTOM =
"bottom"
47 CONF_CATEGORIES =
"categories"
48 CONF_CATEGORY =
"category"
49 CONF_FILE_OUT =
"file_out"
51 CONF_LABELS =
"labels"
52 CONF_LABEL_OFFSET =
"label_offset"
54 CONF_MODEL_DIR =
"model_dir"
58 AREA_SCHEMA = vol.Schema(
60 vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
61 vol.Optional(CONF_LEFT, default=0): cv.small_float,
62 vol.Optional(CONF_RIGHT, default=1): cv.small_float,
63 vol.Optional(CONF_TOP, default=0): cv.small_float,
67 CATEGORY_SCHEMA = vol.Schema(
68 {vol.Required(CONF_CATEGORY): cv.string, vol.Optional(CONF_AREA): AREA_SCHEMA}
71 PLATFORM_SCHEMA = IMAGE_PROCESSING_PLATFORM_SCHEMA.extend(
73 vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]),
74 vol.Required(CONF_MODEL): vol.Schema(
76 vol.Required(CONF_GRAPH): cv.isdir,
77 vol.Optional(CONF_AREA): AREA_SCHEMA,
78 vol.Optional(CONF_CATEGORIES, default=[]): vol.All(
79 cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)]
81 vol.Optional(CONF_LABELS): cv.isfile,
82 vol.Optional(CONF_LABEL_OFFSET, default=1): int,
83 vol.Optional(CONF_MODEL_DIR): cv.isdir,
91 """Get a tf.function for detection."""
95 """Detect objects in image."""
97 image, shapes = model.preprocess(image)
98 prediction_dict = model.predict(image, shapes)
99 return model.postprocess(prediction_dict, shapes)
107 add_entities: AddEntitiesCallback,
108 discovery_info: DiscoveryInfoType |
None =
None,
110 """Set up the TensorFlow image processing platform."""
111 model_config = config[CONF_MODEL]
112 model_dir = model_config.get(CONF_MODEL_DIR)
or hass.config.path(
"tensorflow")
113 labels = model_config.get(CONF_LABELS)
or hass.config.path(
114 "tensorflow",
"object_detection",
"data",
"mscoco_label_map.pbtxt"
116 checkpoint = os.path.join(model_config[CONF_GRAPH],
"checkpoint")
117 pipeline_config = os.path.join(model_config[CONF_GRAPH],
"pipeline.config")
121 not os.path.isdir(model_dir)
122 or not os.path.isdir(checkpoint)
123 or not os.path.exists(pipeline_config)
124 or not os.path.exists(labels)
126 _LOGGER.error(
"Unable to locate tensorflow model or label map")
130 sys.path.append(model_dir)
138 from object_detection.builders
import model_builder
139 from object_detection.utils
import config_util, label_map_util
142 "No TensorFlow Object Detection library found! Install or compile "
143 "for your system following instructions here: "
144 "https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md#installation"
153 "No OpenCV library found. TensorFlow will process image with "
154 "PIL at reduced resolution"
157 hass.data[DOMAIN] = {CONF_MODEL:
None}
159 def tensorflow_hass_start(_event):
160 """Set up TensorFlow model on hass start."""
161 start = time.perf_counter()
164 pipeline_configs = config_util.get_configs_from_pipeline_file(pipeline_config)
165 detection_model = model_builder.build(
166 model_config=pipeline_configs[
"model"], is_training=
False
170 ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
171 ckpt.restore(os.path.join(checkpoint,
"ckpt-0")).expect_partial()
174 "Model checkpoint restore took %d seconds", time.perf_counter() - start
180 inp = np.zeros([2160, 3840, 3], dtype=np.uint8)
182 input_tensor = tf.convert_to_tensor(inp, dtype=tf.float32)
184 input_tensor = input_tensor[tf.newaxis, ...]
188 _LOGGER.debug(
"Model load took %d seconds", time.perf_counter() - start)
189 hass.data[DOMAIN][CONF_MODEL] = model
191 hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
193 category_index = label_map_util.create_category_index_from_labelmap(
194 labels, use_display_name=
True
200 camera[CONF_ENTITY_ID],
201 camera.get(CONF_NAME),
205 for camera
in config[CONF_SOURCE]
210 """Representation of an TensorFlow image processor."""
220 """Initialize the TensorFlow entity."""
221 model_config = config.get(CONF_MODEL)
227 self.
_name_name = f
"TensorFlow {split_entity_id(camera_entity)[1]}"
234 categories = model_config.get(CONF_CATEGORIES)
237 for category
in categories:
238 if isinstance(category, dict):
239 category_name = category.get(CONF_CATEGORY)
240 category_area = category.get(CONF_AREA)
245 category_area.get(CONF_TOP),
246 category_area.get(CONF_LEFT),
247 category_area.get(CONF_BOTTOM),
248 category_area.get(CONF_RIGHT),
256 if area_config := model_config.get(CONF_AREA):
258 area_config.get(CONF_TOP),
259 area_config.get(CONF_LEFT),
260 area_config.get(CONF_BOTTOM),
261 area_config.get(CONF_RIGHT),
271 """Return camera entity id from process pictures."""
276 """Return the name of the image processor."""
277 return self.
_name_name
281 """Return the state of the entity."""
286 """Return device specific state attributes."""
288 ATTR_MATCHES: self.
_matches_matches,
290 category: len(values)
for category, values
in self.
_matches_matches.items()
297 img = Image.open(io.BytesIO(bytearray(image))).convert(
"RGB")
298 img_width, img_height = img.size
299 draw = ImageDraw.Draw(img)
302 if self.
_area_area != [0, 0, 1, 1]:
304 draw, self.
_area_area, img_width, img_height,
"Detection Area", (0, 255, 255)
307 for category, values
in matches.items():
315 label = f
"{category.capitalize()} Detection Area"
326 for instance
in values:
327 label = f
"{category} {instance['score']:.1f}%"
329 draw, instance[
"box"], img_width, img_height, label, (255, 255, 0)
333 _LOGGER.debug(
"Saving results image to %s", path)
334 os.makedirs(os.path.dirname(path), exist_ok=
True)
338 """Process the image."""
339 if not (model := self.
hasshasshass.data[DOMAIN][CONF_MODEL]):
340 _LOGGER.debug(
"Model not yet ready")
343 start = time.perf_counter()
347 img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
348 inp = img[:, :, [2, 1, 0]]
349 inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
352 img = Image.open(io.BytesIO(bytearray(image))).convert(
"RGB")
353 except UnidentifiedImageError:
354 _LOGGER.warning(
"Unable to process image, bad data")
356 img.thumbnail((460, 460), Image.ANTIALIAS)
357 img_width, img_height = img.size
359 np.array(img.getdata())
360 .reshape((img_height, img_width, 3))
363 inp_expanded = np.expand_dims(inp, axis=0)
366 input_tensor = tf.convert_to_tensor(inp_expanded, dtype=tf.float32)
368 detections =
model(input_tensor)
369 boxes = detections[
"detection_boxes"][0].numpy()
370 scores = detections[
"detection_scores"][0].numpy()
372 detections[
"detection_classes"][0].numpy() + self.
_label_id_offset_label_id_offset
377 for box, score, obj_class
in zip(boxes, scores, classes, strict=
False):
387 boxes[0] < self.
_area_area[0]
388 or boxes[1] < self.
_area_area[1]
389 or boxes[2] > self.
_area_area[2]
390 or boxes[3] > self.
_area_area[3]
410 if category
not in matches:
411 matches[category] = []
412 matches[category].append({
"score":
float(score),
"box": boxes})
416 if total_matches
and self.
_file_out_file_out:
418 for path_template
in self.
_file_out_file_out:
419 if isinstance(path_template, template.Template):
421 path_template.render(camera_entity=self.
_camera_entity_camera_entity)
424 paths.append(path_template)
429 self.
_process_time_process_time = time.perf_counter() - start
def __init__(self, hass, camera_entity, name, category_index, config)
def extra_state_attributes(self)
def process_image(self, image)
def _save_image(self, image, matches, paths)
None add_entities(AsusWrtRouter router, AddEntitiesCallback async_add_entities, set[str] tracked)
def get_model_detection_function(model)
None setup_platform(HomeAssistant hass, ConfigType config, AddEntitiesCallback add_entities, DiscoveryInfoType|None discovery_info=None)
None draw_box(ImageDraw draw, tuple[float, float, float, float] box, int img_width, int img_height, str text="", tuple[int, int, int] color=(255, 255, 0))