Skip to content

Commit ad9beba

Browse files
authored
Recover metadata from the response when missing from older data (#140)
1 parent 2d6869a commit ad9beba

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

src/paperoni/prompt.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from google import genai
1010
from google.genai import types
1111
from paperazzi.platforms.utils import Message
12-
from paperazzi.utils import _make_key as paperazzi_make_key, disk_cache, disk_store
12+
from paperazzi.utils import _make_key as paperazzi_make_key
13+
from paperazzi.utils import disk_cache, disk_store
1314
from serieux.features.comment import CommentProxy, comment_field
1415
from serieux.features.encrypt import Secret
1516

@@ -62,6 +63,36 @@ def cleanup_schema(schema: dict | Type[Any]) -> dict:
6263
return schema
6364

6465

66+
def _derive_metadata(
67+
response: types.GenerateContentResponse,
68+
structured_model: Type[Any] | None,
69+
) -> "PromptMetadata":
70+
"""Derive token counts and parsed output from response, return PromptMetadata."""
71+
usage = response.usage_metadata
72+
if usage is not None:
73+
input_tokens = usage.prompt_token_count
74+
output_tokens = (
75+
usage.candidates_token_count + usage.thoughts_token_count
76+
)
77+
total_tokens = usage.total_token_count
78+
else:
79+
input_tokens = output_tokens = total_tokens = None
80+
81+
if structured_model is not None:
82+
metadata_type = PromptMetadata[structured_model]
83+
parsed = serieux.deserialize(structured_model, response.parsed)
84+
else:
85+
metadata_type = PromptMetadata
86+
parsed = None
87+
88+
return metadata_type(
89+
input_tokens=input_tokens,
90+
output_tokens=output_tokens,
91+
total_tokens=total_tokens,
92+
parsed=parsed,
93+
)
94+
95+
6596
@dataclass
6697
class PromptMetadata[T]:
6798
# The number of input tokens in the prompt
@@ -109,15 +140,17 @@ def load(self, file_obj: BinaryIO) -> types.GenerateContentResponse:
109140
metadata_type = PromptMetadata
110141

111142
data = dict(data)
112-
comment = data.pop(comment_field, {})
143+
metadata = data.pop(comment_field, {})
144+
response: types.GenerateContentResponse = (
145+
types.GenerateContentResponse.model_validate(data)
146+
)
147+
metadata = metadata or _derive_metadata(response, self.content_type)
148+
113149
with SERIEUX_LOCK:
114-
response: types.GenerateContentResponse = (
115-
types.GenerateContentResponse.model_validate(data)
116-
)
150+
metadata = serieux.deserialize(metadata_type, metadata)
117151

118-
comment = serieux.deserialize(metadata_type, comment)
119-
response = CommentProxy(response, comment)
120-
response.parsed = comment.parsed
152+
response = CommentProxy(response, metadata)
153+
response.parsed = metadata.parsed
121154

122155
return response
123156

@@ -224,30 +257,9 @@ def prompt(
224257
contents=contents, model=model, config=config
225258
)
226259

227-
# Extract token information from usage_metadata
228-
input_tokens = response.usage_metadata.prompt_token_count
229-
output_tokens = (
230-
response.usage_metadata.candidates_token_count
231-
+ response.usage_metadata.thoughts_token_count
232-
)
233-
total_tokens = response.usage_metadata.total_token_count
260+
metadata = _derive_metadata(response, structured_model)
234261

235-
if structured_model is not None:
236-
metadata_type = PromptMetadata[structured_model]
237-
parsed = serieux.deserialize(structured_model, response.parsed)
238-
else:
239-
metadata_type = PromptMetadata
240-
parsed = None
241-
242-
response = CommentProxy(
243-
response,
244-
metadata_type(
245-
input_tokens=input_tokens,
246-
output_tokens=output_tokens,
247-
total_tokens=total_tokens,
248-
parsed=parsed,
249-
),
250-
)
262+
response = CommentProxy(response, metadata)
251263

252264
return response
253265

0 commit comments

Comments
 (0)