Skip to content

Commit bfb7d34

Browse files
committed
Fix chat completion request and logprob contracts
Signed-off-by: Ritwij Aryan Parmar <ritwij.aryan.parmar@gmail.com>
1 parent cc9f253 commit bfb7d34

4 files changed

Lines changed: 75 additions & 2 deletions

File tree

src/together/resources/chat/completions.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, AsyncGenerator, Dict, Iterator, List
3+
from typing import Any, AsyncGenerator, Dict, Iterator, List, Literal
44

55
from together.abstract import api_requestor
66
from together.together_response import TogetherResponse
@@ -32,6 +32,7 @@ def create(
3232
frequency_penalty: float | None = None,
3333
min_p: float | None = None,
3434
logit_bias: Dict[str, float] | None = None,
35+
context_length_exceeded_behavior: Literal["truncate", "error"] | None = None,
3536
seed: int | None = None,
3637
stream: bool = False,
3738
logprobs: int | None = None,
@@ -80,6 +81,9 @@ def create(
8081
logit_bias (Dict[str, float], optional): A dictionary of tokens and their bias values that modify the
8182
likelihood of specific tokens being sampled. Bias values must be in the range [-100, 100].
8283
Defaults to None.
84+
context_length_exceeded_behavior ("truncate" | "error", optional): Behavior when max_tokens exceeds the
85+
model context length. "error" returns a 400, while "truncate" overrides max_tokens with the model's
86+
maximum context length.
8387
seed (int, optional): A seed value to use for reproducibility.
8488
stream (bool, optional): Flag indicating whether to stream the generated completions.
8589
Defaults to False.
@@ -126,6 +130,7 @@ def create(
126130
frequency_penalty=frequency_penalty,
127131
min_p=min_p,
128132
logit_bias=logit_bias,
133+
context_length_exceeded_behavior=context_length_exceeded_behavior,
129134
seed=seed,
130135
stream=stream,
131136
logprobs=logprobs,
@@ -174,6 +179,7 @@ async def create(
174179
frequency_penalty: float | None = None,
175180
min_p: float | None = None,
176181
logit_bias: Dict[str, float] | None = None,
182+
context_length_exceeded_behavior: Literal["truncate", "error"] | None = None,
177183
seed: int | None = None,
178184
stream: bool = False,
179185
logprobs: int | None = None,
@@ -222,6 +228,9 @@ async def create(
222228
logit_bias (Dict[str, float], optional): A dictionary of tokens and their bias values that modify the
223229
likelihood of specific tokens being sampled. Bias values must be in the range [-100, 100].
224230
Defaults to None.
231+
context_length_exceeded_behavior ("truncate" | "error", optional): Behavior when max_tokens exceeds the
232+
model context length. "error" returns a 400, while "truncate" overrides max_tokens with the model's
233+
maximum context length.
225234
seed (int, optional): A seed value to use for reproducibility.
226235
stream (bool, optional): Flag indicating whether to stream the generated completions.
227236
Defaults to False.
@@ -268,6 +277,7 @@ async def create(
268277
frequency_penalty=frequency_penalty,
269278
min_p=min_p,
270279
logit_bias=logit_bias,
280+
context_length_exceeded_behavior=context_length_exceeded_behavior,
271281
seed=seed,
272282
stream=stream,
273283
logprobs=logprobs,

src/together/types/chat_completions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from enum import Enum
5-
from typing import Any, Dict, List
5+
from typing import Any, Dict, List, Literal
66

77
from pydantic import model_validator
88
from typing_extensions import Self
@@ -132,6 +132,8 @@ class ChatCompletionRequest(BaseModel):
132132
frequency_penalty: float | None = None
133133
min_p: float | None = None
134134
logit_bias: Dict[str, float] | None = None
135+
# behavior when max_tokens exceeds the model context length
136+
context_length_exceeded_behavior: Literal["truncate", "error"] | None = None
135137
seed: int | None = None
136138
# stream SSE token chunks
137139
stream: bool = False

src/together/types/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class LogprobsPart(BaseModel):
4242
tokens: List[str | None] | None = None
4343
# token logprob list
4444
token_logprobs: List[float | None] | None = None
45+
# top-k logprobs per token
46+
top_logprobs: List[Dict[str, float]] | None = None
4547

4648

4749
class PromptPart(BaseModel):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import inspect
2+
3+
from together.resources.chat.completions import AsyncChatCompletions, ChatCompletions
4+
from together.types.chat_completions import (
5+
ChatCompletionRequest,
6+
ChatCompletionResponse,
7+
)
8+
from together.types.common import LogprobsPart
9+
10+
11+
def test_chat_completion_create_exposes_context_length_behavior() -> None:
12+
sync_signature = inspect.signature(ChatCompletions.create)
13+
async_signature = inspect.signature(AsyncChatCompletions.create)
14+
15+
assert "context_length_exceeded_behavior" in sync_signature.parameters
16+
assert "context_length_exceeded_behavior" in async_signature.parameters
17+
18+
19+
def test_chat_completion_request_serializes_context_length_behavior() -> None:
20+
request = ChatCompletionRequest(
21+
model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
22+
messages=[{"role": "user", "content": "Hello"}],
23+
context_length_exceeded_behavior="truncate",
24+
)
25+
26+
assert (
27+
request.model_dump(exclude_none=True)["context_length_exceeded_behavior"]
28+
== "truncate"
29+
)
30+
31+
32+
def test_logprobs_part_models_top_logprobs_as_list_per_token() -> None:
33+
assert "top_logprobs" in LogprobsPart.model_fields
34+
35+
response = ChatCompletionResponse(
36+
choices=[
37+
{
38+
"logprobs": {
39+
"tokens": ["Hello", "."],
40+
"token_logprobs": [-0.1, -0.2],
41+
"top_logprobs": [
42+
{"Hello": -0.1, "Hi": -1.4},
43+
{".": -0.2, "!": -1.7},
44+
],
45+
}
46+
}
47+
]
48+
)
49+
50+
assert response.choices is not None
51+
assert response.choices[0].logprobs is not None
52+
assert response.choices[0].logprobs.top_logprobs == [
53+
{"Hello": -0.1, "Hi": -1.4},
54+
{".": -0.2, "!": -1.7},
55+
]
56+
assert response.model_dump()["choices"][0]["logprobs"]["top_logprobs"] == [
57+
{"Hello": -0.1, "Hi": -1.4},
58+
{".": -0.2, "!": -1.7},
59+
]

0 commit comments

Comments
 (0)