Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/wayflowcore/source/core/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ New features
Added support for converting `MessageSummarizationTransform` and `ConversationSummarizationTransform` between Agent Spec and Wayflow. Similarly for Datastores (`OracleDatabaseDatastore`, `PostgreSQLDatabaseDatastore`).
You can now declare your agents with summarization transforms and summary caching in Agent Spec and run them in WayFlow.

* **Logprob support in `LlmGenerationConfig` and `PromptExecutionStep`

Add token logprobs support with the `top_logprobs` generation config parameter and support returning
logprobs in the `PromptExecutionStep`.

Improvements
^^^^^^^^^^^^
Expand Down
58 changes: 58 additions & 0 deletions wayflowcore/src/wayflowcore/messagelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from wayflowcore._utils.hash import fast_stable_hash
from wayflowcore.serialization.context import DeserializationContext, SerializationContext
from wayflowcore.serialization.serializer import (
FrozenSerializableDataclass,
SerializableDataclass,
SerializableDataclassMixin,
SerializableObject,
Expand All @@ -51,6 +52,28 @@
_ReasoningContent: TypeAlias = Dict[str, Any]


@dataclass(frozen=True, slots=True)
class TextTokenTopLogProb(FrozenSerializableDataclass):
"""Represents a single candidate token with its associated log probability."""

token: str
"""The literal text of the candidate token."""
logprob: float
"""The log probability assigned to the candidate token."""


@dataclass(frozen=True, slots=True)
class TextTokenLogProb(FrozenSerializableDataclass):
"""Captures a generated token, its log probability, and alternate candidates."""

token: str
"""The literal text of the generated token."""
logprob: float
"""The log probability assigned to the generated token."""
top_logprobs: Optional[List[TextTokenTopLogProb]] = None
"""Optional ranked list of alternate tokens with probabilities."""


class MessageType(str, Enum):
"""Type of messages"""

Expand Down Expand Up @@ -98,6 +121,7 @@ class TextContent(MessageContent, SerializableObject):
"""

content: str = ""
logprobs: Optional[List[TextTokenLogProb]] = None
type: ClassVar[Literal["text"]] = "text"

def __post_init__(self) -> None:
Expand All @@ -109,6 +133,40 @@ def __post_init__(self) -> None:
)
self.content = str(self.content)

if self.logprobs is None:
return

# We accept both already-built `TextTokenLogProb` objects and raw dicts
# (e.g., from provider payloads) to keep adapters simple.
validated: List[TextTokenLogProb] = []
for item in self.logprobs:
if isinstance(item, TextTokenLogProb):
validated.append(item)
continue

raw_item = cast(Dict[str, Any], item)
raw_top = raw_item.get("top_logprobs")
top_converted = None
if raw_top is not None:
top_converted = [
(
c
if isinstance(c, TextTokenTopLogProb)
else TextTokenTopLogProb(**cast(Dict[str, Any], c))
)
for c in raw_top
]

validated.append(
TextTokenLogProb(
token=raw_item["token"],
logprob=raw_item["logprob"],
top_logprobs=top_converted,
)
)

self.logprobs = validated


@dataclass
class ImageContent(MessageContent, SerializableObject):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from typing import Any, AsyncIterable, Callable, Dict, List, Optional, TypedDict

from wayflowcore._utils.formatting import stringify
from wayflowcore.messagelist import ImageContent, Message, TextContent
from wayflowcore.messagelist import (
ImageContent,
Message,
TextContent,
TextTokenLogProb,
TextTokenTopLogProb,
)
from wayflowcore.tokenusage import TokenUsage
from wayflowcore.tools import Tool, ToolRequest
from wayflowcore.tools.tools import ExtraContentT
Expand All @@ -31,6 +37,28 @@ class OpenAIToolRequestAsDictT(TypedDict, total=True):

class _ChatCompletionsAPIProcessor(_APIProcessor):

@staticmethod
def _convert_openai_logprobs_into_text_logprobs(logprobs: Any) -> List[TextTokenLogProb]:
converted: List[TextTokenLogProb] = []
if not logprobs:
return converted

for item in logprobs:
top = item.get("top_logprobs")
top_converted = None
if top is not None:
top_converted = [
TextTokenTopLogProb(token=c["token"], logprob=c["logprob"]) for c in top
]
converted.append(
TextTokenLogProb(
token=item["token"],
logprob=item["logprob"],
top_logprobs=top_converted,
)
)
return converted

@staticmethod
def _tool_to_openai_function_dict(tool: Tool) -> Dict[str, Any]:
openai_function_dict: Dict[str, Any] = {
Expand Down Expand Up @@ -151,6 +179,9 @@ def _convert_generation_params(
kwargs["stop"] = generation_config.stop
if generation_config.frequency_penalty is not None:
kwargs["frequency_penalty"] = generation_config.frequency_penalty
if generation_config.top_logprobs is not None:
kwargs["logprobs"] = True
kwargs["top_logprobs"] = generation_config.top_logprobs
if generation_config.extra_args:
kwargs.update(generation_config.extra_args)
return kwargs
Expand All @@ -174,9 +205,17 @@ def _convert_openai_response_into_message(self, response: Any) -> "Message":
# content might be empty when certain models (like gemini) decide
# to finish the conversation
content = extracted_message.get("content", "")

logprobs = None
choice_logprobs = response["choices"][0].get("logprobs")
if choice_logprobs and choice_logprobs.get("content") is not None:
logprobs = self._convert_openai_logprobs_into_text_logprobs(
choice_logprobs["content"]
)

message = Message(
role="assistant",
contents=[TextContent(content=content)],
contents=[TextContent(content=content, logprobs=logprobs)],
_extra_content=extracted_message.get("extra_content"),
)
return message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Message,
MessageContent,
TextContent,
TextTokenLogProb,
TextTokenTopLogProb,
_ReasoningContent,
)
from wayflowcore.tokenusage import TokenUsage
Expand All @@ -28,6 +30,28 @@


class _ResponsesAPIProcessor(_APIProcessor):
@staticmethod
def _convert_openai_logprobs_into_text_logprobs(logprobs: Any) -> List[TextTokenLogProb]:
converted: List[TextTokenLogProb] = []
if not logprobs:
return converted

for item in logprobs:
top = item.get("top_logprobs")
top_converted = None
if top is not None:
top_converted = [
TextTokenTopLogProb(token=c["token"], logprob=c["logprob"]) for c in top
]
converted.append(
TextTokenLogProb(
token=item["token"],
logprob=item["logprob"],
top_logprobs=top_converted,
)
)
return converted

@staticmethod
def _tool_to_openai_function_dict(tool: Tool) -> Dict[str, Any]:
openai_function_dict: Dict[str, Any] = {
Expand Down Expand Up @@ -218,11 +242,18 @@ def _convert_generation_params(
kwargs["temperature"] = generation_config.temperature
if generation_config.max_tokens is not None:
kwargs["max_output_tokens"] = generation_config.max_tokens
if generation_config.top_logprobs is not None:
kwargs["top_logprobs"] = generation_config.top_logprobs
kwargs.setdefault("include", [])
if "message.output_text.logprobs" not in kwargs["include"]:
kwargs["include"].append("message.output_text.logprobs")
if generation_config.extra_args:
if "reasoning" in generation_config.extra_args:
kwargs["include"] = [
"reasoning.encrypted_content"
] # Pass reasoning traces if user has configured the reasoning parameter
kwargs.setdefault("include", [])
if "reasoning.encrypted_content" not in kwargs["include"]:
kwargs["include"].append(
"reasoning.encrypted_content"
) # Pass reasoning traces if user has configured the reasoning parameter

if "summary" not in generation_config.extra_args["reasoning"]:
generation_config.extra_args["reasoning"]["summary"] = "auto"
Expand Down Expand Up @@ -261,7 +292,17 @@ def _convert_openai_response_into_message(self, response: Any) -> "Message":
)

if item["content"][0]["type"] == "output_text":
output_contents.append(TextContent(item["content"][0]["text"]))
logprobs = None
if (
"logprobs" in item["content"][0]
and item["content"][0]["logprobs"] is not None
):
logprobs = self._convert_openai_logprobs_into_text_logprobs(
item["content"][0]["logprobs"]
)
output_contents.append(
TextContent(item["content"][0]["text"], logprobs=logprobs)
)

elif item["type"] == "image_generation_call" and item["result"]:
output_contents.append(ImageContent(base64_content=item["result"]))
Expand Down
19 changes: 18 additions & 1 deletion wayflowcore/src/wayflowcore/models/llmgenerationconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,21 @@ class LlmGenerationConfig(SerializableDataclass):
top_p: Optional[float] = None
stop: Optional[List[str]] = None
frequency_penalty: Optional[float] = None
top_logprobs: Optional[int] = None
extra_args: Dict[str, Any] = field(default_factory=dict)
# might be an issue in the future if you want not to pass some parameter and there is a default config
# if needed, defaults would be `Empty` and then None means `not specified` rather than `don't use`

def __post_init__(self) -> None:
# We check if among the extra args there are known fields
known_fields: Set[str] = {"max_tokens", "temperature", "top_p", "stop", "frequency_penalty"}
known_fields: Set[str] = {
"max_tokens",
"temperature",
"top_p",
"stop",
"frequency_penalty",
"top_logprobs",
}
for extra_arg_key in known_fields:
# If we find one, we remove it from extra args, and we set the proper field with the value (if it's not set)
# If the field is already set (i.e., not None), we just ignore the value in extra_args and raise a warning
Expand All @@ -101,6 +109,8 @@ def __post_init__(self) -> None:
)
if not (self.frequency_penalty is None or (-2 <= self.frequency_penalty <= 2)):
raise ValueError("The frequency penalty should be between -2 and 2")
if self.top_logprobs is not None and self.top_logprobs < 0:
raise ValueError("top_logprobs should be non-negative")

def to_dict(self) -> Dict[str, Any]:
config_dict: Dict[str, Union[int, float, List[str], Dict[str, Any]]] = {}
Expand All @@ -114,6 +124,8 @@ def to_dict(self) -> Dict[str, Any]:
config_dict["stop"] = self.stop
if self.frequency_penalty is not None:
config_dict["frequency_penalty"] = self.frequency_penalty
if self.top_logprobs is not None:
config_dict["top_logprobs"] = self.top_logprobs
if self.extra_args:
for extra_arg_key, extra_arg_value in self.extra_args.items():
config_dict[extra_arg_key] = serialize_any_to_dict(
Expand Down Expand Up @@ -142,6 +154,10 @@ def from_dict(config: Dict[str, Any]) -> "LlmGenerationConfig":

stop = config.pop("stop", None)

top_logprobs = config.pop("top_logprobs", None)
if top_logprobs is not None:
top_logprobs = int(top_logprobs)

extra_args: Dict[str, Any] = {}
for extra_arg_key, extra_arg_value in config.items():
extra_args[extra_arg_key] = autodeserialize_any_from_dict(
Expand All @@ -154,6 +170,7 @@ def from_dict(config: Dict[str, Any]) -> "LlmGenerationConfig":
top_p=top_p,
stop=stop,
frequency_penalty=frequency_penalty,
top_logprobs=top_logprobs,
extra_args=extra_args,
)

Expand Down
Loading