Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions src/client/content/config/tabs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
# spell-checker:ignore selectbox

import json
import logging
from time import sleep
from typing import Literal, Any
Expand Down Expand Up @@ -251,6 +252,41 @@ def _handle_form_submission(model: dict, action: str) -> bool:
return False


@st.dialog("Pull Ollama Model", width="large")
def pull_model_dialog(model_provider: str, model_id: str) -> None:
"""Stream Ollama model pull progress"""
st.write(f"Pull `{model_provider}/{model_id}` from Ollama?")
st.caption("Large models may take several minutes to download.")

col1, _, col2 = st.columns([1.4, 6, 1.4])
if col1.button("Pull", type="primary", key="pull_confirm_btn"):
quoted_id = urllib.parse.quote(model_id, safe="")
with st.status("Pulling model...", expanded=True) as pull_status:
try:
response = api_call.post_stream(endpoint=f"v1/models/pull/{model_provider}/{quoted_id}")
progress_line = st.empty()
for raw_line in response.iter_lines():
if raw_line:
data = json.loads(raw_line)
if "error" in data:
pull_status.update(label=f"Pull failed: {data['error']}", state="error")
progress_line.error(data["error"])
return
msg = data.get("status", "")
total = data.get("total", 0)
if total > 0:
pct = int(data.get("completed", 0) / total * 100)
progress_line.write(f"{msg}: {pct}%")
elif msg:
progress_line.write(msg)
pull_status.update(label="Pull complete!", state="complete")
progress_line.success(f"Model `{model_id}` pulled successfully. Enable it in the model configuration.")
except api_call.ApiError as ex:
pull_status.update(label=f"Pull failed: {ex}", state="error")
if col2.button("Cancel", key="pull_cancel_btn"):
st.rerun()


