Home Assistant Unofficial Reference 2024.12.1
assist_satellite.py
Go to the documentation of this file.
1 """Assist satellite entity for VoIP integration."""
2 
3 from __future__ import annotations
4 
5 import asyncio
6 from enum import IntFlag
7 from functools import partial
8 import io
9 import logging
10 from pathlib import Path
11 from typing import TYPE_CHECKING, Any, Final
12 import wave
13 
14 from voip_utils import RtpDatagramProtocol
15 
16 from homeassistant.components import tts
17 from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
19  AssistSatelliteConfiguration,
20  AssistSatelliteEntity,
21  AssistSatelliteEntityDescription,
22 )
23 from homeassistant.config_entries import ConfigEntry
24 from homeassistant.core import Context, HomeAssistant, callback
25 from homeassistant.helpers.entity_platform import AddEntitiesCallback
26 
27 from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
28 from .devices import VoIPDevice
29 from .entity import VoIPEntity
30 
31 if TYPE_CHECKING:
32  from . import DomainData
33 
34 _LOGGER = logging.getLogger(__name__)
35 
36 _PIPELINE_TIMEOUT_SEC: Final = 30
37 
38 
39 class Tones(IntFlag):
40  """Feedback tones for specific events."""
41 
42  LISTENING = 1
43  PROCESSING = 2
44  ERROR = 4
45 
46 
47 _TONE_FILENAMES: dict[Tones, str] = {
48  Tones.LISTENING: "tone.pcm",
49  Tones.PROCESSING: "processing.pcm",
50  Tones.ERROR: "error.pcm",
51 }
52 
53 
55  hass: HomeAssistant,
56  config_entry: ConfigEntry,
57  async_add_entities: AddEntitiesCallback,
58 ) -> None:
59  """Set up VoIP Assist satellite entity."""
60  domain_data: DomainData = hass.data[DOMAIN]
61 
62  @callback
63  def async_add_device(device: VoIPDevice) -> None:
64  """Add device."""
65  async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
66 
67  domain_data.devices.async_add_new_device_listener(async_add_device)
68 
69  entities: list[VoIPEntity] = [
70  VoipAssistSatellite(hass, device, config_entry)
71  for device in domain_data.devices
72  ]
73 
74  async_add_entities(entities)
75 
76 
78  """Assist satellite for VoIP devices."""
79 
80  entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
81  _attr_translation_key = "assist_satellite"
82  _attr_name = None
83 
84  def __init__(
85  self,
86  hass: HomeAssistant,
87  voip_device: VoIPDevice,
88  config_entry: ConfigEntry,
89  tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
90  ) -> None:
91  """Initialize an Assist satellite."""
92  VoIPEntity.__init__(self, voip_device)
93  AssistSatelliteEntity.__init__(self)
94  RtpDatagramProtocol.__init__(self)
95 
96  self.config_entryconfig_entry = config_entry
97 
98  self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
99  self._audio_chunk_timeout: float = 2.0
100  self._run_pipeline_task_run_pipeline_task: asyncio.Task | None = None
101  self._pipeline_had_error_pipeline_had_error: bool = False
102  self._tts_done_tts_done = asyncio.Event()
103  self._tts_extra_timeout: float = 1.0
104  self._tone_bytes: dict[Tones, bytes] = {}
105  self._tones_tones = tones
106  self._processing_tone_done_processing_tone_done = asyncio.Event()
107 
108  @property
109  def pipeline_entity_id(self) -> str | None:
110  """Return the entity ID of the pipeline to use for the next conversation."""
111  return self.voip_devicevoip_device.get_pipeline_entity_id(self.hass)
112 
113  @property
114  def vad_sensitivity_entity_id(self) -> str | None:
115  """Return the entity ID of the VAD sensitivity to use for the next conversation."""
116  return self.voip_devicevoip_device.get_vad_sensitivity_entity_id(self.hass)
117 
118  @property
119  def tts_options(self) -> dict[str, Any] | None:
120  """Options passed for text-to-speech."""
121  return {
122  tts.ATTR_PREFERRED_FORMAT: "wav",
123  tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
124  tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
125  tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
126  }
127 
128  async def async_added_to_hass(self) -> None:
129  """Run when entity about to be added to hass."""
130  await super().async_added_to_hass()
131  self.voip_devicevoip_device.protocol = self
132 
133  async def async_will_remove_from_hass(self) -> None:
134  """Run when entity will be removed from hass."""
135  await super().async_will_remove_from_hass()
136  assert self.voip_devicevoip_device.protocol == self
137  self.voip_devicevoip_device.protocol = None
138 
139  @callback
141  self,
142  ) -> AssistSatelliteConfiguration:
143  """Get the current satellite configuration."""
144  raise NotImplementedError
145 
147  self, config: AssistSatelliteConfiguration
148  ) -> None:
149  """Set the current satellite configuration."""
150  raise NotImplementedError
151 
152  # -------------------------------------------------------------------------
153  # VoIP
154  # -------------------------------------------------------------------------
155 
156  def on_chunk(self, audio_bytes: bytes) -> None:
157  """Handle raw audio chunk."""
158  if self._run_pipeline_task_run_pipeline_task is None:
159  # Run pipeline until voice command finishes, then start over
160  self._clear_audio_queue_clear_audio_queue()
161  self._tts_done_tts_done.clear()
162  self._run_pipeline_task_run_pipeline_task = self.config_entryconfig_entry.async_create_background_task(
163  self.hass,
164  self._run_pipeline_run_pipeline(),
165  "voip_pipeline_run",
166  )
167 
168  self._audio_queue.put_nowait(audio_bytes)
169 
170  async def _run_pipeline(self) -> None:
171  _LOGGER.debug("Starting pipeline")
172 
173  self.async_set_context(Context(user_id=self.config_entryconfig_entry.data["user"]))
174  self.voip_devicevoip_device.set_is_active(True)
175 
176  async def stt_stream():
177  while True:
178  async with asyncio.timeout(self._audio_chunk_timeout):
179  chunk = await self._audio_queue.get()
180  if not chunk:
181  break
182 
183  yield chunk
184 
185  # Play listening tone at the start of each cycle
186  await self._play_tone_play_tone(Tones.LISTENING, silence_before=0.2)
187 
188  try:
189  await self.async_accept_pipeline_from_satelliteasync_accept_pipeline_from_satellite(
190  audio_stream=stt_stream(),
191  )
192 
193  if self._pipeline_had_error_pipeline_had_error:
194  self._pipeline_had_error_pipeline_had_error = False
195  await self._play_tone_play_tone(Tones.ERROR)
196  else:
197  # Block until TTS is done speaking.
198  #
199  # This is set in _send_tts and has a timeout that's based on the
200  # length of the TTS audio.
201  await self._tts_done_tts_done.wait()
202  except TimeoutError:
203  self.disconnect() # caller hung up
204  finally:
205  # Stop audio stream
206  await self._audio_queue.put(None)
207 
208  self.voip_devicevoip_device.set_is_active(False)
209  self._run_pipeline_task_run_pipeline_task = None
210  _LOGGER.debug("Pipeline finished")
211 
212  def _clear_audio_queue(self) -> None:
213  """Ensure audio queue is empty."""
214  while not self._audio_queue.empty():
215  self._audio_queue.get_nowait()
216 
217  def on_pipeline_event(self, event: PipelineEvent) -> None:
218  """Set state based on pipeline stage."""
219  if event.type == PipelineEventType.STT_END:
220  if (self._tones_tones & Tones.PROCESSING) == Tones.PROCESSING:
221  self._processing_tone_done_processing_tone_done.clear()
222  self.config_entryconfig_entry.async_create_background_task(
223  self.hass, self._play_tone_play_tone(Tones.PROCESSING), "voip_process_tone"
224  )
225  elif event.type == PipelineEventType.TTS_END:
226  # Send TTS audio to caller over RTP
227  if event.data and (tts_output := event.data["tts_output"]):
228  media_id = tts_output["media_id"]
229  self.config_entryconfig_entry.async_create_background_task(
230  self.hass,
231  self._send_tts_send_tts(media_id),
232  "voip_pipeline_tts",
233  )
234  else:
235  # Empty TTS response
236  self._tts_done_tts_done.set()
237  elif event.type == PipelineEventType.ERROR:
238  # Play error tone instead of wait for TTS when pipeline is finished.
239  self._pipeline_had_error_pipeline_had_error = True
240  _LOGGER.warning(event)
241 
242  async def _send_tts(self, media_id: str) -> None:
243  """Send TTS audio to caller via RTP."""
244  try:
245  if self.transport is None:
246  return # not connected
247 
248  extension, data = await tts.async_get_media_source_audio(
249  self.hass,
250  media_id,
251  )
252 
253  if extension != "wav":
254  raise ValueError(f"Only WAV audio can be streamed, got {extension}")
255 
256  if (self._tones_tones & Tones.PROCESSING) == Tones.PROCESSING:
257  # Don't overlap TTS and processing beep
258  _LOGGER.debug("Waiting for processing tone")
259  await self._processing_tone_done_processing_tone_done.wait()
260 
261  with io.BytesIO(data) as wav_io:
262  with wave.open(wav_io, "rb") as wav_file:
263  sample_rate = wav_file.getframerate()
264  sample_width = wav_file.getsampwidth()
265  sample_channels = wav_file.getnchannels()
266 
267  if (
268  (sample_rate != RATE)
269  or (sample_width != WIDTH)
270  or (sample_channels != CHANNELS)
271  ):
272  raise ValueError(
273  f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
274  f" got {sample_rate}/{sample_width}/{sample_channels}"
275  )
276 
277  audio_bytes = wav_file.readframes(wav_file.getnframes())
278 
279  _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
280 
281  # Time out 1 second after TTS audio should be finished
282  tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
283  tts_seconds = tts_samples / RATE
284 
285  async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
286  # TTS audio is 16Khz 16-bit mono
287  await self._async_send_audio_async_send_audio(audio_bytes)
288  except TimeoutError:
289  _LOGGER.warning("TTS timeout")
290  raise
291  finally:
292  # Update satellite state
293  self.tts_response_finishedtts_response_finished()
294 
295  # Signal pipeline to restart
296  self._tts_done_tts_done.set()
297 
298  async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
299  """Send audio in executor."""
300  await self.hass.async_add_executor_job(
301  partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
302  )
303 
304  async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
305  """Play a tone as feedback to the user if it's enabled."""
306  if (self._tones_tones & tone) != tone:
307  return # not enabled
308 
309  if tone not in self._tone_bytes:
310  # Do I/O in executor
311  self._tone_bytes[tone] = await self.hass.async_add_executor_job(
312  self._load_pcm_load_pcm,
313  _TONE_FILENAMES[tone],
314  )
315 
316  await self._async_send_audio_async_send_audio(
317  self._tone_bytes[tone],
318  silence_before=silence_before,
319  )
320 
321  if tone == Tones.PROCESSING:
322  self._processing_tone_done_processing_tone_done.set()
323 
324  def _load_pcm(self, file_name: str) -> bytes:
325  """Load raw audio (16Khz, 16-bit mono)."""
326  return (Path(__file__).parent / file_name).read_bytes()
None async_accept_pipeline_from_satellite(self, AsyncIterable[bytes] audio_stream, PipelineStage start_stage=PipelineStage.STT, PipelineStage end_stage=PipelineStage.TTS, str|None wake_word_phrase=None)
Definition: entity.py:260
None async_set_configuration(self, AssistSatelliteConfiguration config)
None _play_tone(self, Tones tone, float silence_before=0.0)
None __init__(self, HomeAssistant hass, VoIPDevice voip_device, ConfigEntry config_entry, tones=Tones.LISTENING|Tones.PROCESSING|Tones.ERROR)
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)