1 """Assist satellite entity for VoIP integration."""
3 from __future__
import annotations
6 from enum
import IntFlag
7 from functools
import partial
10 from pathlib
import Path
11 from typing
import TYPE_CHECKING, Any, Final
14 from voip_utils
import RtpDatagramProtocol
19 AssistSatelliteConfiguration,
20 AssistSatelliteEntity,
21 AssistSatelliteEntityDescription,
27 from .const
import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
28 from .devices
import VoIPDevice
29 from .entity
import VoIPEntity
32 from .
import DomainData
34 _LOGGER = logging.getLogger(__name__)
36 _PIPELINE_TIMEOUT_SEC: Final = 30
40 """Feedback tones for specific events."""
47 _TONE_FILENAMES: dict[Tones, str] = {
48 Tones.LISTENING:
"tone.pcm",
49 Tones.PROCESSING:
"processing.pcm",
50 Tones.ERROR:
"error.pcm",
56 config_entry: ConfigEntry,
57 async_add_entities: AddEntitiesCallback,
59 """Set up VoIP Assist satellite entity."""
60 domain_data: DomainData = hass.data[DOMAIN]
63 def async_add_device(device: VoIPDevice) ->
None:
67 domain_data.devices.async_add_new_device_listener(async_add_device)
69 entities: list[VoIPEntity] = [
71 for device
in domain_data.devices
78 """Assist satellite for VoIP devices."""
81 _attr_translation_key =
"assist_satellite"
87 voip_device: VoIPDevice,
88 config_entry: ConfigEntry,
89 tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
91 """Initialize an Assist satellite."""
92 VoIPEntity.__init__(self, voip_device)
93 AssistSatelliteEntity.__init__(self)
94 RtpDatagramProtocol.__init__(self)
98 self._audio_queue: asyncio.Queue[bytes |
None] = asyncio.Queue()
99 self._audio_chunk_timeout: float = 2.0
103 self._tts_extra_timeout: float = 1.0
104 self._tone_bytes: dict[Tones, bytes] = {}
110 """Return the entity ID of the pipeline to use for the next conversation."""
111 return self.
voip_devicevoip_device.get_pipeline_entity_id(self.hass)
115 """Return the entity ID of the VAD sensitivity to use for the next conversation."""
116 return self.
voip_devicevoip_device.get_vad_sensitivity_entity_id(self.hass)
120 """Options passed for text-to-speech."""
122 tts.ATTR_PREFERRED_FORMAT:
"wav",
123 tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
124 tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
125 tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
129 """Run when entity about to be added to hass."""
134 """Run when entity will be removed from hass."""
136 assert self.
voip_devicevoip_device.protocol == self
142 ) -> AssistSatelliteConfiguration:
143 """Get the current satellite configuration."""
144 raise NotImplementedError
147 self, config: AssistSatelliteConfiguration
149 """Set the current satellite configuration."""
150 raise NotImplementedError
157 """Handle raw audio chunk."""
168 self._audio_queue.put_nowait(audio_bytes)
171 _LOGGER.debug(
"Starting pipeline")
176 async
def stt_stream():
178 async
with asyncio.timeout(self._audio_chunk_timeout):
179 chunk = await self._audio_queue.
get()
186 await self.
_play_tone_play_tone(Tones.LISTENING, silence_before=0.2)
190 audio_stream=stt_stream(),
206 await self._audio_queue.put(
None)
210 _LOGGER.debug(
"Pipeline finished")
213 """Ensure audio queue is empty."""
214 while not self._audio_queue.empty():
215 self._audio_queue.get_nowait()
218 """Set state based on pipeline stage."""
219 if event.type == PipelineEventType.STT_END:
220 if (self.
_tones_tones & Tones.PROCESSING) == Tones.PROCESSING:
222 self.
config_entryconfig_entry.async_create_background_task(
223 self.hass, self.
_play_tone_play_tone(Tones.PROCESSING),
"voip_process_tone"
225 elif event.type == PipelineEventType.TTS_END:
227 if event.data
and (tts_output := event.data[
"tts_output"]):
228 media_id = tts_output[
"media_id"]
229 self.
config_entryconfig_entry.async_create_background_task(
237 elif event.type == PipelineEventType.ERROR:
240 _LOGGER.warning(event)
243 """Send TTS audio to caller via RTP."""
245 if self.transport
is None:
248 extension, data = await tts.async_get_media_source_audio(
253 if extension !=
"wav":
254 raise ValueError(f
"Only WAV audio can be streamed, got {extension}")
256 if (self.
_tones_tones & Tones.PROCESSING) == Tones.PROCESSING:
258 _LOGGER.debug(
"Waiting for processing tone")
261 with io.BytesIO(data)
as wav_io:
262 with wave.open(wav_io,
"rb")
as wav_file:
263 sample_rate = wav_file.getframerate()
264 sample_width = wav_file.getsampwidth()
265 sample_channels = wav_file.getnchannels()
268 (sample_rate != RATE)
269 or (sample_width != WIDTH)
270 or (sample_channels != CHANNELS)
273 f
"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
274 f
" got {sample_rate}/{sample_width}/{sample_channels}"
277 audio_bytes = wav_file.readframes(wav_file.getnframes())
279 _LOGGER.debug(
"Sending %s byte(s) of audio", len(audio_bytes))
282 tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
283 tts_seconds = tts_samples / RATE
285 async
with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
289 _LOGGER.warning(
"TTS timeout")
299 """Send audio in executor."""
300 await self.hass.async_add_executor_job(
301 partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
304 async
def _play_tone(self, tone: Tones, silence_before: float = 0.0) ->
None:
305 """Play a tone as feedback to the user if it's enabled."""
306 if (self.
_tones_tones & tone) != tone:
309 if tone
not in self._tone_bytes:
311 self._tone_bytes[tone] = await self.hass.async_add_executor_job(
313 _TONE_FILENAMES[tone],
317 self._tone_bytes[tone],
318 silence_before=silence_before,
321 if tone == Tones.PROCESSING:
325 """Load raw audio (16Khz, 16-bit mono)."""
326 return (Path(__file__).parent / file_name).read_bytes()
None async_accept_pipeline_from_satellite(self, AsyncIterable[bytes] audio_stream, PipelineStage start_stage=PipelineStage.STT, PipelineStage end_stage=PipelineStage.TTS, str|None wake_word_phrase=None)
None tts_response_finished(self)
None async_set_configuration(self, AssistSatelliteConfiguration config)
str|None vad_sensitivity_entity_id(self)
None on_pipeline_event(self, PipelineEvent event)
None async_will_remove_from_hass(self)
None _clear_audio_queue(self)
str|None pipeline_entity_id(self)
None _play_tone(self, Tones tone, float silence_before=0.0)
None async_added_to_hass(self)
def _async_send_audio(self, bytes audio_bytes, **kwargs)
AssistSatelliteConfiguration async_get_configuration(self)
None on_chunk(self, bytes audio_bytes)
None __init__(self, HomeAssistant hass, VoIPDevice voip_device, ConfigEntry config_entry, tones=Tones.LISTENING|Tones.PROCESSING|Tones.ERROR)
bytes _load_pcm(self, str file_name)
None _send_tts(self, str media_id)
dict[str, Any]|None tts_options(self)
web.Response get(self, web.Request request, str config_key)
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)