Home Assistant Unofficial Reference 2024.12.1
__init__.py
Go to the documentation of this file.
1 """Provide functionality to STT."""
2 
3 from __future__ import annotations
4 
5 from abc import abstractmethod
6 from collections.abc import AsyncIterable
7 from dataclasses import asdict
8 import logging
9 from typing import Any, final
10 
11 from aiohttp import web
12 from aiohttp.hdrs import istr
13 from aiohttp.web_exceptions import (
14  HTTPBadRequest,
15  HTTPNotFound,
16  HTTPUnsupportedMediaType,
17 )
18 import voluptuous as vol
19 
20 from homeassistant.components import websocket_api
21 from homeassistant.components.http import KEY_HASS, HomeAssistantView
22 from homeassistant.config_entries import ConfigEntry
23 from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
24 from homeassistant.core import HomeAssistant, callback
25 from homeassistant.helpers import config_validation as cv
26 from homeassistant.helpers.entity_component import EntityComponent
27 from homeassistant.helpers.restore_state import RestoreEntity
28 from homeassistant.helpers.typing import ConfigType
29 from homeassistant.loader import async_suggest_report_issue
30 from homeassistant.util import dt as dt_util, language as language_util
31 
32 from .const import (
33  DATA_COMPONENT,
34  DATA_PROVIDERS,
35  DOMAIN,
36  AudioBitRates,
37  AudioChannels,
38  AudioCodecs,
39  AudioFormats,
40  AudioSampleRates,
41  SpeechResultState,
42 )
43 from .legacy import (
44  Provider,
45  async_default_provider,
46  async_get_provider,
47  async_setup_legacy,
48 )
49 from .models import SpeechMetadata, SpeechResult
50 
51 __all__ = [
52  "async_get_provider",
53  "async_get_speech_to_text_engine",
54  "async_get_speech_to_text_entity",
55  "AudioBitRates",
56  "AudioChannels",
57  "AudioCodecs",
58  "AudioFormats",
59  "AudioSampleRates",
60  "DOMAIN",
61  "Provider",
62  "SpeechToTextEntity",
63  "SpeechMetadata",
64  "SpeechResult",
65  "SpeechResultState",
66 ]
67 
68 _LOGGER = logging.getLogger(__name__)
69 
70 CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
71 
72 
73 @callback
74 def async_default_engine(hass: HomeAssistant) -> str | None:
75  """Return the domain or entity id of the default engine."""
76  default_entity_id: str | None = None
77 
78  for entity in hass.data[DATA_COMPONENT].entities:
79  if entity.platform and entity.platform.platform_name == "cloud":
80  return entity.entity_id
81 
82  if default_entity_id is None:
83  default_entity_id = entity.entity_id
84 
85  return default_entity_id or async_default_provider(hass)
86 
87 
88 @callback
90  hass: HomeAssistant, entity_id: str
91 ) -> SpeechToTextEntity | None:
92  """Return stt entity."""
93  return hass.data[DATA_COMPONENT].get_entity(entity_id)
94 
95 
96 @callback
98  hass: HomeAssistant, engine_id: str
99 ) -> SpeechToTextEntity | Provider | None:
100  """Return stt entity or legacy provider."""
101  if entity := async_get_speech_to_text_entity(hass, engine_id):
102  return entity
103  return async_get_provider(hass, engine_id)
104 
105 
106 @callback
107 def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
108  """Return a set with the union of languages supported by stt engines."""
109  languages = set()
110 
111  for entity in hass.data[DATA_COMPONENT].entities:
112  for language_tag in entity.supported_languages:
113  languages.add(language_tag)
114 
115  for engine in hass.data[DATA_PROVIDERS].values():
116  for language_tag in engine.supported_languages:
117  languages.add(language_tag)
118 
119  return languages
120 
121 
122 async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
123  """Set up STT."""
124  websocket_api.async_register_command(hass, websocket_list_engines)
125 
126  component = hass.data[DATA_COMPONENT] = EntityComponent[SpeechToTextEntity](
127  _LOGGER, DOMAIN, hass
128  )
129 
130  component.register_shutdown()
131  platform_setups = async_setup_legacy(hass, config)
132 
133  for setup in platform_setups:
134  # Tasks are created as tracked tasks to ensure startup
135  # waits for them to finish, but we explicitly do not
136  # want to wait for them to finish here because we want
137  # any config entries that use stt as a base platform
138  # to be able to start with out having to wait for the
139  # legacy platforms to finish setting up.
140  hass.async_create_task(setup, eager_start=True)
141 
142  hass.http.register_view(SpeechToTextView(hass.data[DATA_PROVIDERS]))
143  return True
144 
145 
146 async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
147  """Set up a config entry."""
148  return await hass.data[DATA_COMPONENT].async_setup_entry(entry)
149 
150 
151 async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
152  """Unload a config entry."""
153  return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
154 
155 
157  """Represent a single STT provider."""
158 
159  _attr_should_poll = False
160  __last_processed: str | None = None
161 
162  @property
163  @final
164  def state(self) -> str | None:
165  """Return the state of the provider entity."""
166  if self.__last_processed__last_processed is None:
167  return None
168  return self.__last_processed__last_processed
169 
170  @property
171  @abstractmethod
172  def supported_languages(self) -> list[str]:
173  """Return a list of supported languages."""
174 
175  @property
176  @abstractmethod
177  def supported_formats(self) -> list[AudioFormats]:
178  """Return a list of supported formats."""
179 
180  @property
181  @abstractmethod
182  def supported_codecs(self) -> list[AudioCodecs]:
183  """Return a list of supported codecs."""
184 
185  @property
186  @abstractmethod
187  def supported_bit_rates(self) -> list[AudioBitRates]:
188  """Return a list of supported bit rates."""
189 
190  @property
191  @abstractmethod
192  def supported_sample_rates(self) -> list[AudioSampleRates]:
193  """Return a list of supported sample rates."""
194 
195  @property
196  @abstractmethod
197  def supported_channels(self) -> list[AudioChannels]:
198  """Return a list of supported channels."""
199 
200  async def async_internal_added_to_hass(self) -> None:
201  """Call when the provider entity is added to hass."""
202  await super().async_internal_added_to_hass()
203  state = await self.async_get_last_state()
204  if (
205  state is not None
206  and state.state is not None
207  and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
208  ):
209  self.__last_processed__last_processed = state.state
210 
211  @final
213  self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
214  ) -> SpeechResult:
215  """Process an audio stream to STT service.
216 
217  Only streaming content is allowed!
218  """
219  self.__last_processed__last_processed = dt_util.utcnow().isoformat()
220  self.async_write_ha_stateasync_write_ha_state()
221  return await self.async_process_audio_streamasync_process_audio_stream(metadata=metadata, stream=stream)
222 
223  @abstractmethod
225  self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
226  ) -> SpeechResult:
227  """Process an audio stream to STT service.
228 
229  Only streaming content is allowed!
230  """
231 
232  @callback
233  def check_metadata(self, metadata: SpeechMetadata) -> bool:
234  """Check if given metadata supported by this provider."""
235  if (
236  metadata.language not in self.supported_languagessupported_languages
237  or metadata.format not in self.supported_formatssupported_formats
238  or metadata.codec not in self.supported_codecssupported_codecs
239  or metadata.bit_rate not in self.supported_bit_ratessupported_bit_rates
240  or metadata.sample_rate not in self.supported_sample_ratessupported_sample_rates
241  or metadata.channel not in self.supported_channelssupported_channels
242  ):
243  return False
244  return True
245 
246 
247 class SpeechToTextView(HomeAssistantView):
248  """STT view to generate a text from audio stream."""
249 
250  _legacy_provider_reported = False
251  requires_auth = True
252  url = "/api/stt/{provider}"
253  name = "api:stt:provider"
254 
255  def __init__(self, providers: dict[str, Provider]) -> None:
256  """Initialize a tts view."""
257  self.providersproviders = providers
258 
259  async def post(self, request: web.Request, provider: str) -> web.Response:
260  """Convert Speech (audio) to text."""
261  hass = request.app[KEY_HASS]
262  provider_entity: SpeechToTextEntity | None = None
263  if (
264  not (provider_entity := async_get_speech_to_text_entity(hass, provider))
265  and provider not in self.providersproviders
266  ):
267  raise HTTPNotFound
268 
269  # Get metadata
270  try:
271  metadata = _metadata_from_header(request)
272  except ValueError as err:
273  raise HTTPBadRequest(text=str(err)) from err
274 
275  if not provider_entity:
276  stt_provider = self._get_provider_get_provider(hass, provider)
277 
278  # Check format
279  if not stt_provider.check_metadata(metadata):
280  raise HTTPUnsupportedMediaType
281 
282  # Process audio stream
283  result = await stt_provider.async_process_audio_stream(
284  metadata, request.content
285  )
286  else:
287  # Check format
288  if not provider_entity.check_metadata(metadata):
289  raise HTTPUnsupportedMediaType
290 
291  # Process audio stream
292  result = await provider_entity.internal_async_process_audio_stream(
293  metadata, request.content
294  )
295 
296  # Return result
297  return self.json(asdict(result))
298 
299  async def get(self, request: web.Request, provider: str) -> web.Response:
300  """Return provider specific audio information."""
301  hass = request.app[KEY_HASS]
302  if (
303  not (provider_entity := async_get_speech_to_text_entity(hass, provider))
304  and provider not in self.providersproviders
305  ):
306  raise HTTPNotFound
307 
308  if not provider_entity:
309  stt_provider = self._get_provider_get_provider(hass, provider)
310 
311  return self.json(
312  {
313  "languages": stt_provider.supported_languages,
314  "formats": stt_provider.supported_formats,
315  "codecs": stt_provider.supported_codecs,
316  "sample_rates": stt_provider.supported_sample_rates,
317  "bit_rates": stt_provider.supported_bit_rates,
318  "channels": stt_provider.supported_channels,
319  }
320  )
321 
322  return self.json(
323  {
324  "languages": provider_entity.supported_languages,
325  "formats": provider_entity.supported_formats,
326  "codecs": provider_entity.supported_codecs,
327  "sample_rates": provider_entity.supported_sample_rates,
328  "bit_rates": provider_entity.supported_bit_rates,
329  "channels": provider_entity.supported_channels,
330  }
331  )
332 
333  def _get_provider(self, hass: HomeAssistant, provider: str) -> Provider:
334  """Get provider.
335 
336  Method for legacy providers.
337  This can be removed when we remove the legacy provider support.
338  """
339  stt_provider = self.providersproviders[provider]
340 
341  if not self._legacy_provider_reported_legacy_provider_reported_legacy_provider_reported:
342  self._legacy_provider_reported_legacy_provider_reported_legacy_provider_reported = True
343  report_issue = self._suggest_report_issue_suggest_report_issue(hass, provider, stt_provider)
344  # This should raise in Home Assistant Core 2023.9
345  _LOGGER.warning(
346  "Provider %s (%s) is using a legacy implementation, "
347  "and should be updated to use the SpeechToTextEntity. Please "
348  "%s",
349  provider,
350  type(stt_provider),
351  report_issue,
352  )
353 
354  return stt_provider
355 
357  self, hass: HomeAssistant, provider: str, provider_instance: object
358  ) -> str:
359  """Suggest to report an issue."""
361  hass, integration_domain=provider, module=type(provider_instance).__module__
362  )
363 
364 
365 def _metadata_from_header(request: web.Request) -> SpeechMetadata:
366  """Extract STT metadata from header.
367 
368  X-Speech-Content:
369  format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
370  """
371  try:
372  data = request.headers[istr("X-Speech-Content")].split(";")
373  except KeyError as err:
374  raise ValueError("Missing X-Speech-Content header") from err
375 
376  fields = (
377  "language",
378  "format",
379  "codec",
380  "bit_rate",
381  "sample_rate",
382  "channel",
383  )
384 
385  # Convert Header data
386  args: dict[str, Any] = {}
387  for entry in data:
388  key, _, value = entry.strip().partition("=")
389  if key not in fields:
390  raise ValueError(f"Invalid field: {key}")
391  args[key] = value
392 
393  for field in fields:
394  if field not in args:
395  raise ValueError(f"Missing {field} in X-Speech-Content header")
396 
397  try:
398  return SpeechMetadata(
399  language=args["language"],
400  format=args["format"],
401  codec=args["codec"],
402  bit_rate=args["bit_rate"],
403  sample_rate=args["sample_rate"],
404  channel=args["channel"],
405  )
406  except ValueError as err:
407  raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
408 
409 
410 @websocket_api.websocket_command( { "type": "stt/engine/list", vol.Optional("language"): str,
411  vol.Optional("country"): str,
412  }
413 )
414 @callback
416  hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
417 ) -> None:
418  """List speech-to-text engines and, optionally, if they support a given language."""
419  country = msg.get("country")
420  language = msg.get("language")
421  providers = []
422  provider_info: dict[str, Any]
423 
424  for entity in hass.data[DATA_COMPONENT].entities:
425  provider_info = {
426  "engine_id": entity.entity_id,
427  "supported_languages": entity.supported_languages,
428  }
429  if language:
430  provider_info["supported_languages"] = language_util.matches(
431  language, entity.supported_languages, country
432  )
433  providers.append(provider_info)
434 
435  for engine_id, provider in hass.data[DATA_PROVIDERS].items():
436  provider_info = {
437  "engine_id": engine_id,
438  "name": provider.name,
439  "supported_languages": provider.supported_languages,
440  }
441  if language:
442  provider_info["supported_languages"] = language_util.matches(
443  language, provider.supported_languages, country
444  )
445  providers.append(provider_info)
446 
447  connection.send_message(
448  websocket_api.result_message(msg["id"], {"providers": providers})
449  )
450 
list[AudioFormats] supported_formats(self)
Definition: __init__.py:177
list[AudioChannels] supported_channels(self)
Definition: __init__.py:197
bool check_metadata(self, SpeechMetadata metadata)
Definition: __init__.py:233
list[AudioSampleRates] supported_sample_rates(self)
Definition: __init__.py:192
list[AudioBitRates] supported_bit_rates(self)
Definition: __init__.py:187
SpeechResult async_process_audio_stream(self, SpeechMetadata metadata, AsyncIterable[bytes] stream)
Definition: __init__.py:226
SpeechResult internal_async_process_audio_stream(self, SpeechMetadata metadata, AsyncIterable[bytes] stream)
Definition: __init__.py:214
Provider _get_provider(self, HomeAssistant hass, str provider)
Definition: __init__.py:333
None __init__(self, dict[str, Provider] providers)
Definition: __init__.py:255
web.Response post(self, web.Request request, str provider)
Definition: __init__.py:259
str _suggest_report_issue(self, HomeAssistant hass, str provider, object provider_instance)
Definition: __init__.py:358
web.Response get(self, web.Request request, str provider)
Definition: __init__.py:299
HassAuthProvider async_get_provider(HomeAssistant hass)
CalendarEntity get_entity(HomeAssistant hass, str entity_id)
Definition: trigger.py:96
list[Coroutine[Any, Any, None]] async_setup_legacy(HomeAssistant hass, ConfigType config)
Definition: legacy.py:70
str|None async_default_provider(HomeAssistant hass)
Definition: legacy.py:35
set[str] async_get_speech_to_text_languages(HomeAssistant hass)
Definition: __init__.py:107
SpeechToTextEntity|Provider|None async_get_speech_to_text_engine(HomeAssistant hass, str engine_id)
Definition: __init__.py:99
bool async_setup_entry(HomeAssistant hass, ConfigEntry entry)
Definition: __init__.py:146
bool async_setup(HomeAssistant hass, ConfigType config)
Definition: __init__.py:122
None websocket_list_engines(HomeAssistant hass, websocket_api.ActiveConnection connection, dict msg)
Definition: __init__.py:420
SpeechMetadata _metadata_from_header(web.Request request)
Definition: __init__.py:365
bool async_unload_entry(HomeAssistant hass, ConfigEntry entry)
Definition: __init__.py:151
str|None async_default_engine(HomeAssistant hass)
Definition: __init__.py:74
SpeechToTextEntity|None async_get_speech_to_text_entity(HomeAssistant hass, str entity_id)
Definition: __init__.py:91
str async_suggest_report_issue(HomeAssistant|None hass, *Integration|None integration=None, str|None integration_domain=None, str|None module=None)
Definition: loader.py:1752