|
9 | 9 | from google import genai |
10 | 10 | from google.genai import types |
11 | 11 | from paperazzi.platforms.utils import Message |
12 | | -from paperazzi.utils import _make_key as paperazzi_make_key, disk_cache, disk_store |
| 12 | +from paperazzi.utils import _make_key as paperazzi_make_key |
| 13 | +from paperazzi.utils import disk_cache, disk_store |
13 | 14 | from serieux.features.comment import CommentProxy, comment_field |
14 | 15 | from serieux.features.encrypt import Secret |
15 | 16 |
|
@@ -62,6 +63,36 @@ def cleanup_schema(schema: dict | Type[Any]) -> dict: |
62 | 63 | return schema |
63 | 64 |
|
64 | 65 |
|
| 66 | +def _derive_metadata( |
| 67 | + response: types.GenerateContentResponse, |
| 68 | + structured_model: Type[Any] | None, |
| 69 | +) -> "PromptMetadata": |
| 70 | + """Derive token counts and parsed output from response, return PromptMetadata.""" |
| 71 | + usage = response.usage_metadata |
| 72 | + if usage is not None: |
| 73 | + input_tokens = usage.prompt_token_count |
| 74 | + output_tokens = ( |
| 75 | + usage.candidates_token_count + usage.thoughts_token_count |
| 76 | + ) |
| 77 | + total_tokens = usage.total_token_count |
| 78 | + else: |
| 79 | + input_tokens = output_tokens = total_tokens = None |
| 80 | + |
| 81 | + if structured_model is not None: |
| 82 | + metadata_type = PromptMetadata[structured_model] |
| 83 | + parsed = serieux.deserialize(structured_model, response.parsed) |
| 84 | + else: |
| 85 | + metadata_type = PromptMetadata |
| 86 | + parsed = None |
| 87 | + |
| 88 | + return metadata_type( |
| 89 | + input_tokens=input_tokens, |
| 90 | + output_tokens=output_tokens, |
| 91 | + total_tokens=total_tokens, |
| 92 | + parsed=parsed, |
| 93 | + ) |
| 94 | + |
| 95 | + |
65 | 96 | @dataclass |
66 | 97 | class PromptMetadata[T]: |
67 | 98 | # The number of input tokens in the prompt |
@@ -109,15 +140,17 @@ def load(self, file_obj: BinaryIO) -> types.GenerateContentResponse: |
109 | 140 | metadata_type = PromptMetadata |
110 | 141 |
|
111 | 142 | data = dict(data) |
112 | | - comment = data.pop(comment_field, {}) |
| 143 | + metadata = data.pop(comment_field, {}) |
| 144 | + response: types.GenerateContentResponse = ( |
| 145 | + types.GenerateContentResponse.model_validate(data) |
| 146 | + ) |
| 147 | + metadata = metadata or _derive_metadata(response, self.content_type) |
| 148 | + |
113 | 149 | with SERIEUX_LOCK: |
114 | | - response: types.GenerateContentResponse = ( |
115 | | - types.GenerateContentResponse.model_validate(data) |
116 | | - ) |
| 150 | + metadata = serieux.deserialize(metadata_type, metadata) |
117 | 151 |
|
118 | | - comment = serieux.deserialize(metadata_type, comment) |
119 | | - response = CommentProxy(response, comment) |
120 | | - response.parsed = comment.parsed |
| 152 | + response = CommentProxy(response, metadata) |
| 153 | + response.parsed = metadata.parsed |
121 | 154 |
|
122 | 155 | return response |
123 | 156 |
|
@@ -224,30 +257,9 @@ def prompt( |
224 | 257 | contents=contents, model=model, config=config |
225 | 258 | ) |
226 | 259 |
|
227 | | - # Extract token information from usage_metadata |
228 | | - input_tokens = response.usage_metadata.prompt_token_count |
229 | | - output_tokens = ( |
230 | | - response.usage_metadata.candidates_token_count |
231 | | - + response.usage_metadata.thoughts_token_count |
232 | | - ) |
233 | | - total_tokens = response.usage_metadata.total_token_count |
| 260 | + metadata = _derive_metadata(response, structured_model) |
234 | 261 |
|
235 | | - if structured_model is not None: |
236 | | - metadata_type = PromptMetadata[structured_model] |
237 | | - parsed = serieux.deserialize(structured_model, response.parsed) |
238 | | - else: |
239 | | - metadata_type = PromptMetadata |
240 | | - parsed = None |
241 | | - |
242 | | - response = CommentProxy( |
243 | | - response, |
244 | | - metadata_type( |
245 | | - input_tokens=input_tokens, |
246 | | - output_tokens=output_tokens, |
247 | | - total_tokens=total_tokens, |
248 | | - parsed=parsed, |
249 | | - ), |
250 | | - ) |
| 262 | + response = CommentProxy(response, metadata) |
251 | 263 |
|
252 | 264 | return response |
253 | 265 |
|
|
0 commit comments