1 """Support for Wyoming wake-word-detection services."""
4 from collections.abc
import AsyncIterable
7 from wyoming.audio
import AudioChunk, AudioStart
8 from wyoming.client
import AsyncTcpClient
9 from wyoming.wake
import Detect, Detection
16 from .const
import DOMAIN
17 from .data
import WyomingService, load_wyoming_info
18 from .error
import WyomingError
19 from .models
import DomainDataItem
21 _LOGGER = logging.getLogger(__name__)
26 config_entry: ConfigEntry,
27 async_add_entities: AddEntitiesCallback,
29 """Set up Wyoming wake-word-detection."""
30 item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
39 """Wyoming wake-word-detection provider."""
44 config_entry: ConfigEntry,
45 service: WyomingService,
47 """Set up provider."""
50 wake_service = service.info.wake[0]
54 id=ww.name, name=ww.description
or ww.name, phrase=ww.phrase
56 for ww
in wake_service.models
62 """Return a list of supported wake words."""
64 self.
serviceservice.host, self.
serviceservice.port, retries=0, timeout=1
68 wake_service = info.wake[0]
72 name=ww.description
or ww.name,
75 for ww
in wake_service.models
81 self, stream: AsyncIterable[tuple[bytes, int]], wake_word_id: str |
None
83 """Try to detect one or more wake words in an audio stream.
85 Audio must be 16Khz sample rate with 16-bit mono PCM samples.
88 async
def next_chunk():
89 """Get the next chunk from audio stream."""
90 async
for chunk_bytes
in stream:
95 async
with AsyncTcpClient(self.
serviceservice.host, self.
serviceservice.port)
as client:
97 await client.write_event(
98 Detect(names=[wake_word_id]
if wake_word_id
else None).event()
101 await client.write_event(
110 audio_task = asyncio.create_task(next_chunk())
111 wake_task = asyncio.create_task(client.read_event())
112 pending = {audio_task, wake_task}
116 done, pending = await asyncio.wait(
117 pending, return_when=asyncio.FIRST_COMPLETED
120 if wake_task
in done:
121 event = wake_task.result()
123 _LOGGER.debug(
"Connection lost")
126 if Detection.is_type(event.type):
128 detection = Detection.from_event(event)
129 _LOGGER.info(detection)
131 if wake_word_id
and (detection.name != wake_word_id):
133 "Expected wake word %s but got %s, skipping",
137 wake_task = asyncio.create_task(client.read_event())
138 pending.add(wake_task)
142 queued_audio: list[tuple[bytes, int]] |
None =
None
143 if audio_task
in pending:
146 pending.remove(audio_task)
147 queued_audio = [audio_task.result()]
150 wake_word_id=detection.name,
151 wake_word_phrase=self.
_get_phrase_get_phrase(detection.name),
152 timestamp=detection.timestamp,
153 queued_audio=queued_audio,
157 wake_task = asyncio.create_task(client.read_event())
158 pending.add(wake_task)
160 if audio_task
in done:
162 chunk_info = audio_task.result()
163 if chunk_info
is None:
166 chunk_bytes, chunk_timestamp = chunk_info
172 timestamp=chunk_timestamp,
174 await client.write_event(chunk.event())
177 audio_task = asyncio.create_task(next_chunk())
178 pending.add(audio_task)
181 if audio_task
in pending:
186 pending.remove(audio_task)
191 except (OSError, WyomingError):
192 _LOGGER.exception(
"Error processing audio stream")
197 """Get wake word phrase for model id."""
199 if not ww_model.phrase:
202 if ww_model.id == model_id:
203 return ww_model.phrase
wake_word.DetectionResult|None _async_process_audio_stream(self, AsyncIterable[tuple[bytes, int]] stream, str|None wake_word_id)
list[wake_word.WakeWord] get_supported_wake_words(self)
str _get_phrase(self, str model_id)
None __init__(self, HomeAssistant hass, ConfigEntry config_entry, WyomingService service)
Info|None load_wyoming_info(str host, int port, int retries=_INFO_RETRIES, float retry_wait=_INFO_RETRY_WAIT, float timeout=_INFO_TIMEOUT)
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)