Skip to content

Commit e5daa22

Browse files
Merge pull request #153 from Aleph-Alpha/add_token_counts
Add new token usage count fields
2 parents 5ece809 + c0beb1d commit e5daa22

File tree

6 files changed

+71
-5
lines changed

6 files changed

+71
-5
lines changed

Changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## 5.0.0
4+
5+
- Added `num_tokens_prompt_total` and `num_tokens_generated` to `CompletionResponse`. This is a
6+
breaking change as these were introduced as mandatory parameters rather than optional ones.
7+
HTTP API version 1.14.0 or higher is required.
8+
39
## 4.1.0
410

511
- Added `verify_ssl` flag so you can disable SSL checking for your sessions.

aleph_alpha_client/completion.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,31 @@ def _asdict(self) -> Mapping[str, Any]:
247247

248248
@dataclass(frozen=True)
249249
class CompletionResponse:
250+
"""
251+
Describes a completion response
252+
253+
Parameters:
254+
model_version:
255+
Model name and version (if any) of the used model for inference.
256+
completions:
257+
List of completions; may contain only one entry if no more are requested (see parameter n).
258+
num_tokens_prompt_total:
259+
Number of tokens combined across all completion tasks.
260+
In particular, if you set best_of or n to a number larger than 1 then we report the
261+
combined prompt token count for all best_of or n tasks.
262+
num_tokens_generated:
263+
Number of tokens combined across all completion tasks.
264+
If multiple completions are returned or best_of is set to a value greater than 1 then
265+
this value contains the combined generated token count.
266+
optimized_prompt:
267+
Describes prompt after optimizations. This field is only returned if the flag
268+
`disable_optimizations` flag is not set and the prompt has actually changed.
269+
"""
270+
250271
model_version: str
251272
completions: Sequence[CompletionResult]
273+
num_tokens_prompt_total: int
274+
num_tokens_generated: int
252275
optimized_prompt: Optional[Prompt] = None
253276

254277
@staticmethod
@@ -259,6 +282,8 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse":
259282
completions=[
260283
CompletionResult.from_json(item) for item in json["completions"]
261284
],
285+
num_tokens_prompt_total=json["num_tokens_prompt_total"],
286+
num_tokens_generated=json["num_tokens_generated"],
262287
optimized_prompt=Prompt.from_json(optimized_prompt_json)
263288
if optimized_prompt_json
264289
else None,

aleph_alpha_client/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "4.1.0"
1+
__version__ = "5.0.0"

tests/test_clients.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def test_nice_flag_on_client(httpserver: HTTPServer):
3333
).respond_with_json(
3434
CompletionResponse(
3535
"model_version",
36-
[CompletionResult(log_probs=[], completion="foo")],
36+
[CompletionResult(log_probs=[], completion="foo", )],
37+
num_tokens_prompt_total=2,
38+
num_tokens_generated=1,
3739
).to_json()
3840
)
3941

@@ -47,11 +49,14 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer):
4749
httpserver.expect_request("/version").respond_with_data("OK")
4850

4951
httpserver.expect_request(
50-
"/complete", query_string={"nice": "true"}
52+
"/complete",
53+
query_string={"nice": "true"},
5154
).respond_with_json(
5255
CompletionResponse(
5356
"model_version",
5457
[CompletionResult(log_probs=[], completion="foo")],
58+
num_tokens_prompt_total=2,
59+
num_tokens_generated=1,
5560
).to_json()
5661
)
5762

tests/test_complete.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,34 @@ def test_complete_with_echo(sync_client: Client, model_name: str, prompt_image:
127127
assert len(completion_result.completion_tokens) > 0
128128
assert completion_result.log_probs is not None
129129
assert len(completion_result.log_probs) > 0
130+
131+
@pytest.mark.system_test
132+
def test_num_tokens_prompt_total_with_best_of(sync_client: Client, model_name: str):
133+
tokens = [49222, 2998] # Hello world
134+
best_of = 2
135+
request = CompletionRequest(
136+
prompt = Prompt.from_tokens(tokens),
137+
best_of = best_of,
138+
maximum_tokens = 1,
139+
)
140+
141+
response = sync_client.complete(request, model=model_name)
142+
assert response.num_tokens_prompt_total == len(tokens) * best_of
143+
144+
@pytest.mark.system_test
145+
def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str):
146+
hello_world = [49222, 2998] # Hello world
147+
best_of = 2
148+
request = CompletionRequest(
149+
prompt = Prompt.from_tokens(hello_world),
150+
best_of = best_of,
151+
maximum_tokens = 1,
152+
tokens = True,
153+
)
154+
155+
response = sync_client.complete(request, model=model_name)
156+
completion_result = response.completions[0]
157+
assert completion_result.completion_tokens is not None
158+
number_tokens_completion = len(completion_result.completion_tokens)
159+
160+
assert response.num_tokens_generated == best_of * number_tokens_completion

tests/test_error_handling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def expect_retryable_error(
111111

112112
def expect_valid_completion(httpserver: HTTPServer) -> None:
113113
httpserver.expect_ordered_request("/complete").respond_with_json(
114-
{"model_version": "1", "completions": []}
115-
)
114+
{"model_version": "1", "completions": [], "num_tokens_prompt_total": 0, "num_tokens_generated": 0})
116115

117116

118117
def expect_valid_version(httpserver: HTTPServer) -> None:

0 commit comments

Comments
 (0)