1 """Support for assist satellites in ESPHome."""
3 from __future__
import annotations
6 from collections.abc
import AsyncIterable
7 from functools
import partial
9 from itertools
import chain
12 from typing
import Any, cast
15 from aioesphomeapi
import (
16 MediaPlayerFormatPurpose,
17 MediaPlayerSupportedFormat,
18 VoiceAssistantAnnounceFinished,
19 VoiceAssistantAudioSettings,
20 VoiceAssistantCommandFlag,
21 VoiceAssistantEventType,
22 VoiceAssistantFeature,
23 VoiceAssistantTimerEventType,
35 async_register_timer_handler,
44 from .const
import DOMAIN
45 from .entity
import EsphomeAssistEntity
46 from .entry_data
import ESPHomeConfigEntry, RuntimeEntryData
47 from .enum_mapper
import EsphomeEnumMapper
48 from .ffmpeg_proxy
import async_create_proxy_url
50 _LOGGER = logging.getLogger(__name__)
52 _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
53 VoiceAssistantEventType, PipelineEventType
56 VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
57 VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
58 VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
59 VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
60 VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
61 VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
62 VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
63 VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
64 VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
65 VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
66 VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
67 VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
68 VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
72 _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
75 VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
76 VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
77 VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
78 VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
83 _ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60
84 _CONFIG_TIMEOUT_SEC = 5
89 entry: ESPHomeConfigEntry,
90 async_add_entities: AddEntitiesCallback,
92 """Set up Assist satellite entity."""
93 entry_data = entry.runtime_data
94 assert entry_data.device_info
is not None
95 if entry_data.device_info.voice_assistant_feature_flags_compat(
96 entry_data.api_version
104 """Satellite running ESPHome."""
107 key=
"assist_satellite", translation_key=
"assist_satellite"
112 config_entry: ConfigEntry,
113 entry_data: RuntimeEntryData,
115 """Initialize satellite."""
124 self._audio_queue: asyncio.Queue[bytes |
None] = asyncio.Queue()
126 self.
_udp_server_udp_server: VoiceAssistantUDPServer |
None =
None
130 available_wake_words=[], active_wake_words=[], max_active_wake_words=1
135 """Return the entity ID of the pipeline to use for the next conversation."""
136 assert self.
entry_dataentry_data.device_info
is not None
137 ent_reg = er.async_get(self.
hasshass)
138 return ent_reg.async_get_entity_id(
141 f
"{self.entry_data.device_info.mac_address}-pipeline",
146 """Return the entity ID of the VAD sensitivity to use for the next conversation."""
147 assert self.
entry_dataentry_data.device_info
is not None
148 ent_reg = er.async_get(self.
hasshass)
149 return ent_reg.async_get_entity_id(
152 f
"{self.entry_data.device_info.mac_address}-vad_sensitivity",
158 ) -> assist_satellite.AssistSatelliteConfiguration:
159 """Get the current satellite configuration."""
163 self, config: assist_satellite.AssistSatelliteConfiguration
165 """Set the current satellite configuration."""
166 await self.
clicli.set_voice_assistant_configuration(
167 active_wake_words=config.active_wake_words
169 _LOGGER.debug(
"Set active wake words: %s", config.active_wake_words)
175 """Get the latest satellite configuration from the device."""
177 config = await self.
clicli.get_voice_assistant_configuration(
188 wake_word=model.wake_word,
189 trained_languages=
list(model.trained_languages),
191 for model
in config.available_wake_words
194 self.
_satellite_config_satellite_config.max_active_wake_words = config.max_active_wake_words
195 _LOGGER.debug(
"Received satellite configuration: %s", self.
_satellite_config_satellite_config)
201 """Run when entity about to be added to hass."""
204 assert self.
entry_dataentry_data.device_info
is not None
206 self.
entry_dataentry_data.device_info.voice_assistant_feature_flags_compat(
210 if feature_flags & VoiceAssistantFeature.API_AUDIO:
213 self.
clicli.subscribe_voice_assistant(
223 self.
clicli.subscribe_voice_assistant(
230 if feature_flags & VoiceAssistantFeature.TIMERS:
241 if feature_flags & VoiceAssistantFeature.ANNOUNCE:
244 assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
249 _LOGGER.debug(
"Waiting for satellite configuration")
252 if not (feature_flags & VoiceAssistantFeature.SPEAKER):
258 self.
entry_dataentry_data.async_register_assist_satellite_set_wake_word_callback(
264 """Run when entity will be removed from hass."""
271 """Handle pipeline events."""
273 event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
275 _LOGGER.debug(
"Received unknown pipeline event type: %s", event.type)
278 data_to_send: dict[str, Any] = {}
279 if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
280 self.
entry_dataentry_data.async_set_assist_pipeline_state(
True)
281 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
282 assert event.data
is not None
283 data_to_send = {
"text": event.data[
"stt_output"][
"text"]}
284 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
285 assert event.data
is not None
287 "conversation_id": event.data[
"intent_output"][
"conversation_id"]
or "",
289 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
290 assert event.data
is not None
291 data_to_send = {
"text": event.data[
"tts_input"]}
292 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
293 assert event.data
is not None
294 if tts_output := event.data[
"tts_output"]:
295 path = tts_output[
"url"]
297 data_to_send = {
"url": url}
299 assert self.
entry_dataentry_data.device_info
is not None
301 self.
entry_dataentry_data.device_info.voice_assistant_feature_flags_compat(
305 if feature_flags & VoiceAssistantFeature.SPEAKER:
306 media_id = tts_output[
"media_id"]
308 self.
config_entryconfig_entry.async_create_background_task(
311 "esphome_voice_assistant_tts",
314 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
315 assert event.data
is not None
316 if not event.data[
"wake_word_output"]:
317 event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
319 "code":
"no_wake_word",
320 "message":
"No wake word detected",
322 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
323 assert event.data
is not None
325 "code": event.data[
"code"],
326 "message": event.data[
"message"],
328 elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END:
331 self.
entry_dataentry_data.async_set_assist_pipeline_state(
False)
333 self.
clicli.send_voice_assistant_event(event_type, data_to_send)
336 self, announcement: assist_satellite.AssistSatelliteAnnouncement
338 """Announce media on the satellite.
340 Should block until the announcement is done playing.
343 "Waiting for announcement to finished (message=%s, media_id=%s)",
344 announcement.message,
345 announcement.media_id,
347 media_id = announcement.media_id
348 if announcement.media_id_source !=
"tts":
350 format_to_use: MediaPlayerSupportedFormat |
None =
None
351 for supported_format
in chain(
352 *self.
entry_dataentry_data.media_player_formats.values()
354 if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
355 format_to_use = supported_format
358 if format_to_use
is not None:
366 media_format=format_to_use.format,
367 rate=format_to_use.sample_rate
or None,
368 channels=format_to_use.num_channels
or None,
369 width=format_to_use.sample_bytes
or None,
373 await self.
clicli.send_voice_assistant_announcement_await_response(
374 media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
379 conversation_id: str,
381 audio_settings: VoiceAssistantAudioSettings,
382 wake_word_phrase: str |
None,
384 """Handle pipeline run request."""
386 while not self._audio_queue.empty():
387 await self._audio_queue.
get()
396 assert self.
entry_dataentry_data.device_info
is not None
398 self.
entry_dataentry_data.device_info.voice_assistant_feature_flags_compat(
402 if (feature_flags & VoiceAssistantFeature.SPEAKER)
and not (
403 feature_flags & VoiceAssistantFeature.API_AUDIO
406 _LOGGER.debug(
"Started UDP server on port %s", port)
409 if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
410 start_stage = PipelineStage.WAKE_WORD
412 start_stage = PipelineStage.STT
414 end_stage = PipelineStage.TTS
416 if feature_flags & VoiceAssistantFeature.SPEAKER:
419 tts.ATTR_PREFERRED_FORMAT:
"wav",
420 tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
421 tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
422 tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
429 _LOGGER.debug(
"Running pipeline from %s to %s", start_stage, end_stage)
434 start_stage=start_stage,
436 wake_word_phrase=wake_word_phrase,
438 "esphome_assist_satellite_pipeline",
447 """Handle incoming audio chunk from API."""
448 self._audio_queue.put_nowait(data)
451 """Handle request for pipeline to stop."""
458 """Handle when pipeline has finished running."""
460 _LOGGER.debug(
"Pipeline finished")
463 self, event_type: TimerEventType, timer_info: TimerInfo
465 """Handle timer events."""
467 native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
469 _LOGGER.debug(
"Received unknown timer event type: %s", event_type)
472 self.
clicli.send_voice_assistant_timer_event(
476 timer_info.created_seconds,
477 timer_info.seconds_left,
478 timer_info.is_active,
482 self, announce_finished: VoiceAssistantAnnounceFinished
484 """Handle announcement finished message (also sent for TTS)."""
489 """Set active wake word and update config on satellite."""
491 self.
config_entryconfig_entry.async_create_background_task(
494 "esphome_voice_assistant_set_config",
496 _LOGGER.debug(
"Setting active wake word: %s", wake_word_id)
499 """Update the TTS format from the first media player."""
500 for supported_format
in chain(*self.
entry_dataentry_data.media_player_formats.values()):
502 if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
504 tts.ATTR_PREFERRED_FORMAT: supported_format.format,
507 if supported_format.sample_rate > 0:
509 supported_format.sample_rate
512 if supported_format.sample_rate > 0:
514 supported_format.num_channels
517 if supported_format.sample_rate > 0:
519 supported_format.sample_bytes
527 sample_rate: int = 16000,
528 sample_width: int = 2,
529 sample_channels: int = 1,
530 samples_per_chunk: int = 512,
532 """Stream TTS audio chunks to device via API or UDP."""
533 self.
clicli.send_voice_assistant_event(
534 VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
541 extension, data = await tts.async_get_media_source_audio(
546 if extension !=
"wav":
547 _LOGGER.error(
"Only WAV audio can be streamed, got %s", extension)
550 with io.BytesIO(data)
as wav_io, wave.open(wav_io,
"rb")
as wav_file:
552 (wav_file.getframerate() != sample_rate)
553 or (wav_file.getsampwidth() != sample_width)
554 or (wav_file.getnchannels() != sample_channels)
556 _LOGGER.error(
"Can only stream 16Khz 16-bit mono WAV")
559 _LOGGER.debug(
"Streaming %s audio samples", wav_file.getnframes())
562 chunk = wav_file.readframes(samples_per_chunk)
567 self.
_udp_server_udp_server.send_audio_bytes(chunk)
569 self.
clicli.send_voice_assistant_audio(chunk)
575 samples_in_chunk = len(chunk) // (sample_width * sample_channels)
576 seconds_in_chunk = samples_in_chunk / sample_rate
577 await asyncio.sleep(seconds_in_chunk * 0.9)
578 except asyncio.CancelledError:
581 self.
clicli.send_voice_assistant_event(
582 VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
587 self.
entry_dataentry_data.async_set_assist_pipeline_state(
False)
590 """Yield audio chunks from the queue until None."""
592 chunk = await self._audio_queue.
get()
599 """Request pipeline to be stopped by ending the audio stream and continue processing."""
600 self._audio_queue.put_nowait(
None)
601 _LOGGER.debug(
"Requested pipeline stop")
604 """Request pipeline to be aborted (no further processing)."""
605 _LOGGER.debug(
"Requested pipeline abort")
606 self._audio_queue.put_nowait(
None)
611 """Start a UDP server on a random free port."""
612 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
613 sock.setblocking(
False)
619 ) = await asyncio.get_running_loop().create_datagram_endpoint(
620 partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
623 assert isinstance(protocol, VoiceAssistantUDPServer)
627 return cast(int, sock.getsockname()[1])
630 """Stop the UDP server if it's running."""
639 _LOGGER.debug(
"Stopped UDP server")
643 """Receive UDP packets and forward them to the audio queue."""
645 transport: asyncio.DatagramTransport |
None =
None
646 remote_addr: tuple[str, int] |
None =
None
649 self, audio_queue: asyncio.Queue[bytes |
None], *args: Any, **kwargs: Any
651 """Initialize protocol."""
656 """Store transport for later use."""
657 self.
transporttransport = cast(asyncio.DatagramTransport, transport)
660 """Handle incoming UDP packet."""
667 """Handle when a send or receive operation raises an OSError.
669 (Other than BlockingIOError or InterruptedError.)
671 _LOGGER.error(
"ESPHome Voice Assistant UDP server error received: %s", exc)
677 """Close the receiver."""
684 """Send bytes to the device via UDP."""
686 _LOGGER.error(
"No transport to send audio to")
690 _LOGGER.error(
"No address to send audio to")
None async_set_configuration(self, AssistSatelliteConfiguration config)
None async_accept_pipeline_from_satellite(self, AsyncIterable[bytes] audio_stream, PipelineStage start_stage=PipelineStage.STT, PipelineStage end_stage=PipelineStage.TTS, str|None wake_word_phrase=None)
None tts_response_finished(self)
None handle_pipeline_stop(self, bool abort)
int _start_udp_server(self)
str|None pipeline_entity_id(self)
None _update_tts_format(self)
None async_set_wake_word(self, str wake_word_id)
None on_pipeline_event(self, PipelineEvent event)
None handle_announcement_finished(self, VoiceAssistantAnnounceFinished announce_finished)
assist_satellite.AssistSatelliteConfiguration async_get_configuration(self)
None _stop_pipeline(self)
None __init__(self, ConfigEntry config_entry, RuntimeEntryData entry_data)
str|None vad_sensitivity_entity_id(self)
None async_will_remove_from_hass(self)
None async_announce(self, assist_satellite.AssistSatelliteAnnouncement announcement)
None handle_pipeline_finished(self)
AsyncIterable[bytes] _wrap_audio_stream(self)
None _update_satellite_config(self)
None async_set_configuration(self, assist_satellite.AssistSatelliteConfiguration config)
int|None handle_pipeline_start(self, str conversation_id, int flags, VoiceAssistantAudioSettings audio_settings, str|None wake_word_phrase)
None _stop_udp_server(self)
None _stream_tts_audio(self, str media_id, int sample_rate=16000, int sample_width=2, int sample_channels=1, int samples_per_chunk=512)
None handle_timer_event(self, TimerEventType event_type, TimerInfo timer_info)
None _abort_pipeline(self)
None handle_audio(self, bytes data)
None async_added_to_hass(self)
None __init__(self, asyncio.Queue[bytes|None] audio_queue, *Any args, **Any kwargs)
None send_audio_bytes(self, bytes data)
None error_received(self, Exception exc)
None connection_made(self, asyncio.BaseTransport transport)
None datagram_received(self, bytes data, tuple[str, int] addr)
None async_on_remove(self, CALLBACK_TYPE func)
web.Response get(self, web.Request request, str config_key)
None async_setup_entry(HomeAssistant hass, ESPHomeConfigEntry entry, AddEntitiesCallback async_add_entities)
str async_create_proxy_url(HomeAssistant hass, str device_id, str media_url, str media_format, int|None rate=None, int|None channels=None, int|None width=None)
Callable[[], None] async_register_timer_handler(HomeAssistant hass, str device_id, TimerHandler handler)