Skip to content

Commit 82bead9

Browse files
authored
Update types.py
1 parent e15563f commit 82bead9

File tree

1 file changed

+68
-68
lines changed

1 file changed

+68
-68
lines changed

llama_cpp/server/types.py

+68-68
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import Dict, List, Literal, Optional, Union
3+
from typing import Dict, List, Optional, Union
44

55
from pydantic import BaseModel, Field
6-
from typing_extensions import TypedDict
6+
from typing_extensions import Literal, TypedDict
77

88
import llama_cpp
99

1010
model_field = Field(
11-
description="The model to use for generating completions.", default=None,
11+
description="The model to use for generating completions.", default=None
1212
)
1313

1414
max_tokens_field = Field(
15-
default=16, ge=1, description="The maximum number of tokens to generate.",
15+
default=16, ge=1, description="The maximum number of tokens to generate."
1616
)
1717

1818
min_tokens_field = Field(
@@ -96,7 +96,7 @@
9696
)
9797

9898
mirostat_eta_field = Field(
99-
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate",
99+
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
100100
)
101101

102102
grammar = Field(
@@ -106,15 +106,15 @@
106106

107107

108108
class CreateCompletionRequest(BaseModel):
109-
prompt: str | list[str] = Field(
110-
default="", description="The prompt to generate completions for.",
109+
prompt: Union[str, List[str]] = Field(
110+
default="", description="The prompt to generate completions for."
111111
)
112-
suffix: str | None = Field(
112+
suffix: Optional[str] = Field(
113113
default=None,
114114
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
115115
)
116-
max_tokens: int | None = Field(
117-
default=16, ge=0, description="The maximum number of tokens to generate.",
116+
max_tokens: Optional[int] = Field(
117+
default=16, ge=0, description="The maximum number of tokens to generate."
118118
)
119119
min_tokens: int = min_tokens_field
120120
temperature: float = temperature_field
@@ -124,172 +124,172 @@ class CreateCompletionRequest(BaseModel):
124124
default=False,
125125
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
126126
)
127-
stop: str | list[str] | None = stop_field
127+
stop: Optional[Union[str, List[str]]] = stop_field
128128
stream: bool = stream_field
129-
logprobs: int | None = Field(
129+
logprobs: Optional[int] = Field(
130130
default=None,
131131
ge=0,
132132
description="The number of logprobs to generate. If None, no logprobs are generated.",
133133
)
134-
presence_penalty: float | None = presence_penalty_field
135-
frequency_penalty: float | None = frequency_penalty_field
136-
logit_bias: dict[str, float] | None = Field(None)
137-
seed: int | None = Field(None)
134+
presence_penalty: Optional[float] = presence_penalty_field
135+
frequency_penalty: Optional[float] = frequency_penalty_field
136+
logit_bias: Optional[Dict[str, float]] = Field(None)
137+
seed: Optional[int] = Field(None)
138138

139139
# ignored or currently unsupported
140-
model: str | None = model_field
141-
n: int | None = 1
142-
best_of: int | None = 1
143-
user: str | None = Field(default=None)
140+
model: Optional[str] = model_field
141+
n: Optional[int] = 1
142+
best_of: Optional[int] = 1
143+
user: Optional[str] = Field(default=None)
144144

145145
# llama.cpp specific parameters
146146
top_k: int = top_k_field
147147
repeat_penalty: float = repeat_penalty_field
148-
logit_bias_type: Literal["input_ids", "tokens"] | None = Field(None)
148+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
149149
mirostat_mode: int = mirostat_mode_field
150150
mirostat_tau: float = mirostat_tau_field
151151
mirostat_eta: float = mirostat_eta_field
152-
grammar: str | None = None
152+
grammar: Optional[str] = None
153153

154154
model_config = {
155155
"json_schema_extra": {
156156
"examples": [
157157
{
158158
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
159159
"stop": ["\n", "###"],
160-
},
161-
],
162-
},
160+
}
161+
]
162+
}
163163
}
164164

165165

166166
class CreateEmbeddingRequest(BaseModel):
167-
model: str | None = model_field
168-
input: str | list[str] = Field(description="The input to embed.")
169-
user: str | None = Field(default=None)
167+
model: Optional[str] = model_field
168+
input: Union[str, List[str]] = Field(description="The input to embed.")
169+
user: Optional[str] = Field(default=None)
170170

171171
model_config = {
172172
"json_schema_extra": {
173173
"examples": [
174174
{
175175
"input": "The food was delicious and the waiter...",
176-
},
177-
],
178-
},
176+
}
177+
]
178+
}
179179
}
180180

181181

