File tree 5 files changed +16
-6
lines changed
5 files changed +16
-6
lines changed Original file line number Diff line number Diff line change @@ -29,6 +29,8 @@ WRAPPER_API_BASE="0.0.0.0:8080"
29
29
WRAPPER_CHAT_MAX_TURNS=10
30
30
# To limit the number of OpenAI history tokens allowed in a chat (0 = no limit).
31
31
WRAPPER_CHAT_MAX_TOKENS=1536
32
+ # Public name that the wrapper will use to identify itself as a model. Will default to COLLECTION_NAME is not set.
33
+ WRAPPER_MODEL_NAME="Your Model Name"
32
34
33
35
# Validate requests auth against these bearer tokens.
34
36
AUTH_TOKENS="11111111,22222222,33333333"
Original file line number Diff line number Diff line change 20
20
)
21
21
from langchain_core .runnables import RunnableConfig
22
22
from langchain_openai import ChatOpenAI , OpenAIEmbeddings
23
- from langgraph .graph .state import START , CompiledStateGraph , StateGraph
23
+ from langgraph .constants import START
24
+ from langgraph .graph .state import CompiledStateGraph , StateGraph
24
25
from pymilvus import AnnSearchRequest , MilvusClient , WeightedRanker
25
26
26
27
import wiki_rag .index as index
@@ -51,6 +52,7 @@ class ConfigSchema(TypedDict):
51
52
stream : bool
52
53
wrapper_chat_max_turns : int
53
54
wrapper_chat_max_tokens : int
55
+ wrapper_model_name : str
54
56
55
57
56
58
class RagState (TypedDict ):
Original file line number Diff line number Diff line change @@ -67,6 +67,7 @@ def main():
67
67
if not collection_name :
68
68
logger .error ("Collection name not found in environment. Exiting." )
69
69
sys .exit (1 )
70
+ # TODO: Validate that only numbers, letters and underscores are used.
70
71
71
72
index .milvus_url = os .getenv ("MILVUS_URL" )
72
73
if not index .milvus_url :
@@ -113,6 +114,10 @@ def main():
113
114
# These are optional, default to 0 (unlimited).
114
115
wrapper_chat_max_turns = int (os .getenv ("WRAPPER_CHAT_MAX_TURNS" , 0 ))
115
116
wrapper_chat_max_tokens = int (os .getenv ("WRAPPER_CHAT_MAX_TOKENS" , 0 ))
117
+ wrapper_model_name = os .getenv ("WRAPPER_MODEL_NAME" ) or os .getenv ("COLLECTION_NAME" )
118
+ if not wrapper_model_name :
119
+ logger .error ("Public wrapper name not found in environment. Exiting." ) # This is unreachable.
120
+ sys .exit (1 )
116
121
117
122
logger .info ("Building the graph" )
118
123
server .graph = build_graph ()
@@ -136,6 +141,7 @@ def main():
136
141
stream = False ,
137
142
wrapper_chat_max_turns = wrapper_chat_max_turns ,
138
143
wrapper_chat_max_tokens = wrapper_chat_max_tokens ,
144
+ wrapper_model_name = wrapper_model_name ,
139
145
).items ()
140
146
141
147
# Prepare the configuration.
Original file line number Diff line number Diff line change @@ -82,7 +82,7 @@ async def models_list() -> ModelsListResponse:
82
82
object = "list" ,
83
83
data = [
84
84
ModelResponse (
85
- id = server .config ["configurable" ]["collection_name " ],
85
+ id = server .config ["configurable" ]["wrapper_model_name " ],
86
86
object = "model" ,
87
87
created = int (time .time ()),
88
88
owned_by = server .config ["configurable" ]["kb_name" ],
@@ -113,7 +113,7 @@ async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResp
113
113
if not request .model :
114
114
raise HTTPException (status_code = 400 , detail = "No model provided." )
115
115
116
- if request .model != server .config ["configurable" ]["collection_name " ]:
116
+ if request .model != server .config ["configurable" ]["wrapper_model_name " ]:
117
117
raise HTTPException (status_code = 400 , detail = "Model not supported." )
118
118
119
119
logger .debug (f"Request: { request } " )
Original file line number Diff line number Diff line change @@ -42,7 +42,7 @@ class ChatCompletionRequest(BaseModel):
42
42
max_completion_tokens : int | None = 768 # Max tokens to generate (not all models support this).
43
43
temperature : float | None = 0.05 # Temperature for sampling (0.0, deterministic to 2.0, creative).
44
44
top_p : float | None = 0.85 # Which probability (0.0 - 1.0) is used to consider the next token (0.85 default).
45
- model : str = server .config ["configurable" ]["collection_name " ]
45
+ model : str = server .config ["configurable" ]["wrapper_model_name " ]
46
46
messages : list [Message ] = [Message (role = "user" , content = "Hello!" )]
47
47
stream : bool | None = False
48
48
@@ -60,14 +60,14 @@ class ChatCompletionResponse(BaseModel):
60
60
id : UUID4 = uuid .uuid4 ()
61
61
object : str = "chat.completion"
62
62
created : int = int (time .time ())
63
- model : str = server .config ["configurable" ]["collection_name " ]
63
+ model : str = server .config ["configurable" ]["wrapper_model_name " ]
64
64
choices : list [ChoiceResponse ] = [ChoiceResponse ()]
65
65
66
66
67
67
class ModelResponse (BaseModel ):
68
68
"""Information about a LLM model."""
69
69
70
- id : str = server .config ["configurable" ]["collection_name " ]
70
+ id : str = server .config ["configurable" ]["wrapper_model_name " ]
71
71
object : str = "model"
72
72
created : int = int (time .time ())
73
73
owned_by : str = server .config ["configurable" ]["kb_name" ]
You can’t perform that action at this time.
0 commit comments