Home Assistant Unofficial Reference 2024.12.1
config_flow.py
Go to the documentation of this file.
1 """Config flow for Ollama integration."""
2 
3 from __future__ import annotations
4 
5 import asyncio
6 import logging
7 import sys
8 from types import MappingProxyType
9 from typing import Any
10 
11 import httpx
12 import ollama
13 import voluptuous as vol
14 
15 from homeassistant.config_entries import (
16  ConfigEntry,
17  ConfigFlow,
18  ConfigFlowResult,
19  OptionsFlow,
20 )
21 from homeassistant.const import CONF_LLM_HASS_API, CONF_URL
22 from homeassistant.core import HomeAssistant
23 from homeassistant.helpers import llm
25  NumberSelector,
26  NumberSelectorConfig,
27  NumberSelectorMode,
28  SelectOptionDict,
29  SelectSelector,
30  SelectSelectorConfig,
31  TemplateSelector,
32  TextSelector,
33  TextSelectorConfig,
34  TextSelectorType,
35 )
36 from homeassistant.util.ssl import get_default_context
37 
38 from .const import (
39  CONF_KEEP_ALIVE,
40  CONF_MAX_HISTORY,
41  CONF_MODEL,
42  CONF_NUM_CTX,
43  CONF_PROMPT,
44  DEFAULT_KEEP_ALIVE,
45  DEFAULT_MAX_HISTORY,
46  DEFAULT_MODEL,
47  DEFAULT_NUM_CTX,
48  DEFAULT_TIMEOUT,
49  DOMAIN,
50  MAX_NUM_CTX,
51  MIN_NUM_CTX,
52  MODEL_NAMES,
53 )
54 
55 _LOGGER = logging.getLogger(__name__)
56 
57 
58 STEP_USER_DATA_SCHEMA = vol.Schema(
59  {
60  vol.Required(CONF_URL): TextSelector(
61  TextSelectorConfig(type=TextSelectorType.URL)
62  ),
63  }
64 )
65 
66 
67 class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
68  """Handle a config flow for Ollama."""
69 
70  VERSION = 1
71 
72  def __init__(self) -> None:
73  """Initialize config flow."""
74  self.urlurl: str | None = None
75  self.modelmodel: str | None = None
76  self.clientclient: ollama.AsyncClient | None = None
77  self.download_taskdownload_task: asyncio.Task | None = None
78 
79  async def async_step_user(
80  self, user_input: dict[str, Any] | None = None
81  ) -> ConfigFlowResult:
82  """Handle the initial step."""
83  user_input = user_input or {}
84  self.urlurl = user_input.get(CONF_URL, self.urlurl)
85  self.modelmodel = user_input.get(CONF_MODEL, self.modelmodel)
86 
87  if self.urlurl is None:
88  return self.async_show_formasync_show_formasync_show_form(
89  step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False
90  )
91 
92  errors = {}
93 
94  try:
95  self.clientclient = ollama.AsyncClient(
96  host=self.urlurl, verify=get_default_context()
97  )
98  async with asyncio.timeout(DEFAULT_TIMEOUT):
99  response = await self.clientclient.list()
100 
101  downloaded_models: set[str] = {
102  model_info["model"] for model_info in response.get("models", [])
103  }
104  except (TimeoutError, httpx.ConnectError):
105  errors["base"] = "cannot_connect"
106  except Exception:
107  _LOGGER.exception("Unexpected exception")
108  errors["base"] = "unknown"
109 
110  if errors:
111  return self.async_show_formasync_show_formasync_show_form(
112  step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
113  )
114 
115  if self.modelmodel is None:
116  # Show models that have been downloaded first, followed by all known
117  # models (only latest tags).
118  models_to_list = [
119  SelectOptionDict(label=f"{m} (downloaded)", value=m)
120  for m in sorted(downloaded_models)
121  ] + [
122  SelectOptionDict(label=m, value=f"{m}:latest")
123  for m in sorted(MODEL_NAMES)
124  if m not in downloaded_models
125  ]
126  model_step_schema = vol.Schema(
127  {
128  vol.Required(
129  CONF_MODEL, description={"suggested_value": DEFAULT_MODEL}
130  ): SelectSelector(
131  SelectSelectorConfig(options=models_to_list, custom_value=True)
132  ),
133  }
134  )
135 
136  return self.async_show_formasync_show_formasync_show_form(
137  step_id="user",
138  data_schema=model_step_schema,
139  )
140 
141  if self.modelmodel not in downloaded_models:
142  # Ollama server needs to download model first
143  return await self.async_step_downloadasync_step_download()
144 
145  return self.async_create_entryasync_create_entryasync_create_entry(
146  title=_get_title(self.modelmodel),
147  data={CONF_URL: self.urlurl, CONF_MODEL: self.modelmodel},
148  )
149 
151  self, user_input: dict[str, Any] | None = None
152  ) -> ConfigFlowResult:
153  """Step to wait for Ollama server to download a model."""
154  assert self.modelmodel is not None
155  assert self.clientclient is not None
156 
157  if self.download_taskdownload_task is None:
158  # Tell Ollama server to pull the model.
159  # The task will block until the model and metadata are fully
160  # downloaded.
161  self.download_taskdownload_task = self.hass.async_create_background_task(
162  self.clientclient.pull(self.modelmodel),
163  f"Downloading {self.model}",
164  )
165 
166  if self.download_taskdownload_task.done():
167  if err := self.download_taskdownload_task.exception():
168  _LOGGER.exception("Unexpected error while downloading model: %s", err)
169  return self.async_show_progress_doneasync_show_progress_done(next_step_id="failed")
170 
171  return self.async_show_progress_doneasync_show_progress_done(next_step_id="finish")
172 
173  return self.async_show_progressasync_show_progress(
174  step_id="download",
175  progress_action="download",
176  progress_task=self.download_taskdownload_task,
177  )
178 
179  async def async_step_finish(
180  self, user_input: dict[str, Any] | None = None
181  ) -> ConfigFlowResult:
182  """Step after model downloading has succeeded."""
183  assert self.urlurl is not None
184  assert self.modelmodel is not None
185 
186  return self.async_create_entryasync_create_entryasync_create_entry(
187  title=_get_title(self.modelmodel),
188  data={CONF_URL: self.urlurl, CONF_MODEL: self.modelmodel},
189  )
190 
191  async def async_step_failed(
192  self, user_input: dict[str, Any] | None = None
193  ) -> ConfigFlowResult:
194  """Step after model downloading has failed."""
195  return self.async_abortasync_abortasync_abort(reason="download_failed")
196 
197  @staticmethod
199  config_entry: ConfigEntry,
200  ) -> OptionsFlow:
201  """Create the options flow."""
202  return OllamaOptionsFlow(config_entry)
203 
204 
206  """Ollama options flow."""
207 
208  def __init__(self, config_entry: ConfigEntry) -> None:
209  """Initialize options flow."""
210  self.url: str = config_entry.data[CONF_URL]
211  self.model: str = config_entry.data[CONF_MODEL]
212 
213  async def async_step_init(
214  self, user_input: dict[str, Any] | None = None
215  ) -> ConfigFlowResult:
216  """Manage the options."""
217  if user_input is not None:
218  if user_input[CONF_LLM_HASS_API] == "none":
219  user_input.pop(CONF_LLM_HASS_API)
220  return self.async_create_entryasync_create_entry(
221  title=_get_title(self.model), data=user_input
222  )
223 
224  options = self.config_entryconfig_entryconfig_entry.options or MappingProxyType({})
225  schema = ollama_config_option_schema(self.hass, options)
226  return self.async_show_formasync_show_form(
227  step_id="init",
228  data_schema=vol.Schema(schema),
229  )
230 
231 
233  hass: HomeAssistant, options: MappingProxyType[str, Any]
234 ) -> dict:
235  """Ollama options schema."""
236  hass_apis: list[SelectOptionDict] = [
238  label="No control",
239  value="none",
240  )
241  ]
242  hass_apis.extend(
244  label=api.name,
245  value=api.id,
246  )
247  for api in llm.async_get_apis(hass)
248  )
249 
250  return {
251  vol.Optional(
252  CONF_PROMPT,
253  description={
254  "suggested_value": options.get(
255  CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
256  )
257  },
258  ): TemplateSelector(),
259  vol.Optional(
260  CONF_LLM_HASS_API,
261  description={"suggested_value": options.get(CONF_LLM_HASS_API)},
262  default="none",
263  ): SelectSelector(SelectSelectorConfig(options=hass_apis)),
264  vol.Optional(
265  CONF_NUM_CTX,
266  description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
267  ): NumberSelector(
269  min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX
270  )
271  ),
272  vol.Optional(
273  CONF_MAX_HISTORY,
274  description={
275  "suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)
276  },
277  ): NumberSelector(
279  min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
280  )
281  ),
282  vol.Optional(
283  CONF_KEEP_ALIVE,
284  description={
285  "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
286  },
287  ): NumberSelector(
289  min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
290  )
291  ),
292  }
293 
294 
295 def _get_title(model: str) -> str:
296  """Get title for config entry."""
297  if model.endswith(":latest"):
298  model = model.split(":", maxsplit=1)[0]
299 
300  return model
ConfigFlowResult async_step_finish(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:181
OptionsFlow async_get_options_flow(ConfigEntry config_entry)
Definition: config_flow.py:200
ConfigFlowResult async_step_failed(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:193
ConfigFlowResult async_step_user(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:81
ConfigFlowResult async_step_download(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:152
ConfigFlowResult async_step_init(self, dict[str, Any]|None user_input=None)
Definition: config_flow.py:215
ConfigFlowResult async_create_entry(self, *str title, Mapping[str, Any] data, str|None description=None, Mapping[str, str]|None description_placeholders=None, Mapping[str, Any]|None options=None)
ConfigFlowResult async_abort(self, *str reason, Mapping[str, str]|None description_placeholders=None)
ConfigFlowResult async_show_form(self, *str|None step_id=None, vol.Schema|None data_schema=None, dict[str, str]|None errors=None, Mapping[str, str]|None description_placeholders=None, bool|None last_step=None, str|None preview=None)
None config_entry(self, ConfigEntry value)
_FlowResultT async_show_form(self, *str|None step_id=None, vol.Schema|None data_schema=None, dict[str, str]|None errors=None, Mapping[str, str]|None description_placeholders=None, bool|None last_step=None, str|None preview=None)
_FlowResultT async_show_progress(self, *str|None step_id=None, str progress_action, Mapping[str, str]|None description_placeholders=None, asyncio.Task[Any]|None progress_task=None)
_FlowResultT async_show_progress_done(self, *str next_step_id)
_FlowResultT async_create_entry(self, *str|None title=None, Mapping[str, Any] data, str|None description=None, Mapping[str, str]|None description_placeholders=None)
_FlowResultT async_abort(self, *str reason, Mapping[str, str]|None description_placeholders=None)
dict ollama_config_option_schema(HomeAssistant hass, MappingProxyType[str, Any] options)
Definition: config_flow.py:234
ssl.SSLContext get_default_context()
Definition: ssl.py:118