182182
class ChatCompletionRequestMessage(BaseModel):
183183
role: Literal["system", "user", "assistant", "function"] = Field(
184-
default="user", description="The role of the message.",
184+
default="user", description="The role of the message."
185185
)
186-
content: str | None = Field(
187-
default="", description="The content of the message.",
186+
content: Optional[str] = Field(
187+
default="", description="The content of the message."
188188
)
189189

190190

191191
class CreateChatCompletionRequest(BaseModel):
192-
messages: list[llama_cpp.ChatCompletionRequestMessage] = Field(
193-
default=[], description="A list of messages to generate completions for.",
192+
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
193+
default=[], description="A list of messages to generate completions for."
194194
)
195-
functions: list[llama_cpp.ChatCompletionFunction] | None = Field(
195+
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
196196
default=None,
197197
description="A list of functions to apply to the generated completions.",
198198
)
199-
function_call: llama_cpp.ChatCompletionRequestFunctionCall | None = Field(
199+
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
200200
default=None,
201201
description="A function to apply to the generated completions.",
202202
)
203-
tools: list[llama_cpp.ChatCompletionTool] | None = Field(
203+
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
204204
default=None,
205205
description="A list of tools to apply to the generated completions.",
206206
)
207-
tool_choice: llama_cpp.ChatCompletionToolChoiceOption | None = Field(
207+
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
208208
default=None,
209209
description="A tool to apply to the generated completions.",
210210
) # TODO: verify
211-
max_tokens: int | None = Field(
211+
max_tokens: Optional[int] = Field(
212212
default=None,
213213
description="The maximum number of tokens to generate. Defaults to inf",
214214
)
215215
min_tokens: int = min_tokens_field
216-
logprobs: bool | None = Field(
216+
logprobs: Optional[bool] = Field(
217217
default=False,
218218
description="Whether to output the logprobs or not. Default is True",
219219
)
220-
top_logprobs: int | None = Field(
220+
top_logprobs: Optional[int] = Field(
221221
default=None,
222222
ge=0,
223223
description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.",
224224
)
225225
temperature: float = temperature_field
226226
top_p: float = top_p_field
227227
min_p: float = min_p_field
228-
stop: str | list[str] | None = stop_field
228+
stop: Optional[Union[str, List[str]]] = stop_field
229229
stream: bool = stream_field
230-
presence_penalty: float | None = presence_penalty_field
231-
frequency_penalty: float | None = frequency_penalty_field
232-
logit_bias: dict[str, float] | None = Field(None)
233-
seed: int | None = Field(None)
234-
response_format: llama_cpp.ChatCompletionRequestResponseFormat | None = Field(
230+
presence_penalty: Optional[float] = presence_penalty_field
231+
frequency_penalty: Optional[float] = frequency_penalty_field
232+
logit_bias: Optional[Dict[str, float]] = Field(None)
233+
seed: Optional[int] = Field(None)
234+
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
235235
default=None,
236236
)
237237

238238
# ignored or currently unsupported
239-
model: str | None = model_field
240-
n: int | None = 1
241-
user: str | None = Field(None)
239+
model: Optional[str] = model_field
240+
n: Optional[int] = 1
241+
user: Optional[str] = Field(None)
242242

243243
# llama.cpp specific parameters
244244
top_k: int = top_k_field
245245
repeat_penalty: float = repeat_penalty_field
246-
logit_bias_type: Literal["input_ids", "tokens"] | None = Field(None)
246+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
247247
mirostat_mode: int = mirostat_mode_field
248248
mirostat_tau: float = mirostat_tau_field
249249
mirostat_eta: float = mirostat_eta_field
250-
grammar: str | None = None
250+
grammar: Optional[str] = None
251251

252252
model_config = {
253253
"json_schema_extra": {
254254
"examples": [
255255
{
256256
"messages": [
257257
ChatCompletionRequestMessage(
258-
role="system", content="You are a helpful assistant.",
258+
role="system", content="You are a helpful assistant."
259259
).model_dump(),
260260
ChatCompletionRequestMessage(
261-
role="user", content="What is the capital of France?",
261+
role="user", content="What is the capital of France?"
262262
).model_dump(),
263-
],
264-
},
265-
],
266-
},
263+
]
264+
}
265+
]
266+
}
267267
}
268268

269269

270270
class ModelData(TypedDict):
271271
id: str
272272
object: Literal["model"]
273273
owned_by: str
274-
permissions: list[str]
274+
permissions: List[str]
275275

276276

277277
class ModelList(TypedDict):
278278
object: Literal["list"]
279-
data: list[ModelData]
279+
data: List[ModelData]
280280

281281

282282
class TokenizeInputRequest(BaseModel):
283-
model: str | None = model_field
283+
model: Optional[str] = model_field
284284
input: str = Field(description="The input to tokenize.")
285285

286286
model_config = {
287-
"json_schema_extra": {"examples": [{"input": "How many tokens in this query?"}]},
287+
"json_schema_extra": {"examples": [{"input": "How many tokens in this query?"}]}
288288
}
289289

290290

291291
class TokenizeInputResponse(BaseModel):
292-
tokens: list[int] = Field(description="A list of tokens.")
292+
tokens: List[int] = Field(description="A list of tokens.")
293293

294294
model_config = {"json_schema_extra": {"example": {"tokens": [123, 321, 222]}}}
295295

@@ -301,8 +301,8 @@ class TokenizeInputCountResponse(BaseModel):
301301

302302

303303
class DetokenizeInputRequest(BaseModel):
304-
model: str | None = model_field
305-
tokens: list[int] = Field(description="A list of toekns to detokenize.")
304+
model: Optional[str] = model_field
305+
tokens: List[int] = Field(description="A list of toekns to detokenize.")
306306

307307
model_config = {"json_schema_extra": {"example": [{"tokens": [123, 321, 222]}]}}
308308

@@ -311,5 +311,5 @@ class DetokenizeInputResponse(BaseModel):
311311
text: str = Field(description="The detokenized text.")
312312

313313
model_config = {
314-
"json_schema_extra": {"example": {"text": "How many tokens in this query?"}},
314+
"json_schema_extra": {"example": {"text": "How many tokens in this query?"}}
315315
}

0 commit comments

Comments
 (0)