Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions src/paperoni/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from google import genai
from google.genai import types
from paperazzi.platforms.utils import Message
from paperazzi.utils import _make_key as paperazzi_make_key, disk_cache, disk_store
from paperazzi.utils import _make_key as paperazzi_make_key
from paperazzi.utils import disk_cache, disk_store
from serieux.features.comment import CommentProxy, comment_field
from serieux.features.encrypt import Secret

Expand Down Expand Up @@ -62,6 +63,36 @@ def cleanup_schema(schema: dict | Type[Any]) -> dict:
return schema


def _derive_metadata(
response: types.GenerateContentResponse,
structured_model: Type[Any] | None,
) -> "PromptMetadata":
"""Derive token counts and parsed output from response, return PromptMetadata."""
usage = response.usage_metadata
if usage is not None:
input_tokens = usage.prompt_token_count
output_tokens = (
usage.candidates_token_count + usage.thoughts_token_count
)
total_tokens = usage.total_token_count
else:
input_tokens = output_tokens = total_tokens = None

if structured_model is not None:
metadata_type = PromptMetadata[structured_model]
parsed = serieux.deserialize(structured_model, response.parsed)
else:
metadata_type = PromptMetadata
parsed = None

return metadata_type(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
parsed=parsed,
)


@dataclass
class PromptMetadata[T]:
# The number of input tokens in the prompt
Expand Down Expand Up @@ -109,15 +140,17 @@ def load(self, file_obj: BinaryIO) -> types.GenerateContentResponse:
metadata_type = PromptMetadata

data = dict(data)
comment = data.pop(comment_field, {})
metadata = data.pop(comment_field, {})
response: types.GenerateContentResponse = (
types.GenerateContentResponse.model_validate(data)
)
metadata = metadata or _derive_metadata(response, self.content_type)

with SERIEUX_LOCK:
response: types.GenerateContentResponse = (
types.GenerateContentResponse.model_validate(data)
)
metadata = serieux.deserialize(metadata_type, metadata)

comment = serieux.deserialize(metadata_type, comment)
response = CommentProxy(response, comment)
response.parsed = comment.parsed
response = CommentProxy(response, metadata)
response.parsed = metadata.parsed

return response

Expand Down Expand Up @@ -224,30 +257,9 @@ def prompt(
contents=contents, model=model, config=config
)

# Extract token information from usage_metadata
input_tokens = response.usage_metadata.prompt_token_count
output_tokens = (
response.usage_metadata.candidates_token_count
+ response.usage_metadata.thoughts_token_count
)
total_tokens = response.usage_metadata.total_token_count
metadata = _derive_metadata(response, structured_model)

if structured_model is not None:
metadata_type = PromptMetadata[structured_model]
parsed = serieux.deserialize(structured_model, response.parsed)
else:
metadata_type = PromptMetadata
parsed = None

response = CommentProxy(
response,
metadata_type(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
parsed=parsed,
),
)
response = CommentProxy(response, metadata)

return response

Expand Down
Loading