Home Assistant Unofficial Reference 2024.12.1
wake_word.py
Go to the documentation of this file.
1 """Support for Wyoming wake-word-detection services."""
2 
3 import asyncio
4 from collections.abc import AsyncIterable
5 import logging
6 
7 from wyoming.audio import AudioChunk, AudioStart
8 from wyoming.client import AsyncTcpClient
9 from wyoming.wake import Detect, Detection
10 
11 from homeassistant.components import wake_word
12 from homeassistant.config_entries import ConfigEntry
13 from homeassistant.core import HomeAssistant
14 from homeassistant.helpers.entity_platform import AddEntitiesCallback
15 
16 from .const import DOMAIN
17 from .data import WyomingService, load_wyoming_info
18 from .error import WyomingError
19 from .models import DomainDataItem
20 
21 _LOGGER = logging.getLogger(__name__)
22 
23 
25  hass: HomeAssistant,
26  config_entry: ConfigEntry,
27  async_add_entities: AddEntitiesCallback,
28 ) -> None:
29  """Set up Wyoming wake-word-detection."""
30  item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
32  [
33  WyomingWakeWordProvider(hass, config_entry, item.service),
34  ]
35  )
36 
37 
39  """Wyoming wake-word-detection provider."""
40 
41  def __init__(
42  self,
43  hass: HomeAssistant,
44  config_entry: ConfigEntry,
45  service: WyomingService,
46  ) -> None:
47  """Set up provider."""
48  self.hasshasshass = hass
49  self.serviceservice = service
50  wake_service = service.info.wake[0]
51 
52  self._supported_wake_words_supported_wake_words = [
54  id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
55  )
56  for ww in wake_service.models
57  ]
58  self._attr_name_attr_name = wake_service.name
59  self._attr_unique_id_attr_unique_id = f"{config_entry.entry_id}-wake_word"
60 
61  async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
62  """Return a list of supported wake words."""
63  info = await load_wyoming_info(
64  self.serviceservice.host, self.serviceservice.port, retries=0, timeout=1
65  )
66 
67  if info is not None:
68  wake_service = info.wake[0]
69  self._supported_wake_words_supported_wake_words = [
71  id=ww.name,
72  name=ww.description or ww.name,
73  phrase=ww.phrase,
74  )
75  for ww in wake_service.models
76  ]
77 
78  return self._supported_wake_words_supported_wake_words
79 
81  self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str | None
82  ) -> wake_word.DetectionResult | None:
83  """Try to detect one or more wake words in an audio stream.
84 
85  Audio must be 16Khz sample rate with 16-bit mono PCM samples.
86  """
87 
88  async def next_chunk():
89  """Get the next chunk from audio stream."""
90  async for chunk_bytes in stream:
91  return chunk_bytes
92  return None
93 
94  try:
95  async with AsyncTcpClient(self.serviceservice.host, self.serviceservice.port) as client:
96  # Inform client which wake word we want to detect (None = default)
97  await client.write_event(
98  Detect(names=[wake_word_id] if wake_word_id else None).event()
99  )
100 
101  await client.write_event(
102  AudioStart(
103  rate=16000,
104  width=2,
105  channels=1,
106  ).event(),
107  )
108 
109  # Read audio and wake events in "parallel"
110  audio_task = asyncio.create_task(next_chunk())
111  wake_task = asyncio.create_task(client.read_event())
112  pending = {audio_task, wake_task}
113 
114  try:
115  while True:
116  done, pending = await asyncio.wait(
117  pending, return_when=asyncio.FIRST_COMPLETED
118  )
119 
120  if wake_task in done:
121  event = wake_task.result()
122  if event is None:
123  _LOGGER.debug("Connection lost")
124  break
125 
126  if Detection.is_type(event.type):
127  # Possible detection
128  detection = Detection.from_event(event)
129  _LOGGER.info(detection)
130 
131  if wake_word_id and (detection.name != wake_word_id):
132  _LOGGER.warning(
133  "Expected wake word %s but got %s, skipping",
134  wake_word_id,
135  detection.name,
136  )
137  wake_task = asyncio.create_task(client.read_event())
138  pending.add(wake_task)
139  continue
140 
141  # Retrieve queued audio
142  queued_audio: list[tuple[bytes, int]] | None = None
143  if audio_task in pending:
144  # Save queued audio
145  await audio_task
146  pending.remove(audio_task)
147  queued_audio = [audio_task.result()]
148 
150  wake_word_id=detection.name,
151  wake_word_phrase=self._get_phrase_get_phrase(detection.name),
152  timestamp=detection.timestamp,
153  queued_audio=queued_audio,
154  )
155 
156  # Next event
157  wake_task = asyncio.create_task(client.read_event())
158  pending.add(wake_task)
159 
160  if audio_task in done:
161  # Forward audio to wake service
162  chunk_info = audio_task.result()
163  if chunk_info is None:
164  break
165 
166  chunk_bytes, chunk_timestamp = chunk_info
167  chunk = AudioChunk(
168  rate=16000,
169  width=2,
170  channels=1,
171  audio=chunk_bytes,
172  timestamp=chunk_timestamp,
173  )
174  await client.write_event(chunk.event())
175 
176  # Next chunk
177  audio_task = asyncio.create_task(next_chunk())
178  pending.add(audio_task)
179  finally:
180  # Clean up
181  if audio_task in pending:
182  # It's critical that we don't cancel the audio task or
183  # leave it hanging. This would mess up the pipeline STT
184  # by stopping the audio stream.
185  await audio_task
186  pending.remove(audio_task)
187 
188  for task in pending:
189  task.cancel()
190 
191  except (OSError, WyomingError):
192  _LOGGER.exception("Error processing audio stream")
193 
194  return None
195 
196  def _get_phrase(self, model_id: str) -> str:
197  """Get wake word phrase for model id."""
198  for ww_model in self._supported_wake_words_supported_wake_words:
199  if not ww_model.phrase:
200  continue
201 
202  if ww_model.id == model_id:
203  return ww_model.phrase
204 
205  return model_id
wake_word.DetectionResult|None _async_process_audio_stream(self, AsyncIterable[tuple[bytes, int]] stream, str|None wake_word_id)
Definition: wake_word.py:82
None __init__(self, HomeAssistant hass, ConfigEntry config_entry, WyomingService service)
Definition: wake_word.py:46
Info|None load_wyoming_info(str host, int port, int retries=_INFO_RETRIES, float retry_wait=_INFO_RETRY_WAIT, float timeout=_INFO_TIMEOUT)
Definition: data.py:107
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)
Definition: wake_word.py:28