@st.dialog("Model Configuration", width="large")
def edit_model(
model_type: str, action: Literal["add", "edit"], model_id: str = None, model_provider: str = None
Expand All @@ -272,33 +308,35 @@ def edit_model(

def render_model_rows(model_type: str) -> None:
"""Render rows of the models"""
data_col_widths = [0.08, 0.42, 0.28, 0.12]
table_col_format = st.columns(data_col_widths, vertical_alignment="center")
col1, col2, col3, col4 = table_col_format
col1.markdown("​", help="Active", unsafe_allow_html=True)
col2.markdown("**<u>Model</u>**", unsafe_allow_html=True)
col3.markdown("**<u>Provider URL</u>**", unsafe_allow_html=True)
col4.markdown("&#x200B;")
data_col_widths = [0.08, 0.40, 0.27, 0.12, 0.10]
header_cols = st.columns(data_col_widths)
header_cols[0].markdown("&#x200B;", help="Active", unsafe_allow_html=True)
header_cols[1].markdown("**<u>Model</u>**", unsafe_allow_html=True)
header_cols[2].markdown("**<u>Provider URL</u>**", unsafe_allow_html=True)
header_cols[3].markdown("&#x200B;")
header_cols[4].markdown("&#x200B;")
for model in [m for m in state.model_configs if m.get("type") == model_type]:
model_id = model["id"]
model_provider = model["provider"]
# Pre-set widget state keys so Streamlit displays current values instead of cached ones
state[f"{model_type}_{model_provider}_{model_id}_enabled"] = st_common.bool_to_emoji(model["enabled"])
state[f"{model_type}_{model_provider}_{model_id}"] = f"{model_provider}/{model_id}"
state[f"{model_type}_{model_provider}_{model_id}_api_base"] = model.get("api_base", "")
col1, col2, col3, col4, col5 = st.columns(data_col_widths, vertical_alignment="center")
col1.text_input(
"Enabled",
value=st_common.bool_to_emoji(model["enabled"]),
key=f"{model_type}_{model_provider}_{model_id}_enabled",
label_visibility="collapsed",
disabled=True,
)
col2.text_input(
"Model",
value=f"{model_provider}/{model_id}",
key=f"{model_type}_{model_provider}_{model_id}",
label_visibility="collapsed",
disabled=True,
)
col3.text_input(
"Server",
value=model["api_base"],
key=f"{model_type}_{model_provider}_{model_id}_api_base",
label_visibility="collapsed",
disabled=True,
Expand All @@ -314,6 +352,13 @@ def render_model_rows(model_type: str) -> None:
"model_provider": model_provider,
},
)
if "ollama" in model_provider:
col5.button(
"Pull",
on_click=pull_model_dialog,
key=f"{model_type}_{model_provider}_{model_id}_pull",
kwargs={"model_provider": model_provider, "model_id": model_id},
)

if st.button(label="Add", type="primary", key=f"add_{model_type}_model"):
edit_model(model_type=model_type, action="add")
Expand All @@ -327,7 +372,7 @@ def display_models() -> None:
st.title("Models")
st.write("Update, Add, or Delete model configuration parameters.")
try:
get_models()
get_models(force=True)
except api_call.ApiError:
st.stop()

Expand Down
17 changes: 17 additions & 0 deletions src/client/utils/api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,23 @@ def patch(
return result


def post_stream(endpoint: str, timeout: int = 600) -> requests.Response:
"""POST request with streaming enabled. Returns response for line-by-line iteration."""
url = urljoin(f"{state.server['url']}:{state.server['port']}/", endpoint)
headers = {"Authorization": f"Bearer {state.server['key']}"}
if getattr(state, "client_settings", {}).get("client"):
headers["Client"] = state.client_settings["client"]
LOGGER.info("POST Stream Request: %s", url)
try:
response = requests.post(url, headers=headers, timeout=timeout, stream=True)
response.raise_for_status()
except requests.exceptions.HTTPError as ex:
_error_response(_handle_http_error(ex))
except requests.exceptions.ConnectionError as ex:
_error_response(f"Connection failed: {str(ex)}")
return response


def delete(endpoint: str, timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, toast=True) -> dict:
"""DELETE Requests"""
result = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor)
Expand Down
50 changes: 49 additions & 1 deletion src/server/api/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
"""
# spell-checker:ignore ollama pplx huggingface genai giskard litellm ocigenai rerank vllm

import json
import logging
from typing import Optional, Union, Any
from typing import Optional, Union, Any, AsyncGenerator
from urllib.parse import urlparse

import httpx
import litellm

from langchain.embeddings import init_embeddings
Expand Down Expand Up @@ -43,6 +45,14 @@ class UnknownModelError(ValueError):
"""Raised when the model data doesn't exist."""


class OllamaPullError(ValueError):
"""Raised when an Ollama model pull fails."""


class OllamaModelNotPulledError(ValueError):
"""Raised when trying to enable an Ollama model that has not been pulled."""


#####################################################
# CRUD Functions
#####################################################
Expand Down Expand Up @@ -98,6 +108,13 @@ def update(payload: schema.Model) -> schema.Model:
model_existing.enabled = False
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")

# For Ollama models, require the model to be pulled before enabling
if payload.enabled and "ollama" in payload.provider:
if not is_ollama_model_pulled(payload.id, payload.api_base):
raise OllamaModelNotPulledError(
f"Model '{payload.id}' has not been pulled. Use the Pull button to download it first."
)

# Update all fields from payload in place
for key, value in payload.model_dump().items():
setattr(model_existing, key, value)
Expand All @@ -109,6 +126,37 @@ def delete(model_provider: schema.ModelProviderType, model_id: schema.ModelIdTyp
MODEL_OBJECTS[:] = [m for m in MODEL_OBJECTS if (m.id, m.provider) != (model_id, model_provider)]


def is_ollama_model_pulled(model_id: str, api_base: str) -> bool:
"""Check whether an Ollama model has been pulled to the local instance."""
tags_url = f"{api_base.rstrip('/')}/api/tags"
try:
response = httpx.get(tags_url, timeout=5)
response.raise_for_status()
models = response.json().get("models", [])
# Ollama stores untagged models as "name:latest"
normalized_id = model_id if ":" in model_id else f"{model_id}:latest"
return any(m.get("name") == normalized_id for m in models)
except (httpx.ConnectError, httpx.HTTPStatusError, httpx.TimeoutException):
return False


async def pull_ollama_model(model_id: str, api_base: str) -> AsyncGenerator[str, None]:
"""Async generator that streams Ollama model pull progress as NDJSON lines."""
pull_url = f"{api_base.rstrip('/')}/api/pull"
LOGGER.debug("Pulling Ollama model %s from %s", model_id, pull_url)
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("POST", pull_url, json={"model": model_id}) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line:
yield line + "\n"
except httpx.HTTPStatusError as ex:
yield json.dumps({"error": f"Ollama returned HTTP {ex.response.status_code}"}) + "\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as ex:
yield json.dumps({"error": f"Cannot connect to Ollama at {api_base}: {ex}"}) + "\n"


#####################################################
# Utility Functions
#####################################################
Expand Down
32 changes: 30 additions & 2 deletions src/server/api/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from typing import Optional, Any
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse

import server.api.utils.models as utils_models

Expand Down Expand Up @@ -84,7 +84,7 @@ async def models_update(payload: schema.Model) -> schema.Model:
return utils_models.update(payload=payload)
except utils_models.UnknownModelError as ex:
raise HTTPException(status_code=404, detail=str(ex)) from ex
except utils_models.URLUnreachableError as ex:
except (utils_models.URLUnreachableError, utils_models.OllamaModelNotPulledError) as ex:
raise HTTPException(status_code=422, detail=str(ex)) from ex


Expand All @@ -101,6 +101,34 @@ async def models_create(
raise HTTPException(status_code=409, detail=str(ex)) from ex


@auth.post(
"/pull/{model_provider}/{model_id:path}",
description="Pull an Ollama model from the registry",
)
async def models_pull(
model_provider: schema.ModelProviderType,
model_id: schema.ModelIdType,
) -> StreamingResponse:
"""Pull an Ollama model and stream progress as NDJSON"""
LOGGER.debug("Received models_pull - model: %s/%s", model_provider, model_id)

if "ollama" not in model_provider:
raise HTTPException(status_code=422, detail="Pull is only supported for Ollama models")

try:
(model,) = utils_models.get(model_provider=model_provider, model_id=model_id)
except utils_models.UnknownModelError as ex:
raise HTTPException(status_code=404, detail=str(ex)) from ex

if not model.api_base:
raise HTTPException(status_code=422, detail="Model has no API base URL configured")

return StreamingResponse(
utils_models.pull_ollama_model(model_id=model_id, api_base=model.api_base),
media_type="application/x-ndjson",
)


@auth.delete(
"/{model_provider}/{model_id:path}",
description="Delete a model",
Expand Down
Loading