diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py
index 27eeb1e7..3c4a3a3e 100644
--- a/src/client/content/config/tabs/models.py
+++ b/src/client/content/config/tabs/models.py
@@ -9,6 +9,7 @@
"""
# spell-checker:ignore selectbox
+import json
import logging
from time import sleep
from typing import Literal, Any
@@ -251,6 +252,41 @@ def _handle_form_submission(model: dict, action: str) -> bool:
return False
+@st.dialog("Pull Ollama Model", width="large")
+def pull_model_dialog(model_provider: str, model_id: str) -> None:
+ """Stream Ollama model pull progress"""
+ st.write(f"Pull `{model_provider}/{model_id}` from Ollama?")
+ st.caption("Large models may take several minutes to download.")
+
+ col1, _, col2 = st.columns([1.4, 6, 1.4])
+ if col1.button("Pull", type="primary", key="pull_confirm_btn"):
+ quoted_id = urllib.parse.quote(model_id, safe="")
+ with st.status("Pulling model...", expanded=True) as pull_status:
+ try:
+ response = api_call.post_stream(endpoint=f"v1/models/pull/{model_provider}/{quoted_id}")
+ progress_line = st.empty()
+ for raw_line in response.iter_lines():
+ if raw_line:
+ data = json.loads(raw_line)
+ if "error" in data:
+ pull_status.update(label=f"Pull failed: {data['error']}", state="error")
+ progress_line.error(data["error"])
+ return
+ msg = data.get("status", "")
+ total = data.get("total", 0)
+ if total > 0:
+ pct = int(data.get("completed", 0) / total * 100)
+ progress_line.write(f"{msg}: {pct}%")
+ elif msg:
+ progress_line.write(msg)
+ pull_status.update(label="Pull complete!", state="complete")
+ progress_line.success(f"Model `{model_id}` pulled successfully. Enable it in the model configuration.")
+ except api_call.ApiError as ex:
+ pull_status.update(label=f"Pull failed: {ex}", state="error")
+ if col2.button("Cancel", key="pull_cancel_btn"):
+ st.rerun()
+
+
@st.dialog("Model Configuration", width="large")
def edit_model(
model_type: str, action: Literal["add", "edit"], model_id: str = None, model_provider: str = None
@@ -272,33 +308,35 @@ def edit_model(
def render_model_rows(model_type: str) -> None:
"""Render rows of the models"""
- data_col_widths = [0.08, 0.42, 0.28, 0.12]
- table_col_format = st.columns(data_col_widths, vertical_alignment="center")
- col1, col2, col3, col4 = table_col_format
- col1.markdown("", help="Active", unsafe_allow_html=True)
- col2.markdown("**Model**", unsafe_allow_html=True)
- col3.markdown("**Provider URL**", unsafe_allow_html=True)
- col4.markdown("")
+ data_col_widths = [0.08, 0.40, 0.27, 0.12, 0.10]
+ header_cols = st.columns(data_col_widths)
+ header_cols[0].markdown("", help="Active", unsafe_allow_html=True)
+ header_cols[1].markdown("**Model**", unsafe_allow_html=True)
+ header_cols[2].markdown("**Provider URL**", unsafe_allow_html=True)
+ header_cols[3].markdown("")
+ header_cols[4].markdown("")
for model in [m for m in state.model_configs if m.get("type") == model_type]:
model_id = model["id"]
model_provider = model["provider"]
+ # Pre-set widget state keys so Streamlit displays current values instead of cached ones
+ state[f"{model_type}_{model_provider}_{model_id}_enabled"] = st_common.bool_to_emoji(model["enabled"])
+ state[f"{model_type}_{model_provider}_{model_id}"] = f"{model_provider}/{model_id}"
+ state[f"{model_type}_{model_provider}_{model_id}_api_base"] = model.get("api_base", "")
+ col1, col2, col3, col4, col5 = st.columns(data_col_widths, vertical_alignment="center")
col1.text_input(
"Enabled",
- value=st_common.bool_to_emoji(model["enabled"]),
key=f"{model_type}_{model_provider}_{model_id}_enabled",
label_visibility="collapsed",
disabled=True,
)
col2.text_input(
"Model",
- value=f"{model_provider}/{model_id}",
key=f"{model_type}_{model_provider}_{model_id}",
label_visibility="collapsed",
disabled=True,
)
col3.text_input(
"Server",
- value=model["api_base"],
key=f"{model_type}_{model_provider}_{model_id}_api_base",
label_visibility="collapsed",
disabled=True,
@@ -314,6 +352,13 @@ def render_model_rows(model_type: str) -> None:
"model_provider": model_provider,
},
)
+ if "ollama" in model_provider:
+ col5.button(
+ "Pull",
+ on_click=pull_model_dialog,
+ key=f"{model_type}_{model_provider}_{model_id}_pull",
+ kwargs={"model_provider": model_provider, "model_id": model_id},
+ )
if st.button(label="Add", type="primary", key=f"add_{model_type}_model"):
edit_model(model_type=model_type, action="add")
@@ -327,7 +372,7 @@ def display_models() -> None:
st.title("Models")
st.write("Update, Add, or Delete model configuration parameters.")
try:
- get_models()
+ get_models(force=True)
except api_call.ApiError:
st.stop()
diff --git a/src/client/utils/api_call.py b/src/client/utils/api_call.py
index 5a0371a8..6ca4f2e3 100644
--- a/src/client/utils/api_call.py
+++ b/src/client/utils/api_call.py
@@ -175,6 +175,23 @@ def patch(
return result
+def post_stream(endpoint: str, timeout: int = 600) -> requests.Response:
+ """POST request with streaming enabled. Returns response for line-by-line iteration."""
+ url = urljoin(f"{state.server['url']}:{state.server['port']}/", endpoint)
+ headers = {"Authorization": f"Bearer {state.server['key']}"}
+ if getattr(state, "client_settings", {}).get("client"):
+ headers["Client"] = state.client_settings["client"]
+ LOGGER.info("POST Stream Request: %s", url)
+ try:
+ response = requests.post(url, headers=headers, timeout=timeout, stream=True)
+ response.raise_for_status()
+ except requests.exceptions.HTTPError as ex:
+ _error_response(_handle_http_error(ex))
+ except requests.exceptions.ConnectionError as ex:
+ _error_response(f"Connection failed: {str(ex)}")
+ return response
+
+
def delete(endpoint: str, timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, toast=True) -> dict:
"""DELETE Requests"""
result = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor)
diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py
index e1c82d3a..e26b2620 100644
--- a/src/server/api/utils/models.py
+++ b/src/server/api/utils/models.py
@@ -4,10 +4,12 @@
"""
# spell-checker:ignore ollama pplx huggingface genai giskard litellm ocigenai rerank vllm
+import json
import logging
-from typing import Optional, Union, Any
+from typing import Optional, Union, Any, AsyncGenerator
from urllib.parse import urlparse
+import httpx
import litellm
from langchain.embeddings import init_embeddings
@@ -43,6 +45,14 @@ class UnknownModelError(ValueError):
"""Raised when the model data doesn't exist."""
+class OllamaPullError(ValueError):
+ """Raised when an Ollama model pull fails."""
+
+
+class OllamaModelNotPulledError(ValueError):
+ """Raised when trying to enable an Ollama model that has not been pulled."""
+
+
#####################################################
# CRUD Functions
#####################################################
@@ -98,6 +108,13 @@ def update(payload: schema.Model) -> schema.Model:
model_existing.enabled = False
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")
+ # For Ollama models, require the model to be pulled before enabling
+ if payload.enabled and "ollama" in payload.provider:
+ if not is_ollama_model_pulled(payload.id, payload.api_base):
+ raise OllamaModelNotPulledError(
+ f"Model '{payload.id}' has not been pulled. Use the Pull button to download it first."
+ )
+
# Update all fields from payload in place
for key, value in payload.model_dump().items():
setattr(model_existing, key, value)
@@ -109,6 +126,37 @@ def delete(model_provider: schema.ModelProviderType, model_id: schema.ModelIdTyp
MODEL_OBJECTS[:] = [m for m in MODEL_OBJECTS if (m.id, m.provider) != (model_id, model_provider)]
+def is_ollama_model_pulled(model_id: str, api_base: str) -> bool:
+ """Check whether an Ollama model has been pulled to the local instance."""
+ tags_url = f"{api_base.rstrip('/')}/api/tags"
+ try:
+ response = httpx.get(tags_url, timeout=5)
+ response.raise_for_status()
+ models = response.json().get("models", [])
+ # Ollama stores untagged models as "name:latest"
+ normalized_id = model_id if ":" in model_id else f"{model_id}:latest"
+ return any(m.get("name") == normalized_id for m in models)
+ except (httpx.ConnectError, httpx.HTTPStatusError, httpx.TimeoutException):
+ return False
+
+
+async def pull_ollama_model(model_id: str, api_base: str) -> AsyncGenerator[str, None]:
+ """Async generator that streams Ollama model pull progress as NDJSON lines."""
+ pull_url = f"{api_base.rstrip('/')}/api/pull"
+ LOGGER.debug("Pulling Ollama model %s from %s", model_id, pull_url)
+ try:
+ async with httpx.AsyncClient(timeout=None) as client:
+ async with client.stream("POST", pull_url, json={"model": model_id}) as response:
+ response.raise_for_status()
+ async for line in response.aiter_lines():
+ if line:
+ yield line + "\n"
+ except httpx.HTTPStatusError as ex:
+ yield json.dumps({"error": f"Ollama returned HTTP {ex.response.status_code}"}) + "\n"
+ except (httpx.ConnectError, httpx.ConnectTimeout) as ex:
+ yield json.dumps({"error": f"Cannot connect to Ollama at {api_base}: {ex}"}) + "\n"
+
+
#####################################################
# Utility Functions
#####################################################
diff --git a/src/server/api/v1/models.py b/src/server/api/v1/models.py
index fcf47424..e1d603ac 100644
--- a/src/server/api/v1/models.py
+++ b/src/server/api/v1/models.py
@@ -6,7 +6,7 @@
import logging
from typing import Optional, Any
from fastapi import APIRouter, HTTPException, Query
-from fastapi.responses import JSONResponse
+from fastapi.responses import JSONResponse, StreamingResponse
import server.api.utils.models as utils_models
@@ -84,7 +84,7 @@ async def models_update(payload: schema.Model) -> schema.Model:
return utils_models.update(payload=payload)
except utils_models.UnknownModelError as ex:
raise HTTPException(status_code=404, detail=str(ex)) from ex
- except utils_models.URLUnreachableError as ex:
+ except (utils_models.URLUnreachableError, utils_models.OllamaModelNotPulledError) as ex:
raise HTTPException(status_code=422, detail=str(ex)) from ex
@@ -101,6 +101,34 @@ async def models_create(
raise HTTPException(status_code=409, detail=str(ex)) from ex
+@auth.post(
+ "/pull/{model_provider}/{model_id:path}",
+ description="Pull an Ollama model from the registry",
+)
+async def models_pull(
+ model_provider: schema.ModelProviderType,
+ model_id: schema.ModelIdType,
+) -> StreamingResponse:
+ """Pull an Ollama model and stream progress as NDJSON"""
+ LOGGER.debug("Received models_pull - model: %s/%s", model_provider, model_id)
+
+ if "ollama" not in model_provider:
+ raise HTTPException(status_code=422, detail="Pull is only supported for Ollama models")
+
+ try:
+ (model,) = utils_models.get(model_provider=model_provider, model_id=model_id)
+ except utils_models.UnknownModelError as ex:
+ raise HTTPException(status_code=404, detail=str(ex)) from ex
+
+ if not model.api_base:
+ raise HTTPException(status_code=422, detail="Model has no API base URL configured")
+
+ return StreamingResponse(
+ utils_models.pull_ollama_model(model_id=model_id, api_base=model.api_base),
+ media_type="application/x-ndjson",
+ )
+
+
@auth.delete(
"/{model_provider}/{model_id:path}",
description="Delete a model",