Skip to content

Commit ea86ead

Browse files
committed
fix(environment): Split WRAPPER_MODEL_NAME from COLLECTION_NAME
Still, the former will default to the later when no defined, so this change is BC. But enables us to use better model names, not subject to the collection name restrictions (alphanum + low hyphen).
1 parent 40c3f3c commit ea86ead

File tree

5 files changed

+16
-6
lines changed

5 files changed

+16
-6
lines changed

env_template

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ WRAPPER_API_BASE="0.0.0.0:8080"
2929
WRAPPER_CHAT_MAX_TURNS=10
3030
# To limit the number of OpenAI history tokens allowed in a chat (0 = no limit).
3131
WRAPPER_CHAT_MAX_TOKENS=1536
32+
# Public name that the wrapper will use to identify itself as a model. Will default to COLLECTION_NAME is not set.
33+
WRAPPER_MODEL_NAME="Your Model Name"
3234

3335
# Validate requests auth against these bearer tokens.
3436
AUTH_TOKENS="11111111,22222222,33333333"

wiki_rag/search/util.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
)
2121
from langchain_core.runnables import RunnableConfig
2222
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
23-
from langgraph.graph.state import START, CompiledStateGraph, StateGraph
23+
from langgraph.constants import START
24+
from langgraph.graph.state import CompiledStateGraph, StateGraph
2425
from pymilvus import AnnSearchRequest, MilvusClient, WeightedRanker
2526

2627
import wiki_rag.index as index
@@ -51,6 +52,7 @@ class ConfigSchema(TypedDict):
5152
stream: bool
5253
wrapper_chat_max_turns: int
5354
wrapper_chat_max_tokens: int
55+
wrapper_model_name: str
5456

5557

5658
class RagState(TypedDict):

wiki_rag/server/main.py

+6
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def main():
6767
if not collection_name:
6868
logger.error("Collection name not found in environment. Exiting.")
6969
sys.exit(1)
70+
# TODO: Validate that only numbers, letters and underscores are used.
7071

7172
index.milvus_url = os.getenv("MILVUS_URL")
7273
if not index.milvus_url:
@@ -113,6 +114,10 @@ def main():
113114
# These are optional, default to 0 (unlimited).
114115
wrapper_chat_max_turns = int(os.getenv("WRAPPER_CHAT_MAX_TURNS", 0))
115116
wrapper_chat_max_tokens = int(os.getenv("WRAPPER_CHAT_MAX_TOKENS", 0))
117+
wrapper_model_name = os.getenv("WRAPPER_MODEL_NAME") or os.getenv("COLLECTION_NAME")
118+
if not wrapper_model_name:
119+
logger.error("Public wrapper name not found in environment. Exiting.") # This is unreachable.
120+
sys.exit(1)
116121

117122
logger.info("Building the graph")
118123
server.graph = build_graph()
@@ -136,6 +141,7 @@ def main():
136141
stream=False,
137142
wrapper_chat_max_turns=wrapper_chat_max_turns,
138143
wrapper_chat_max_tokens=wrapper_chat_max_tokens,
144+
wrapper_model_name=wrapper_model_name,
139145
).items()
140146

141147
# Prepare the configuration.

wiki_rag/server/server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def models_list() -> ModelsListResponse:
8282
object="list",
8383
data=[
8484
ModelResponse(
85-
id=server.config["configurable"]["collection_name"],
85+
id=server.config["configurable"]["wrapper_model_name"],
8686
object="model",
8787
created=int(time.time()),
8888
owned_by=server.config["configurable"]["kb_name"],
@@ -113,7 +113,7 @@ async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResp
113113
if not request.model:
114114
raise HTTPException(status_code=400, detail="No model provided.")
115115

116-
if request.model != server.config["configurable"]["collection_name"]:
116+
if request.model != server.config["configurable"]["wrapper_model_name"]:
117117
raise HTTPException(status_code=400, detail="Model not supported.")
118118

119119
logger.debug(f"Request: {request}")

wiki_rag/server/util.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ChatCompletionRequest(BaseModel):
4242
max_completion_tokens: int | None = 768 # Max tokens to generate (not all models support this).
4343
temperature: float | None = 0.05 # Temperature for sampling (0.0, deterministic to 2.0, creative).
4444
top_p: float | None = 0.85 # Which probability (0.0 - 1.0) is used to consider the next token (0.85 default).
45-
model: str = server.config["configurable"]["collection_name"]
45+
model: str = server.config["configurable"]["wrapper_model_name"]
4646
messages: list[Message] = [Message(role="user", content="Hello!")]
4747
stream: bool | None = False
4848

@@ -60,14 +60,14 @@ class ChatCompletionResponse(BaseModel):
6060
id: UUID4 = uuid.uuid4()
6161
object: str = "chat.completion"
6262
created: int = int(time.time())
63-
model: str = server.config["configurable"]["collection_name"]
63+
model: str = server.config["configurable"]["wrapper_model_name"]
6464
choices: list[ChoiceResponse] = [ChoiceResponse()]
6565

6666

6767
class ModelResponse(BaseModel):
6868
"""Information about a LLM model."""
6969

70-
id: str = server.config["configurable"]["collection_name"]
70+
id: str = server.config["configurable"]["wrapper_model_name"]
7171
object: str = "model"
7272
created: int = int(time.time())
7373
owned_by: str = server.config["configurable"]["kb_name"]

0 commit comments

Comments
 (0)