1 """Provide functionality to STT."""
3 from __future__
import annotations
5 from abc
import abstractmethod
6 from collections.abc
import AsyncIterable
7 from dataclasses
import asdict
9 from typing
import Any, final
11 from aiohttp
import web
12 from aiohttp.hdrs
import istr
13 from aiohttp.web_exceptions
import (
16 HTTPUnsupportedMediaType,
18 import voluptuous
as vol
45 async_default_provider,
49 from .models
import SpeechMetadata, SpeechResult
53 "async_get_speech_to_text_engine",
54 "async_get_speech_to_text_entity",
68 _LOGGER = logging.getLogger(__name__)
70 CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
75 """Return the domain or entity id of the default engine."""
76 default_entity_id: str |
None =
None
78 for entity
in hass.data[DATA_COMPONENT].entities:
79 if entity.platform
and entity.platform.platform_name ==
"cloud":
80 return entity.entity_id
82 if default_entity_id
is None:
83 default_entity_id = entity.entity_id
90 hass: HomeAssistant, entity_id: str
91 ) -> SpeechToTextEntity |
None:
92 """Return stt entity."""
93 return hass.data[DATA_COMPONENT].
get_entity(entity_id)
98 hass: HomeAssistant, engine_id: str
99 ) -> SpeechToTextEntity | Provider |
None:
100 """Return stt entity or legacy provider."""
108 """Return a set with the union of languages supported by stt engines."""
111 for entity
in hass.data[DATA_COMPONENT].entities:
112 for language_tag
in entity.supported_languages:
113 languages.add(language_tag)
115 for engine
in hass.data[DATA_PROVIDERS].values():
116 for language_tag
in engine.supported_languages:
117 languages.add(language_tag)
122 async
def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
124 websocket_api.async_register_command(hass, websocket_list_engines)
126 component = hass.data[DATA_COMPONENT] = EntityComponent[SpeechToTextEntity](
127 _LOGGER, DOMAIN, hass
130 component.register_shutdown()
133 for setup
in platform_setups:
140 hass.async_create_task(setup, eager_start=
True)
147 """Set up a config entry."""
152 """Unload a config entry."""
157 """Represent a single STT provider."""
159 _attr_should_poll =
False
160 __last_processed: str |
None =
None
165 """Return the state of the provider entity."""
173 """Return a list of supported languages."""
178 """Return a list of supported formats."""
183 """Return a list of supported codecs."""
188 """Return a list of supported bit rates."""
193 """Return a list of supported sample rates."""
198 """Return a list of supported channels."""
201 """Call when the provider entity is added to hass."""
203 state = await self.async_get_last_state()
206 and state.state
is not None
207 and state.state
not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
213 self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
215 """Process an audio stream to STT service.
217 Only streaming content is allowed!
225 self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
227 """Process an audio stream to STT service.
229 Only streaming content is allowed!
234 """Check if given metadata supported by this provider."""
248 """STT view to generate a text from audio stream."""
250 _legacy_provider_reported =
False
252 url =
"/api/stt/{provider}"
253 name =
"api:stt:provider"
255 def __init__(self, providers: dict[str, Provider]) ->
None:
256 """Initialize a tts view."""
259 async
def post(self, request: web.Request, provider: str) -> web.Response:
260 """Convert Speech (audio) to text."""
261 hass = request.app[KEY_HASS]
262 provider_entity: SpeechToTextEntity |
None =
None
265 and provider
not in self.
providersproviders
272 except ValueError
as err:
273 raise HTTPBadRequest(text=
str(err))
from err
275 if not provider_entity:
276 stt_provider = self.
_get_provider_get_provider(hass, provider)
279 if not stt_provider.check_metadata(metadata):
280 raise HTTPUnsupportedMediaType
283 result = await stt_provider.async_process_audio_stream(
284 metadata, request.content
288 if not provider_entity.check_metadata(metadata):
289 raise HTTPUnsupportedMediaType
292 result = await provider_entity.internal_async_process_audio_stream(
293 metadata, request.content
297 return self.json(asdict(result))
299 async
def get(self, request: web.Request, provider: str) -> web.Response:
300 """Return provider specific audio information."""
301 hass = request.app[KEY_HASS]
304 and provider
not in self.
providersproviders
308 if not provider_entity:
309 stt_provider = self.
_get_provider_get_provider(hass, provider)
313 "languages": stt_provider.supported_languages,
314 "formats": stt_provider.supported_formats,
315 "codecs": stt_provider.supported_codecs,
316 "sample_rates": stt_provider.supported_sample_rates,
317 "bit_rates": stt_provider.supported_bit_rates,
318 "channels": stt_provider.supported_channels,
324 "languages": provider_entity.supported_languages,
325 "formats": provider_entity.supported_formats,
326 "codecs": provider_entity.supported_codecs,
327 "sample_rates": provider_entity.supported_sample_rates,
328 "bit_rates": provider_entity.supported_bit_rates,
329 "channels": provider_entity.supported_channels,
336 Method for legacy providers.
337 This can be removed when we remove the legacy provider support.
339 stt_provider = self.
providersproviders[provider]
346 "Provider %s (%s) is using a legacy implementation, "
347 "and should be updated to use the SpeechToTextEntity. Please "
357 self, hass: HomeAssistant, provider: str, provider_instance: object
359 """Suggest to report an issue."""
361 hass, integration_domain=provider, module=type(provider_instance).__module__
366 """Extract STT metadata from header.
369 format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
372 data = request.headers[istr(
"X-Speech-Content")].split(
";")
373 except KeyError
as err:
374 raise ValueError(
"Missing X-Speech-Content header")
from err
386 args: dict[str, Any] = {}
388 key, _, value = entry.strip().partition(
"=")
389 if key
not in fields:
390 raise ValueError(f
"Invalid field: {key}")
394 if field
not in args:
395 raise ValueError(f
"Missing {field} in X-Speech-Content header")
399 language=args[
"language"],
400 format=args[
"format"],
402 bit_rate=args[
"bit_rate"],
403 sample_rate=args[
"sample_rate"],
404 channel=args[
"channel"],
406 except ValueError
as err:
407 raise ValueError(f
"Wrong format of X-Speech-Content: {err}")
from err
410 @websocket_api.websocket_command(
{
"type": "stt/engine/list",
vol.Optional("language"): str,
411 vol.Optional(
"country"): str,
418 """List speech-to-text engines and, optionally, if they support a given language."""
419 country = msg.get(
"country")
420 language = msg.get(
"language")
422 provider_info: dict[str, Any]
424 for entity
in hass.data[DATA_COMPONENT].entities:
426 "engine_id": entity.entity_id,
427 "supported_languages": entity.supported_languages,
430 provider_info[
"supported_languages"] = language_util.matches(
431 language, entity.supported_languages, country
433 providers.append(provider_info)
435 for engine_id, provider
in hass.data[DATA_PROVIDERS].items():
437 "engine_id": engine_id,
438 "name": provider.name,
439 "supported_languages": provider.supported_languages,
442 provider_info[
"supported_languages"] = language_util.matches(
443 language, provider.supported_languages, country
445 providers.append(provider_info)
447 connection.send_message(
448 websocket_api.result_message(msg[
"id"], {
"providers": providers})
450
list[AudioFormats] supported_formats(self)
None async_internal_added_to_hass(self)
list[AudioChannels] supported_channels(self)
bool check_metadata(self, SpeechMetadata metadata)
list[AudioSampleRates] supported_sample_rates(self)
list[AudioBitRates] supported_bit_rates(self)
SpeechResult async_process_audio_stream(self, SpeechMetadata metadata, AsyncIterable[bytes] stream)
list[AudioCodecs] supported_codecs(self)
list[str] supported_languages(self)
SpeechResult internal_async_process_audio_stream(self, SpeechMetadata metadata, AsyncIterable[bytes] stream)
_legacy_provider_reported
Provider _get_provider(self, HomeAssistant hass, str provider)
bool _legacy_provider_reported
None __init__(self, dict[str, Provider] providers)
web.Response post(self, web.Request request, str provider)
str _suggest_report_issue(self, HomeAssistant hass, str provider, object provider_instance)
web.Response get(self, web.Request request, str provider)
None async_write_ha_state(self)
HassAuthProvider async_get_provider(HomeAssistant hass)
CalendarEntity get_entity(HomeAssistant hass, str entity_id)
list[Coroutine[Any, Any, None]] async_setup_legacy(HomeAssistant hass, ConfigType config)
str|None async_default_provider(HomeAssistant hass)
set[str] async_get_speech_to_text_languages(HomeAssistant hass)
SpeechToTextEntity|Provider|None async_get_speech_to_text_engine(HomeAssistant hass, str engine_id)
bool async_setup_entry(HomeAssistant hass, ConfigEntry entry)
bool async_setup(HomeAssistant hass, ConfigType config)
None websocket_list_engines(HomeAssistant hass, websocket_api.ActiveConnection connection, dict msg)
SpeechMetadata _metadata_from_header(web.Request request)
bool async_unload_entry(HomeAssistant hass, ConfigEntry entry)
str|None async_default_engine(HomeAssistant hass)
SpeechToTextEntity|None async_get_speech_to_text_entity(HomeAssistant hass, str entity_id)
str async_suggest_report_issue(HomeAssistant|None hass, *Integration|None integration=None, str|None integration_domain=None, str|None module=None)