generated from oracle/template-repo
-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathmodels.py
More file actions
143 lines (115 loc) · 4.83 KB
/
models.py
File metadata and controls
143 lines (115 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Copyright (c) 2024, 2026, Oracle and/or its affiliates.
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
"""
import logging
from typing import Optional, Any
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import JSONResponse, StreamingResponse
import server.api.utils.models as utils_models
from common import schema
LOGGER = logging.getLogger("endpoints.v1.models")
auth = APIRouter()
@auth.get(
"",
description="Get all models (by default, only enabled)",
response_model=list[schema.Model],
)
async def models_list(
model_type: Optional[schema.ModelTypeType] = Query(None),
include_disabled: schema.ModelEnabledType = Query(False, description="Include disabled models"),
) -> list[schema.Model]:
"""List all models after applying filters if specified"""
LOGGER.debug("Received models_list - type: %s; include_disabled: %s", model_type, include_disabled)
models_ret = utils_models.get(model_type=model_type, include_disabled=include_disabled)
print(models_ret)
return models_ret
@auth.get(
"/supported",
description="Get supported providers and models",
response_model=list[dict[str, Any]],
)
async def models_supported(
model_provider: Optional[schema.ModelProviderType] = Query(None),
model_type: Optional[schema.ModelTypeType] = Query(None),
) -> list[dict[str, Any]]:
"""List all model Providers"""
LOGGER.debug("Received models_supported")
return utils_models.get_supported(model_provider=model_provider, model_type=model_type)
@auth.get(
"/{model_provider}/{model_id:path}",
description="Get a single model (provider/name)",
response_model=schema.Model,
)
async def models_get(
model_provider: schema.ModelProviderType,
model_id: schema.ModelIdType,
) -> schema.Model:
"""List a specific model"""
LOGGER.debug("Received models_get - model: %s/%s", model_provider, model_id)
try:
(models_ret,) = utils_models.get(model_provider=model_provider, model_id=model_id)
except utils_models.UnknownModelError as ex:
raise HTTPException(status_code=404, detail=str(ex)) from ex
except ValueError as ex:
# happens if >1 results
raise HTTPException(status_code=404, detail="Multiple models returned") from ex
return models_ret
@auth.patch(
"/{model_provider}/{model_id:path}",
description="Update a model",
response_model=schema.Model,
)
async def models_update(payload: schema.Model) -> schema.Model:
"""Update a model"""
LOGGER.debug("Received models_update - payload: %s", payload)
try:
return utils_models.update(payload=payload)
except utils_models.UnknownModelError as ex:
raise HTTPException(status_code=404, detail=str(ex)) from ex
except (utils_models.URLUnreachableError, utils_models.OllamaModelNotPulledError) as ex:
raise HTTPException(status_code=422, detail=str(ex)) from ex
@auth.post("", description="Create a model", response_model=schema.Model, status_code=201)
async def models_create(
payload: schema.Model,
) -> schema.Model:
"""Create a model"""
LOGGER.debug("Received model_create - payload: %s", payload)
try:
return utils_models.create(payload)
except utils_models.ExistsModelError as ex:
raise HTTPException(status_code=409, detail=str(ex)) from ex
@auth.post(
"/pull/{model_provider}/{model_id:path}",
description="Pull an Ollama model from the registry",
)
async def models_pull(
model_provider: schema.ModelProviderType,
model_id: schema.ModelIdType,
) -> StreamingResponse:
"""Pull an Ollama model and stream progress as NDJSON"""
LOGGER.debug("Received models_pull - model: %s/%s", model_provider, model_id)
if "ollama" not in model_provider:
raise HTTPException(status_code=422, detail="Pull is only supported for Ollama models")
try:
(model,) = utils_models.get(model_provider=model_provider, model_id=model_id)
except utils_models.UnknownModelError as ex:
raise HTTPException(status_code=404, detail=str(ex)) from ex
if not model.api_base:
raise HTTPException(status_code=422, detail="Model has no API base URL configured")
return StreamingResponse(
utils_models.pull_ollama_model(model_id=model_id, api_base=model.api_base),
media_type="application/x-ndjson",
)
@auth.delete(
"/{model_provider}/{model_id:path}",
description="Delete a model",
)
async def models_delete(
model_provider: schema.ModelProviderType,
model_id: schema.ModelIdType,
) -> JSONResponse:
"""Delete a model"""
LOGGER.debug("Received models_delete - model: %s/%s", model_provider, model_id)
utils_models.delete(model_provider=model_provider, model_id=model_id)
return JSONResponse(status_code=200, content={"message": f"Model: {model_provider}/{model_id} deleted."})