Home Assistant Unofficial Reference 2024.12.1
websocket_api.py
Go to the documentation of this file.
1 """Assist pipeline Websocket API."""
2 
3 import asyncio
4 
5 # Suppressing disable=deprecated-module is needed for Python 3.11
6 import audioop # pylint: disable=deprecated-module
7 import base64
8 from collections.abc import AsyncGenerator, Callable
9 import contextlib
10 import logging
11 import math
12 from typing import Any, Final
13 
14 import voluptuous as vol
15 
16 from homeassistant.components import conversation, stt, tts, websocket_api
17 from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
18 from homeassistant.core import HomeAssistant, callback
19 from homeassistant.helpers import config_validation as cv, entity_registry as er
20 from homeassistant.util import language as language_util
21 
22 from .const import (
23  DEFAULT_PIPELINE_TIMEOUT,
24  DEFAULT_WAKE_WORD_TIMEOUT,
25  DOMAIN,
26  EVENT_RECORDING,
27  SAMPLE_CHANNELS,
28  SAMPLE_RATE,
29  SAMPLE_WIDTH,
30 )
31 from .error import PipelineNotFound
32 from .pipeline import (
33  AudioSettings,
34  DeviceAudioQueue,
35  PipelineData,
36  PipelineError,
37  PipelineEvent,
38  PipelineEventType,
39  PipelineInput,
40  PipelineRun,
41  PipelineStage,
42  WakeWordSettings,
43  async_get_pipeline,
44 )
45 
46 _LOGGER = logging.getLogger(__name__)
47 
48 CAPTURE_RATE: Final = 16000
49 CAPTURE_WIDTH: Final = 2
50 CAPTURE_CHANNELS: Final = 1
51 MAX_CAPTURE_TIMEOUT: Final = 60.0
52 
53 
54 @callback
55 def async_register_websocket_api(hass: HomeAssistant) -> None:
56  """Register the websocket API."""
57  websocket_api.async_register_command(hass, websocket_run)
58  websocket_api.async_register_command(hass, websocket_list_languages)
59  websocket_api.async_register_command(hass, websocket_list_runs)
60  websocket_api.async_register_command(hass, websocket_list_devices)
61  websocket_api.async_register_command(hass, websocket_get_run)
62  websocket_api.async_register_command(hass, websocket_device_capture)
63 
64 
65 @websocket_api.websocket_command( vol.All( websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( { vol.Required("type"): "assist_pipeline/run",
66  # pylint: disable-next=unnecessary-lambda
67  vol.Required("start_stage"): lambda val: PipelineStage(val),
68  # pylint: disable-next=unnecessary-lambda
69  vol.Required("end_stage"): lambda val: PipelineStage(val),
70  vol.Optional("input"): dict,
71  vol.Optional("pipeline"): str,
72  vol.Optional("conversation_id"): vol.Any(str, None),
73  vol.Optional("device_id"): vol.Any(str, None),
74  vol.Optional("timeout"): vol.Any(float, int),
75  },
76  ),
77  cv.key_value_schemas(
78  "start_stage",
79  {
80  PipelineStage.WAKE_WORD: vol.Schema(
81  {
82  vol.Required("input"): {
83  vol.Required("sample_rate"): int,
84  vol.Optional("timeout"): vol.Any(float, int),
85  vol.Optional("audio_seconds_to_buffer"): vol.Any(
86  float, int
87  ),
88  # Audio enhancement
89  vol.Optional("noise_suppression_level"): int,
90  vol.Optional("auto_gain_dbfs"): int,
91  vol.Optional("volume_multiplier"): float,
92  # Advanced use cases/testing
93  vol.Optional("no_vad"): bool,
94  }
95  },
96  extra=vol.ALLOW_EXTRA,
97  ),
98  PipelineStage.STT: vol.Schema(
99  {
100  vol.Required("input"): {
101  vol.Required("sample_rate"): int,
102  vol.Optional("wake_word_phrase"): str,
103  }
104  },
105  extra=vol.ALLOW_EXTRA,
106  ),
107  PipelineStage.INTENT: vol.Schema(
108  {vol.Required("input"): {"text": str}},
109  extra=vol.ALLOW_EXTRA,
110  ),
111  PipelineStage.TTS: vol.Schema(
112  {vol.Required("input"): {"text": str}},
113  extra=vol.ALLOW_EXTRA,
114  ),
115  },
116  ),
117  ),
118 )
119 @websocket_api.async_response
120 async def websocket_run(
121  hass: HomeAssistant,
122  connection: websocket_api.ActiveConnection,
123  msg: dict[str, Any],
124 ) -> None:
125  """Run a pipeline."""
126  pipeline_id = msg.get("pipeline")
127  try:
128  pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
129  except PipelineNotFound:
130  connection.send_error(
131  msg["id"],
132  "pipeline-not-found",
133  f"Pipeline not found: id={pipeline_id}",
134  )
135  return
136 
137  timeout = msg.get("timeout", DEFAULT_PIPELINE_TIMEOUT)
138  start_stage = PipelineStage(msg["start_stage"])
139  end_stage = PipelineStage(msg["end_stage"])
140  handler_id: int | None = None
141  unregister_handler: Callable[[], None] | None = None
142  wake_word_settings: WakeWordSettings | None = None
143  audio_settings: AudioSettings | None = None
144 
145  # Arguments to PipelineInput
146  input_args: dict[str, Any] = {
147  "conversation_id": msg.get("conversation_id"),
148  "device_id": msg.get("device_id"),
149  }
150 
151  if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
152  # Audio pipeline that will receive audio as binary websocket messages
153  msg_input = msg["input"]
154  audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
155  incoming_sample_rate = msg_input["sample_rate"]
156  wake_word_phrase: str | None = None
157 
158  if start_stage == PipelineStage.WAKE_WORD:
159  wake_word_settings = WakeWordSettings(
160  timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
161  audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
162  )
163  elif start_stage == PipelineStage.STT:
164  wake_word_phrase = msg["input"].get("wake_word_phrase")
165 
166  async def stt_stream() -> AsyncGenerator[bytes]:
167  state = None
168 
169  # Yield until we receive an empty chunk
170  while chunk := await audio_queue.get():
171  if incoming_sample_rate != SAMPLE_RATE:
172  chunk, state = audioop.ratecv(
173  chunk,
174  SAMPLE_WIDTH,
175  SAMPLE_CHANNELS,
176  incoming_sample_rate,
177  SAMPLE_RATE,
178  state,
179  )
180  yield chunk
181 
182  def handle_binary(
183  _hass: HomeAssistant,
184  _connection: websocket_api.ActiveConnection,
185  data: bytes,
186  ) -> None:
187  # Forward to STT audio stream
188  audio_queue.put_nowait(data)
189 
190  handler_id, unregister_handler = connection.async_register_binary_handler(
191  handle_binary
192  )
193 
194  # Audio input must be raw PCM at 16Khz with 16-bit mono samples
195  input_args["stt_metadata"] = stt.SpeechMetadata(
196  language=pipeline.stt_language or pipeline.language,
197  format=stt.AudioFormats.WAV,
198  codec=stt.AudioCodecs.PCM,
199  bit_rate=stt.AudioBitRates.BITRATE_16,
200  sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
201  channel=stt.AudioChannels.CHANNEL_MONO,
202  )
203  input_args["stt_stream"] = stt_stream()
204  input_args["wake_word_phrase"] = wake_word_phrase
205 
206  # Audio settings
207  audio_settings = AudioSettings(
208  noise_suppression_level=msg_input.get("noise_suppression_level", 0),
209  auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
210  volume_multiplier=msg_input.get("volume_multiplier", 1.0),
211  is_vad_enabled=not msg_input.get("no_vad", False),
212  )
213  elif start_stage == PipelineStage.INTENT:
214  # Input to conversation agent
215  input_args["intent_input"] = msg["input"]["text"]
216  elif start_stage == PipelineStage.TTS:
217  # Input to text-to-speech system
218  input_args["tts_input"] = msg["input"]["text"]
219 
220  input_args["run"] = PipelineRun(
221  hass,
222  context=connection.context(msg),
223  pipeline=pipeline,
224  start_stage=start_stage,
225  end_stage=end_stage,
226  event_callback=lambda event: connection.send_event(msg["id"], event),
227  runner_data={
228  "stt_binary_handler_id": handler_id,
229  "timeout": timeout,
230  },
231  wake_word_settings=wake_word_settings,
232  audio_settings=audio_settings or AudioSettings(),
233  )
234 
235  pipeline_input = PipelineInput(**input_args)
236 
237  try:
238  await pipeline_input.validate()
239  except PipelineError as error:
240  # Report more specific error when possible
241  connection.send_error(msg["id"], error.code, error.message)
242  return
243 
244  # Confirm subscription
245  connection.send_result(msg["id"])
246 
247  run_task = hass.async_create_task(pipeline_input.execute())
248 
249  # Cancel pipeline if user unsubscribes
250  connection.subscriptions[msg["id"]] = run_task.cancel
251 
252  try:
253  # Task contains a timeout
254  async with asyncio.timeout(timeout):
255  await run_task
256  except TimeoutError:
257  pipeline_input.run.process_event(
259  PipelineEventType.ERROR,
260  {"code": "timeout", "message": "Timeout running pipeline"},
261  )
262  )
263  finally:
264  if unregister_handler is not None:
265  # Unregister binary handler
266  unregister_handler()
267 
268 
269 @callback
270 @websocket_api.require_admin
271 @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/pipeline_debug/list",
272  vol.Required("pipeline_id"): str,
273  }
274 )
276  hass: HomeAssistant,
278  msg: dict[str, Any],
279 ) -> None:
280  """List pipeline runs for which debug data is available."""
281  pipeline_data: PipelineData = hass.data[DOMAIN]
282  pipeline_id = msg["pipeline_id"]
283 
284  if pipeline_id not in pipeline_data.pipeline_debug:
285  connection.send_result(msg["id"], {"pipeline_runs": []})
286  return
287 
288  pipeline_debug = pipeline_data.pipeline_debug[pipeline_id]
289 
290  connection.send_result(
291  msg["id"],
292  {
293  "pipeline_runs": [
294  {
295  "pipeline_run_id": pipeline_run_id,
296  "timestamp": pipeline_run.timestamp,
297  }
298  for pipeline_run_id, pipeline_run in pipeline_debug.items()
299  ]
300  },
301  )
302 
303 
304 @callback
305 @websocket_api.require_admin
306 @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/device/list",
307  }
308 )
310  hass: HomeAssistant,
312  msg: dict[str, Any],
313 ) -> None:
314  """List assist devices."""
315  pipeline_data: PipelineData = hass.data[DOMAIN]
316  ent_reg = er.async_get(hass)
317  connection.send_result(
318  msg["id"],
319  [
320  {
321  "device_id": device_id,
322  "pipeline_entity": ent_reg.async_get_entity_id(
323  "select", info.domain, f"{info.unique_id_prefix}-pipeline"
324  ),
325  }
326  for device_id, info in pipeline_data.pipeline_devices.items()
327  ],
328  )
329 
330 
331 @callback
332 @websocket_api.require_admin
333 @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/pipeline_debug/get",
334  vol.Required("pipeline_id"): str,
335  vol.Required("pipeline_run_id"): str,
336  }
337 )
339  hass: HomeAssistant,
341  msg: dict[str, Any],
342 ) -> None:
343  """Get debug data for a pipeline run."""
344  pipeline_data: PipelineData = hass.data[DOMAIN]
345  pipeline_id = msg["pipeline_id"]
346  pipeline_run_id = msg["pipeline_run_id"]
347 
348  if pipeline_id not in pipeline_data.pipeline_debug:
349  connection.send_error(
350  msg["id"],
351  websocket_api.ERR_NOT_FOUND,
352  f"pipeline_id {pipeline_id} not found",
353  )
354  return
355 
356  pipeline_debug = pipeline_data.pipeline_debug[pipeline_id]
357 
358  if pipeline_run_id not in pipeline_debug:
359  connection.send_error(
360  msg["id"],
361  websocket_api.ERR_NOT_FOUND,
362  f"pipeline_run_id {pipeline_run_id} not found",
363  )
364  return
365 
366  connection.send_result(
367  msg["id"],
368  {"events": pipeline_debug[pipeline_run_id].events},
369  )
370 
371 
372 @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/language/list",
373  }
374 )
375 @callback
377  hass: HomeAssistant,
379  msg: dict[str, Any],
380 ) -> None:
381  """List languages which are supported by a complete pipeline.
382 
383  This will return a list of languages which are supported by at least one stt, tts
384  and conversation engine respectively.
385  """
386  conv_language_tags = conversation.async_get_conversation_languages(hass)
387  stt_language_tags = stt.async_get_speech_to_text_languages(hass)
388  tts_language_tags = tts.async_get_text_to_speech_languages(hass)
389  pipeline_languages: set[str] | None = None
390 
391  if conv_language_tags and conv_language_tags != MATCH_ALL:
392  languages = set()
393  for language_tag in conv_language_tags:
394  dialect = language_util.Dialect.parse(language_tag)
395  languages.add(dialect.language)
396  pipeline_languages = languages
397 
398  if stt_language_tags:
399  languages = set()
400  for language_tag in stt_language_tags:
401  dialect = language_util.Dialect.parse(language_tag)
402  languages.add(dialect.language)
403  if pipeline_languages is not None:
404  pipeline_languages = language_util.intersect(pipeline_languages, languages)
405  else:
406  pipeline_languages = languages
407 
408  if tts_language_tags:
409  languages = set()
410  for language_tag in tts_language_tags:
411  dialect = language_util.Dialect.parse(language_tag)
412  languages.add(dialect.language)
413  if pipeline_languages is not None:
414  pipeline_languages = language_util.intersect(pipeline_languages, languages)
415  else:
416  pipeline_languages = languages
417 
418  connection.send_result(
419  msg["id"],
420  {
421  "languages": (
422  sorted(pipeline_languages) if pipeline_languages else pipeline_languages
423  )
424  },
425  )
426 
427 
428 @websocket_api.require_admin
429 @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/device/capture",
430  vol.Required("device_id"): str,
431  vol.Required("timeout"): vol.All(
432  # 0 < timeout <= MAX_CAPTURE_TIMEOUT
433  vol.Coerce(float),
434  vol.Range(min=0, min_included=False, max=MAX_CAPTURE_TIMEOUT),
435  ),
436  }
437 )
438 @websocket_api.async_response
439 async def websocket_device_capture(
440  hass: HomeAssistant,
442  msg: dict[str, Any],
443 ) -> None:
444  """Capture raw audio from a satellite device and forward to client."""
445  pipeline_data: PipelineData = hass.data[DOMAIN]
446  device_id = msg["device_id"]
447 
448  # Number of seconds to record audio in wall clock time
449  timeout_seconds = msg["timeout"]
450 
451  # We don't know the chunk size, so the upper bound is calculated assuming a
452  # single sample (16 bits) per queue item.
453  max_queue_items = (
454  # +1 for None to signal end
455  int(math.ceil(timeout_seconds * CAPTURE_RATE)) + 1
456  )
457 
458  audio_queue = DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items))
459 
460  # Running simultaneous captures for a single device will not work by design.
461  # The new capture will cause the old capture to stop.
462  if (
463  old_audio_queue := pipeline_data.device_audio_queues.pop(device_id, None)
464  ) is not None:
465  with contextlib.suppress(asyncio.QueueFull):
466  # Signal other websocket command that we're taking over
467  old_audio_queue.queue.put_nowait(None)
468 
469  # Only one client can be capturing audio at a time
470  pipeline_data.device_audio_queues[device_id] = audio_queue
471 
472  def clean_up_queue() -> None:
473  # Clean up our audio queue
474  maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id)
475  if (maybe_audio_queue is not None) and (maybe_audio_queue.id == audio_queue.id):
476  # Only pop if this is our queue
477  pipeline_data.device_audio_queues.pop(device_id)
478 
479  # Unsubscribe cleans up queue
480  connection.subscriptions[msg["id"]] = clean_up_queue
481 
482  # Audio will follow as events
483  connection.send_result(msg["id"])
484 
485  # Record to logbook
486  hass.bus.async_fire(
487  EVENT_RECORDING,
488  {
489  ATTR_DEVICE_ID: device_id,
490  ATTR_SECONDS: timeout_seconds,
491  },
492  )
493 
494  try:
495  with contextlib.suppress(TimeoutError):
496  async with asyncio.timeout(timeout_seconds):
497  while True:
498  # Send audio chunks encoded as base64
499  audio_bytes = await audio_queue.queue.get()
500  if audio_bytes is None:
501  # Signal to stop
502  break
503 
504  connection.send_event(
505  msg["id"],
506  {
507  "type": "audio",
508  "rate": CAPTURE_RATE, # hertz
509  "width": CAPTURE_WIDTH, # bytes
510  "channels": CAPTURE_CHANNELS,
511  "audio": base64.b64encode(audio_bytes).decode("ascii"),
512  },
513  )
514 
515  # Capture has ended
516  connection.send_event(
517  msg["id"], {"type": "end", "overflow": audio_queue.overflow}
518  )
519  finally:
520  clean_up_queue()
521 
Pipeline async_get_pipeline(HomeAssistant hass, str|None pipeline_id=None)
Definition: pipeline.py:282
None websocket_run(HomeAssistant hass, websocket_api.ActiveConnection connection, dict[str, Any] msg)
None websocket_list_languages(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_device_capture(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_get_run(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_list_runs(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_list_devices(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
web.Response get(self, web.Request request, str config_key)
Definition: view.py:88