1 """Assist pipeline Websocket API."""
8 from collections.abc
import AsyncGenerator, Callable
12 from typing
import Any, Final
14 import voluptuous
as vol
23 DEFAULT_PIPELINE_TIMEOUT,
24 DEFAULT_WAKE_WORD_TIMEOUT,
31 from .error
import PipelineNotFound
32 from .pipeline
import (
46 _LOGGER = logging.getLogger(__name__)
48 CAPTURE_RATE: Final = 16000
49 CAPTURE_WIDTH: Final = 2
50 CAPTURE_CHANNELS: Final = 1
51 MAX_CAPTURE_TIMEOUT: Final = 60.0
56 """Register the websocket API."""
57 websocket_api.async_register_command(hass, websocket_run)
58 websocket_api.async_register_command(hass, websocket_list_languages)
59 websocket_api.async_register_command(hass, websocket_list_runs)
60 websocket_api.async_register_command(hass, websocket_list_devices)
61 websocket_api.async_register_command(hass, websocket_get_run)
62 websocket_api.async_register_command(hass, websocket_device_capture)
65 @websocket_api.websocket_command(
vol.All(
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"):
"assist_pipeline/run",
70 vol.Optional(
"input"): dict,
71 vol.Optional(
"pipeline"): str,
72 vol.Optional(
"conversation_id"): vol.Any(str,
None),
73 vol.Optional(
"device_id"): vol.Any(str,
None),
74 vol.Optional(
"timeout"): vol.Any(float, int),
80 PipelineStage.WAKE_WORD: vol.Schema(
82 vol.Required(
"input"): {
83 vol.Required(
"sample_rate"): int,
84 vol.Optional(
"timeout"): vol.Any(float, int),
85 vol.Optional(
"audio_seconds_to_buffer"): vol.Any(
89 vol.Optional(
"noise_suppression_level"): int,
90 vol.Optional(
"auto_gain_dbfs"): int,
91 vol.Optional(
"volume_multiplier"): float,
93 vol.Optional(
"no_vad"): bool,
96 extra=vol.ALLOW_EXTRA,
98 PipelineStage.STT: vol.Schema(
100 vol.Required(
"input"): {
101 vol.Required(
"sample_rate"): int,
102 vol.Optional(
"wake_word_phrase"): str,
105 extra=vol.ALLOW_EXTRA,
107 PipelineStage.INTENT: vol.Schema(
108 {vol.Required(
"input"): {
"text": str}},
109 extra=vol.ALLOW_EXTRA,
111 PipelineStage.TTS: vol.Schema(
112 {vol.Required(
"input"): {
"text": str}},
113 extra=vol.ALLOW_EXTRA,
119 @websocket_api.async_response
125 """Run a pipeline."""
126 pipeline_id = msg.get(
"pipeline")
129 except PipelineNotFound:
130 connection.send_error(
132 "pipeline-not-found",
133 f
"Pipeline not found: id={pipeline_id}",
137 timeout = msg.get(
"timeout", DEFAULT_PIPELINE_TIMEOUT)
140 handler_id: int |
None =
None
141 unregister_handler: Callable[[],
None] |
None =
None
142 wake_word_settings: WakeWordSettings |
None =
None
143 audio_settings: AudioSettings |
None =
None
146 input_args: dict[str, Any] = {
147 "conversation_id": msg.get(
"conversation_id"),
148 "device_id": msg.get(
"device_id"),
151 if start_stage
in (PipelineStage.WAKE_WORD, PipelineStage.STT):
153 msg_input = msg[
"input"]
154 audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
155 incoming_sample_rate = msg_input[
"sample_rate"]
156 wake_word_phrase: str |
None =
None
158 if start_stage == PipelineStage.WAKE_WORD:
160 timeout=msg[
"input"].
get(
"timeout", DEFAULT_WAKE_WORD_TIMEOUT),
161 audio_seconds_to_buffer=msg_input.get(
"audio_seconds_to_buffer", 0),
163 elif start_stage == PipelineStage.STT:
164 wake_word_phrase = msg[
"input"].
get(
"wake_word_phrase")
166 async
def stt_stream() -> AsyncGenerator[bytes]:
170 while chunk := await audio_queue.get():
171 if incoming_sample_rate != SAMPLE_RATE:
172 chunk, state = audioop.ratecv(
176 incoming_sample_rate,
183 _hass: HomeAssistant,
188 audio_queue.put_nowait(data)
190 handler_id, unregister_handler = connection.async_register_binary_handler(
196 language=pipeline.stt_language
or pipeline.language,
197 format=stt.AudioFormats.WAV,
198 codec=stt.AudioCodecs.PCM,
199 bit_rate=stt.AudioBitRates.BITRATE_16,
200 sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
201 channel=stt.AudioChannels.CHANNEL_MONO,
203 input_args[
"stt_stream"] = stt_stream()
204 input_args[
"wake_word_phrase"] = wake_word_phrase
208 noise_suppression_level=msg_input.get(
"noise_suppression_level", 0),
209 auto_gain_dbfs=msg_input.get(
"auto_gain_dbfs", 0),
210 volume_multiplier=msg_input.get(
"volume_multiplier", 1.0),
211 is_vad_enabled=
not msg_input.get(
"no_vad",
False),
213 elif start_stage == PipelineStage.INTENT:
215 input_args[
"intent_input"] = msg[
"input"][
"text"]
216 elif start_stage == PipelineStage.TTS:
218 input_args[
"tts_input"] = msg[
"input"][
"text"]
222 context=connection.context(msg),
224 start_stage=start_stage,
226 event_callback=
lambda event: connection.send_event(msg[
"id"], event),
228 "stt_binary_handler_id": handler_id,
231 wake_word_settings=wake_word_settings,
238 await pipeline_input.validate()
239 except PipelineError
as error:
241 connection.send_error(msg[
"id"], error.code, error.message)
245 connection.send_result(msg[
"id"])
247 run_task = hass.async_create_task(pipeline_input.execute())
250 connection.subscriptions[msg[
"id"]] = run_task.cancel
254 async
with asyncio.timeout(timeout):
257 pipeline_input.run.process_event(
259 PipelineEventType.ERROR,
260 {
"code":
"timeout",
"message":
"Timeout running pipeline"},
264 if unregister_handler
is not None:
270 @websocket_api.require_admin
271 @websocket_api.websocket_command(
{
vol.Required("type"):
"assist_pipeline/pipeline_debug/list",
272 vol.Required(
"pipeline_id"): str,
280 """List pipeline runs for which debug data is available."""
281 pipeline_data: PipelineData = hass.data[DOMAIN]
282 pipeline_id = msg[
"pipeline_id"]
284 if pipeline_id
not in pipeline_data.pipeline_debug:
285 connection.send_result(msg[
"id"], {
"pipeline_runs": []})
288 pipeline_debug = pipeline_data.pipeline_debug[pipeline_id]
290 connection.send_result(
295 "pipeline_run_id": pipeline_run_id,
296 "timestamp": pipeline_run.timestamp,
298 for pipeline_run_id, pipeline_run
in pipeline_debug.items()
305 @websocket_api.require_admin
306 @websocket_api.websocket_command(
{
vol.Required("type"):
"assist_pipeline/device/list",
314 """List assist devices."""
315 pipeline_data: PipelineData = hass.data[DOMAIN]
316 ent_reg = er.async_get(hass)
317 connection.send_result(
321 "device_id": device_id,
322 "pipeline_entity": ent_reg.async_get_entity_id(
323 "select", info.domain, f
"{info.unique_id_prefix}-pipeline"
326 for device_id, info
in pipeline_data.pipeline_devices.items()
332 @websocket_api.require_admin
333 @websocket_api.websocket_command(
{
vol.Required("type"):
"assist_pipeline/pipeline_debug/get",
334 vol.Required(
"pipeline_id"): str,
335 vol.Required(
"pipeline_run_id"): str,
343 """Get debug data for a pipeline run."""
344 pipeline_data: PipelineData = hass.data[DOMAIN]
345 pipeline_id = msg[
"pipeline_id"]
346 pipeline_run_id = msg[
"pipeline_run_id"]
348 if pipeline_id
not in pipeline_data.pipeline_debug:
349 connection.send_error(
351 websocket_api.ERR_NOT_FOUND,
352 f
"pipeline_id {pipeline_id} not found",
356 pipeline_debug = pipeline_data.pipeline_debug[pipeline_id]
358 if pipeline_run_id
not in pipeline_debug:
359 connection.send_error(
361 websocket_api.ERR_NOT_FOUND,
362 f
"pipeline_run_id {pipeline_run_id} not found",
366 connection.send_result(
368 {
"events": pipeline_debug[pipeline_run_id].events},
372 @websocket_api.websocket_command(
{
vol.Required("type"):
"assist_pipeline/language/list",
381 """List languages which are supported by a complete pipeline.
383 This will return a list of languages which are supported by at least one stt, tts
384 and conversation engine respectively.
386 conv_language_tags = conversation.async_get_conversation_languages(hass)
387 stt_language_tags = stt.async_get_speech_to_text_languages(hass)
388 tts_language_tags = tts.async_get_text_to_speech_languages(hass)
389 pipeline_languages: set[str] |
None =
None
391 if conv_language_tags
and conv_language_tags != MATCH_ALL:
393 for language_tag
in conv_language_tags:
394 dialect = language_util.Dialect.parse(language_tag)
395 languages.add(dialect.language)
396 pipeline_languages = languages
398 if stt_language_tags:
400 for language_tag
in stt_language_tags:
401 dialect = language_util.Dialect.parse(language_tag)
402 languages.add(dialect.language)
403 if pipeline_languages
is not None:
404 pipeline_languages = language_util.intersect(pipeline_languages, languages)
406 pipeline_languages = languages
408 if tts_language_tags:
410 for language_tag
in tts_language_tags:
411 dialect = language_util.Dialect.parse(language_tag)
412 languages.add(dialect.language)
413 if pipeline_languages
is not None:
414 pipeline_languages = language_util.intersect(pipeline_languages, languages)
416 pipeline_languages = languages
418 connection.send_result(
422 sorted(pipeline_languages)
if pipeline_languages
else pipeline_languages
428 @websocket_api.require_admin
429 @websocket_api.websocket_command(
{
vol.Required("type"):
"assist_pipeline/device/capture",
430 vol.Required(
"device_id"): str,
431 vol.Required(
"timeout"): vol.All(
434 vol.Range(min=0, min_included=
False, max=MAX_CAPTURE_TIMEOUT),
438 @websocket_api.async_response
444 """Capture raw audio from a satellite device and forward to client."""
445 pipeline_data: PipelineData = hass.data[DOMAIN]
446 device_id = msg[
"device_id"]
449 timeout_seconds = msg[
"timeout"]
455 int(math.ceil(timeout_seconds * CAPTURE_RATE)) + 1
458 audio_queue =
DeviceAudioQueue(queue=asyncio.Queue(maxsize=max_queue_items))
463 old_audio_queue := pipeline_data.device_audio_queues.pop(device_id,
None)
465 with contextlib.suppress(asyncio.QueueFull):
467 old_audio_queue.queue.put_nowait(
None)
470 pipeline_data.device_audio_queues[device_id] = audio_queue
472 def clean_up_queue() -> None:
474 maybe_audio_queue = pipeline_data.device_audio_queues.get(device_id)
475 if (maybe_audio_queue
is not None)
and (maybe_audio_queue.id == audio_queue.id):
477 pipeline_data.device_audio_queues.pop(device_id)
480 connection.subscriptions[msg[
"id"]] = clean_up_queue
483 connection.send_result(msg[
"id"])
489 ATTR_DEVICE_ID: device_id,
490 ATTR_SECONDS: timeout_seconds,
495 with contextlib.suppress(TimeoutError):
496 async
with asyncio.timeout(timeout_seconds):
499 audio_bytes = await audio_queue.queue.get()
500 if audio_bytes
is None:
504 connection.send_event(
508 "rate": CAPTURE_RATE,
509 "width": CAPTURE_WIDTH,
510 "channels": CAPTURE_CHANNELS,
511 "audio": base64.b64encode(audio_bytes).decode(
"ascii"),
516 connection.send_event(
517 msg[
"id"], {
"type":
"end",
"overflow": audio_queue.overflow}
521
Pipeline async_get_pipeline(HomeAssistant hass, str|None pipeline_id=None)
None websocket_run(HomeAssistant hass, websocket_api.ActiveConnection connection, dict[str, Any] msg)
None websocket_list_languages(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_device_capture(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_get_run(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_list_runs(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None websocket_list_devices(HomeAssistant hass, websocket_api.connection.ActiveConnection connection, dict[str, Any] msg)
None async_register_websocket_api(HomeAssistant hass)
web.Response get(self, web.Request request, str config_key)