diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py index 27eeb1e7..3c4a3a3e 100644 --- a/src/client/content/config/tabs/models.py +++ b/src/client/content/config/tabs/models.py @@ -9,6 +9,7 @@ """ # spell-checker:ignore selectbox +import json import logging from time import sleep from typing import Literal, Any @@ -251,6 +252,41 @@ def _handle_form_submission(model: dict, action: str) -> bool: return False +@st.dialog("Pull Ollama Model", width="large") +def pull_model_dialog(model_provider: str, model_id: str) -> None: + """Stream Ollama model pull progress""" + st.write(f"Pull `{model_provider}/{model_id}` from Ollama?") + st.caption("Large models may take several minutes to download.") + + col1, _, col2 = st.columns([1.4, 6, 1.4]) + if col1.button("Pull", type="primary", key="pull_confirm_btn"): + quoted_id = urllib.parse.quote(model_id, safe="") + with st.status("Pulling model...", expanded=True) as pull_status: + try: + response = api_call.post_stream(endpoint=f"v1/models/pull/{model_provider}/{quoted_id}") + progress_line = st.empty() + for raw_line in response.iter_lines(): + if raw_line: + data = json.loads(raw_line) + if "error" in data: + pull_status.update(label=f"Pull failed: {data['error']}", state="error") + progress_line.error(data["error"]) + return + msg = data.get("status", "") + total = data.get("total", 0) + if total > 0: + pct = int(data.get("completed", 0) / total * 100) + progress_line.write(f"{msg}: {pct}%") + elif msg: + progress_line.write(msg) + pull_status.update(label="Pull complete!", state="complete") + progress_line.success(f"Model `{model_id}` pulled successfully. Enable it in the model configuration.") + except api_call.ApiError as ex: + pull_status.update(label=f"Pull failed: {ex}", state="error") + if col2.button("Cancel", key="pull_cancel_btn"): + st.rerun() + + @st.dialog("Model Configuration", width="large") def edit_model( model_type: str, action: Literal["add", "edit"], model_id: str = None, model_provider: str = None @@ -272,33 +308,35 @@ def edit_model( def render_model_rows(model_type: str) -> None: """Render rows of the models""" - data_col_widths = [0.08, 0.42, 0.28, 0.12] - table_col_format = st.columns(data_col_widths, vertical_alignment="center") - col1, col2, col3, col4 = table_col_format - col1.markdown("​", help="Active", unsafe_allow_html=True) - col2.markdown("**Model**", unsafe_allow_html=True) - col3.markdown("**Provider URL**", unsafe_allow_html=True) - col4.markdown("​") + data_col_widths = [0.08, 0.40, 0.27, 0.12, 0.10] + header_cols = st.columns(data_col_widths) + header_cols[0].markdown("​", help="Active", unsafe_allow_html=True) + header_cols[1].markdown("**Model**", unsafe_allow_html=True) + header_cols[2].markdown("**Provider URL**", unsafe_allow_html=True) + header_cols[3].markdown("​") + header_cols[4].markdown("​") for model in [m for m in state.model_configs if m.get("type") == model_type]: model_id = model["id"] model_provider = model["provider"] + # Pre-set widget state keys so Streamlit displays current values instead of cached ones + state[f"{model_type}_{model_provider}_{model_id}_enabled"] = st_common.bool_to_emoji(model["enabled"]) + state[f"{model_type}_{model_provider}_{model_id}"] = f"{model_provider}/{model_id}" + state[f"{model_type}_{model_provider}_{model_id}_api_base"] = model.get("api_base", "") + col1, col2, col3, col4, col5 = st.columns(data_col_widths, vertical_alignment="center") col1.text_input( "Enabled", - value=st_common.bool_to_emoji(model["enabled"]), key=f"{model_type}_{model_provider}_{model_id}_enabled", label_visibility="collapsed", disabled=True, ) col2.text_input( "Model", - value=f"{model_provider}/{model_id}", key=f"{model_type}_{model_provider}_{model_id}", label_visibility="collapsed", disabled=True, ) col3.text_input( "Server", - value=model["api_base"], key=f"{model_type}_{model_provider}_{model_id}_api_base", label_visibility="collapsed", disabled=True, @@ -314,6 +352,13 @@ def render_model_rows(model_type: str) -> None: "model_provider": model_provider, }, ) + if "ollama" in model_provider: + col5.button( + "Pull", + on_click=pull_model_dialog, + key=f"{model_type}_{model_provider}_{model_id}_pull", + kwargs={"model_provider": model_provider, "model_id": model_id}, + ) if st.button(label="Add", type="primary", key=f"add_{model_type}_model"): edit_model(model_type=model_type, action="add") @@ -327,7 +372,7 @@ def display_models() -> None: st.title("Models") st.write("Update, Add, or Delete model configuration parameters.") try: - get_models() + get_models(force=True) except api_call.ApiError: st.stop() diff --git a/src/client/utils/api_call.py b/src/client/utils/api_call.py index 5a0371a8..6ca4f2e3 100644 --- a/src/client/utils/api_call.py +++ b/src/client/utils/api_call.py @@ -175,6 +175,23 @@ def patch( return result +def post_stream(endpoint: str, timeout: int = 600) -> requests.Response: + """POST request with streaming enabled. Returns response for line-by-line iteration.""" + url = urljoin(f"{state.server['url']}:{state.server['port']}/", endpoint) + headers = {"Authorization": f"Bearer {state.server['key']}"} + if getattr(state, "client_settings", {}).get("client"): + headers["Client"] = state.client_settings["client"] + LOGGER.info("POST Stream Request: %s", url) + try: + response = requests.post(url, headers=headers, timeout=timeout, stream=True) + response.raise_for_status() + except requests.exceptions.HTTPError as ex: + _error_response(_handle_http_error(ex)) + except requests.exceptions.ConnectionError as ex: + _error_response(f"Connection failed: {str(ex)}") + return response + + def delete(endpoint: str, timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, toast=True) -> dict: """DELETE Requests""" result = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor) diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index e1c82d3a..e26b2620 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -4,10 +4,12 @@ """ # spell-checker:ignore ollama pplx huggingface genai giskard litellm ocigenai rerank vllm +import json import logging -from typing import Optional, Union, Any +from typing import Optional, Union, Any, AsyncGenerator from urllib.parse import urlparse +import httpx import litellm from langchain.embeddings import init_embeddings @@ -43,6 +45,14 @@ class UnknownModelError(ValueError): """Raised when the model data doesn't exist.""" +class OllamaPullError(ValueError): + """Raised when an Ollama model pull fails.""" + + +class OllamaModelNotPulledError(ValueError): + """Raised when trying to enable an Ollama model that has not been pulled.""" + + ##################################################### # CRUD Functions ##################################################### @@ -98,6 +108,13 @@ def update(payload: schema.Model) -> schema.Model: model_existing.enabled = False raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.") + # For Ollama models, require the model to be pulled before enabling + if payload.enabled and "ollama" in payload.provider: + if not is_ollama_model_pulled(payload.id, payload.api_base): + raise OllamaModelNotPulledError( + f"Model '{payload.id}' has not been pulled. Use the Pull button to download it first." + ) + # Update all fields from payload in place for key, value in payload.model_dump().items(): setattr(model_existing, key, value) @@ -109,6 +126,37 @@ def delete(model_provider: schema.ModelProviderType, model_id: schema.ModelIdTyp MODEL_OBJECTS[:] = [m for m in MODEL_OBJECTS if (m.id, m.provider) != (model_id, model_provider)] +def is_ollama_model_pulled(model_id: str, api_base: str) -> bool: + """Check whether an Ollama model has been pulled to the local instance.""" + tags_url = f"{api_base.rstrip('/')}/api/tags" + try: + response = httpx.get(tags_url, timeout=5) + response.raise_for_status() + models = response.json().get("models", []) + # Ollama stores untagged models as "name:latest" + normalized_id = model_id if ":" in model_id else f"{model_id}:latest" + return any(m.get("name") == normalized_id for m in models) + except (httpx.ConnectError, httpx.HTTPStatusError, httpx.TimeoutException): + return False + + +async def pull_ollama_model(model_id: str, api_base: str) -> AsyncGenerator[str, None]: + """Async generator that streams Ollama model pull progress as NDJSON lines.""" + pull_url = f"{api_base.rstrip('/')}/api/pull" + LOGGER.debug("Pulling Ollama model %s from %s", model_id, pull_url) + try: + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream("POST", pull_url, json={"model": model_id}) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line: + yield line + "\n" + except httpx.HTTPStatusError as ex: + yield json.dumps({"error": f"Ollama returned HTTP {ex.response.status_code}"}) + "\n" + except (httpx.ConnectError, httpx.ConnectTimeout) as ex: + yield json.dumps({"error": f"Cannot connect to Ollama at {api_base}: {ex}"}) + "\n" + + ##################################################### # Utility Functions ##################################################### diff --git a/src/server/api/v1/models.py b/src/server/api/v1/models.py index fcf47424..e1d603ac 100644 --- a/src/server/api/v1/models.py +++ b/src/server/api/v1/models.py @@ -6,7 +6,7 @@ import logging from typing import Optional, Any from fastapi import APIRouter, HTTPException, Query -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse import server.api.utils.models as utils_models @@ -84,7 +84,7 @@ async def models_update(payload: schema.Model) -> schema.Model: return utils_models.update(payload=payload) except utils_models.UnknownModelError as ex: raise HTTPException(status_code=404, detail=str(ex)) from ex - except utils_models.URLUnreachableError as ex: + except (utils_models.URLUnreachableError, utils_models.OllamaModelNotPulledError) as ex: raise HTTPException(status_code=422, detail=str(ex)) from ex @@ -101,6 +101,34 @@ async def models_create( raise HTTPException(status_code=409, detail=str(ex)) from ex +@auth.post( + "/pull/{model_provider}/{model_id:path}", + description="Pull an Ollama model from the registry", +) +async def models_pull( + model_provider: schema.ModelProviderType, + model_id: schema.ModelIdType, +) -> StreamingResponse: + """Pull an Ollama model and stream progress as NDJSON""" + LOGGER.debug("Received models_pull - model: %s/%s", model_provider, model_id) + + if "ollama" not in model_provider: + raise HTTPException(status_code=422, detail="Pull is only supported for Ollama models") + + try: + (model,) = utils_models.get(model_provider=model_provider, model_id=model_id) + except utils_models.UnknownModelError as ex: + raise HTTPException(status_code=404, detail=str(ex)) from ex + + if not model.api_base: + raise HTTPException(status_code=422, detail="Model has no API base URL configured") + + return StreamingResponse( + utils_models.pull_ollama_model(model_id=model_id, api_base=model.api_base), + media_type="application/x-ndjson", + ) + + @auth.delete( "/{model_provider}/{model_id:path}", description="Delete a model",