1 """Assist satellite entity for Wyoming integration."""
3 from __future__
import annotations
6 from collections.abc
import AsyncGenerator
9 from typing
import Any, Final
12 from wyoming.asr
import Transcribe, Transcript
13 from wyoming.audio
import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
14 from wyoming.client
import AsyncTcpClient
15 from wyoming.error
import Error
16 from wyoming.event
import Event
17 from wyoming.info
import Describe, Info
18 from wyoming.ping
import Ping, Pong
19 from wyoming.pipeline
import PipelineStage, RunPipeline
20 from wyoming.satellite
import PauseSatellite, RunSatellite
21 from wyoming.snd
import Played
22 from wyoming.timer
import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
23 from wyoming.tts
import Synthesize, SynthesizeVoice
24 from wyoming.vad
import VoiceStarted, VoiceStopped
25 from wyoming.wake
import Detect, Detection
30 AssistSatelliteConfiguration,
31 AssistSatelliteEntity,
32 AssistSatelliteEntityDescription,
38 from .const
import DOMAIN
39 from .data
import WyomingService
40 from .devices
import SatelliteDevice
41 from .entity
import WyomingSatelliteEntity
42 from .models
import DomainDataItem
44 _LOGGER = logging.getLogger(__name__)
46 _SAMPLES_PER_CHUNK: Final = 1024
47 _RECONNECT_SECONDS: Final = 10
48 _RESTART_SECONDS: Final = 3
49 _PING_TIMEOUT: Final = 5
50 _PING_SEND_DELAY: Final = 2
51 _PIPELINE_FINISH_TIMEOUT: Final = 1
55 PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
56 PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
57 PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
58 PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
64 config_entry: ConfigEntry,
65 async_add_entities: AddEntitiesCallback,
67 """Set up Wyoming Assist satellite entity."""
68 domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
69 assert domain_data.device
is not None
74 hass, domain_data.service, domain_data.device, config_entry
81 """Assist satellite for Wyoming devices."""
84 _attr_translation_key =
"assist_satellite"
90 service: WyomingService,
91 device: SatelliteDevice,
92 config_entry: ConfigEntry,
94 """Initialize an Assist satellite."""
95 WyomingSatelliteEntity.__init__(self, device)
96 AssistSatelliteEntity.__init__(self)
104 self.
_client_client: AsyncTcpClient |
None =
None
105 self.
_chunk_converter_chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
108 self.
_audio_queue_audio_queue: asyncio.Queue[bytes |
None] = asyncio.Queue()
109 self._pipeline_id: str |
None =
None
121 """Return the entity ID of the pipeline to use for the next conversation."""
122 return self.
devicedevice.get_pipeline_entity_id(self.hass)
126 """Return the entity ID of the VAD sensitivity to use for the next conversation."""
127 return self.
devicedevice.get_vad_sensitivity_entity_id(self.hass)
131 """Options passed for text-to-speech."""
133 tts.ATTR_PREFERRED_FORMAT:
"wav",
134 tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
135 tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
136 tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
140 """Run when entity about to be added to hass."""
145 """Run when entity will be removed from hass."""
152 ) -> AssistSatelliteConfiguration:
153 """Get the current satellite configuration."""
154 raise NotImplementedError
157 self, config: AssistSatelliteConfiguration
159 """Set the current satellite configuration."""
160 raise NotImplementedError
163 """Set state based on pipeline stage."""
164 assert self.
_client_client
is not None
166 if event.type == assist_pipeline.PipelineEventType.RUN_END:
170 self.
devicedevice.set_is_active(
False)
171 elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
172 self.hass.add_job(self.
_client_client.write_event(Detect().event()))
173 elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
176 if event.data
and (wake_word_output := event.data.get(
"wake_word_output")):
177 detection = Detection(
178 name=wake_word_output[
"wake_word_id"],
179 timestamp=wake_word_output.get(
"timestamp"),
181 self.hass.add_job(self.
_client_client.write_event(detection.event()))
182 elif event.type == assist_pipeline.PipelineEventType.STT_START:
184 self.
devicedevice.set_is_active(
True)
188 self.
_client_client.write_event(
189 Transcribe(language=event.data[
"metadata"][
"language"]).event()
192 elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
196 self.
_client_client.write_event(
197 VoiceStarted(timestamp=event.data[
"timestamp"]).event()
200 elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
204 self.
_client_client.write_event(
205 VoiceStopped(timestamp=event.data[
"timestamp"]).event()
208 elif event.type == assist_pipeline.PipelineEventType.STT_END:
212 stt_text = event.data[
"stt_output"][
"text"]
214 self.
_client_client.write_event(Transcript(text=stt_text).event())
216 elif event.type == assist_pipeline.PipelineEventType.TTS_START:
221 self.
_client_client.write_event(
223 text=event.data[
"tts_input"],
224 voice=SynthesizeVoice(
225 name=event.data.get(
"voice"),
226 language=event.data.get(
"language"),
231 elif event.type == assist_pipeline.PipelineEventType.TTS_END:
233 if event.data
and (tts_output := event.data[
"tts_output"]):
234 media_id = tts_output[
"media_id"]
235 self.hass.add_job(self.
_stream_tts_stream_tts(media_id))
236 elif event.type == assist_pipeline.PipelineEventType.ERROR:
240 self.
_client_client.write_event(
242 text=event.data[
"message"], code=event.data[
"code"]
250 """Start satellite task."""
253 self.
config_entryconfig_entry.async_create_background_task(
254 self.hass, self.
runrun(),
"wyoming satellite run"
258 """Signal satellite task to stop running."""
273 async
def run(self) -> None:
274 """Run and maintain a connection to satellite."""
275 _LOGGER.debug(
"Running satellite task")
277 unregister_timer_handler = intent.async_register_timer_handler(
285 while self.
devicedevice.is_muted:
286 _LOGGER.debug(
"Satellite is muted")
294 except asyncio.CancelledError:
296 except Exception
as err:
297 _LOGGER.debug(
"%s: %s", err.__class__.__name__,
str(err))
303 self.
devicedevice.set_is_active(
False)
308 unregister_timer_handler()
311 self.
devicedevice.set_is_active(
False)
316 """Block until pipeline loop will be restarted."""
318 "Satellite has been disconnected. Reconnecting in %s second(s)",
321 await asyncio.sleep(_RESTART_SECONDS)
324 """Block until a reconnection attempt should be made."""
326 "Failed to connect to satellite. Reconnecting in %s second(s)",
329 await asyncio.sleep(_RECONNECT_SECONDS)
332 """Block until device may be unmuted again."""
336 """Run when run() has fully stopped."""
337 _LOGGER.debug(
"Satellite task stopped")
342 """Send a pause message to satellite."""
343 if self.
_client_client
is not None:
344 self.
config_entryconfig_entry.async_create_background_task(
346 self.
_client_client.write_event(PauseSatellite().event()),
351 """Run when device muted status changes."""
352 if self.
devicedevice.is_muted:
363 """Run when device pipeline changes."""
369 """Run when device audio settings."""
375 """Connect to satellite and run pipelines until an error occurs."""
380 except ConnectionError:
385 if self.
_client_client
is None:
388 _LOGGER.debug(
"Connected to satellite")
395 await self.
_client_client.write_event(RunSatellite().event())
402 """Run a pipeline one or more times."""
403 assert self.
_client_client
is not None
404 client_info: Info |
None =
None
405 wake_word_phrase: str |
None =
None
406 run_pipeline: RunPipeline |
None =
None
410 pipeline_ended_task = self.
config_entryconfig_entry.async_create_background_task(
413 client_event_task = self.
config_entryconfig_entry.async_create_background_task(
414 self.hass, self.
_client_client.read_event(),
"satellite event read"
416 pending = {pipeline_ended_task, client_event_task}
419 await self.
_client_client.write_event(Describe().event())
425 self.
config_entryconfig_entry.async_create_background_task(
429 async
with asyncio.timeout(_PING_TIMEOUT):
430 done, pending = await asyncio.wait(
431 pending, return_when=asyncio.FIRST_COMPLETED
434 if pipeline_ended_task
in done:
436 _LOGGER.debug(
"Pipeline finished")
438 pipeline_ended_task = (
439 self.
config_entryconfig_entry.async_create_background_task(
442 "satellite pipeline ended",
445 pending.add(pipeline_ended_task)
448 wake_word_phrase =
None
450 if (run_pipeline
is not None)
and run_pipeline.restart_on_end:
456 if client_event_task
not in done:
459 client_event = client_event_task.result()
460 if client_event
is None:
461 raise ConnectionResetError(
"Satellite disconnected")
463 if Pong.is_type(client_event.type):
466 elif Ping.is_type(client_event.type):
468 ping = Ping.from_event(client_event)
469 await self.
_client_client.write_event(Pong(text=ping.text).event())
470 elif RunPipeline.is_type(client_event.type):
472 run_pipeline = RunPipeline.from_event(client_event)
478 chunk = AudioChunk.from_event(client_event)
483 _LOGGER.debug(
"Client requested pipeline to stop")
485 elif Info.is_type(client_event.type):
486 client_info = Info.from_event(client_event)
487 _LOGGER.debug(
"Updated client info: %s", client_info)
488 elif Detection.is_type(client_event.type):
489 detection = Detection.from_event(client_event)
490 wake_word_phrase = detection.name
496 if (client_info
is not None)
and (client_info.wake
is not None):
498 for wake_service
in client_info.wake:
499 for wake_model
in wake_service.models:
500 if wake_model.name == detection.name:
502 wake_model.phrase
or wake_model.name
510 _LOGGER.debug(
"Client detected wake word: %s", wake_word_phrase)
511 elif Played.is_type(client_event.type):
515 _LOGGER.debug(
"Unexpected event from satellite: %s", client_event)
518 client_event_task = self.
config_entryconfig_entry.async_create_background_task(
519 self.hass, self.
_client_client.read_event(),
"satellite event read"
521 pending.add(client_event_task)
524 self, run_pipeline: RunPipeline, wake_word_phrase: str |
None =
None
526 """Run a pipeline once."""
527 _LOGGER.debug(
"Received run information: %s", run_pipeline)
529 start_stage = _STAGES.get(run_pipeline.start_stage)
530 end_stage = _STAGES.get(run_pipeline.end_stage)
532 if start_stage
is None:
533 raise ValueError(f
"Invalid start stage: {start_stage}")
535 if end_stage
is None:
536 raise ValueError(f
"Invalid end stage: {end_stage}")
543 self.
config_entryconfig_entry.async_create_background_task(
547 start_stage=start_stage,
549 wake_word_phrase=wake_word_phrase,
551 "wyoming satellite pipeline",
555 """Send ping to satellite after a delay."""
556 assert self.
_client_client
is not None
559 await asyncio.sleep(_PING_SEND_DELAY)
560 await self.
_client_client.write_event(Ping().event())
561 except ConnectionError:
565 """Connect to satellite over TCP."""
569 "Connecting to satellite at %s:%s", self.
serviceservice.host, self.
serviceservice.port
572 await self.
_client_client.connect()
575 """Disconnect if satellite is currently connected."""
576 if self.
_client_client
is None:
579 _LOGGER.debug(
"Disconnecting from satellite")
580 await self.
_client_client.disconnect()
584 """Stream TTS WAV audio to satellite in chunks."""
585 assert self.
_client_client
is not None
587 extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
588 if extension !=
"wav":
589 raise ValueError(f
"Cannot stream audio format to satellite: {extension}")
591 with io.BytesIO(data)
as wav_io, wave.open(wav_io,
"rb")
as wav_file:
592 sample_rate = wav_file.getframerate()
593 sample_width = wav_file.getsampwidth()
594 sample_channels = wav_file.getnchannels()
595 _LOGGER.debug(
"Streaming %s TTS sample(s)", wav_file.getnframes())
598 await self.
_client_client.write_event(
602 channels=sample_channels,
608 while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
612 channels=sample_channels,
616 await self.
_client_client.write_event(chunk.event())
617 timestamp += chunk.seconds
619 await self.
_client_client.write_event(AudioStop(timestamp=timestamp).event())
620 _LOGGER.debug(
"TTS streaming complete")
623 """Yield audio chunks from a queue."""
624 is_first_chunk =
True
630 is_first_chunk =
False
631 _LOGGER.debug(
"Receiving audio from satellite")
639 """Forward timer events to satellite."""
640 assert self.
_client_client
is not None
642 _LOGGER.debug(
"Timer event: type=%s, info=%s", event_type, timer)
643 event: Event |
None =
None
644 if event_type == intent.TimerEventType.STARTED:
645 event = TimerStarted(
647 total_seconds=timer.seconds,
649 start_hours=timer.start_hours,
650 start_minutes=timer.start_minutes,
651 start_seconds=timer.start_seconds,
653 elif event_type == intent.TimerEventType.UPDATED:
654 event = TimerUpdated(
656 is_active=timer.is_active,
657 total_seconds=timer.seconds,
659 elif event_type == intent.TimerEventType.CANCELLED:
660 event = TimerCancelled(id=timer.id).event()
661 elif event_type == intent.TimerEventType.FINISHED:
662 event = TimerFinished(id=timer.id).event()
664 if event
is not None:
666 self.
config_entryconfig_entry.async_create_background_task(
667 self.hass, self.
_client_client.write_event(event),
"wyoming timer event"
None async_accept_pipeline_from_satellite(self, AsyncIterable[bytes] audio_stream, PipelineStage start_stage=PipelineStage.STT, PipelineStage end_stage=PipelineStage.TTS, str|None wake_word_phrase=None)
None tts_response_finished(self)
None _audio_settings_changed(self)
None _send_delayed_ping(self)
None _run_pipeline_once(self, RunPipeline run_pipeline, str|None wake_word_phrase=None)
None _handle_timer(self, intent.TimerEventType event_type, intent.TimerInfo timer)
None _run_pipeline_loop(self)
None start_satellite(self)
None async_set_configuration(self, AssistSatelliteConfiguration config)
None _stream_tts(self, str media_id)
dict[str, Any]|None tts_options(self)
None _pipeline_changed(self)
str|None pipeline_entity_id(self)
None async_added_to_hass(self)
None __init__(self, HomeAssistant hass, WyomingService service, SatelliteDevice device, ConfigEntry config_entry)
str|None vad_sensitivity_entity_id(self)
None stop_satellite(self)
None on_pipeline_event(self, PipelineEvent event)
AsyncGenerator[bytes] _stt_stream(self)
None _muted_changed(self)
AssistSatelliteConfiguration async_get_configuration(self)
None _connect_and_loop(self)
None async_will_remove_from_hass(self)
web.Response get(self, web.Request request, str config_key)
None async_setup_entry(HomeAssistant hass, ConfigEntry config_entry, AddEntitiesCallback async_add_entities)