Skip to content

Commit 0d2dece

Browse files
committed
Merge branch 'dev/support_mask' of github.com:microsoft/agent-lightning into dev/support_mask
2 parents dfb6323 + 47f5186 commit 0d2dece

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+3518
-240
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@ jobs:
5757
if: github.ref == 'refs/heads/main'
5858
run: |
5959
mike deploy --push latest
60-
mike set-default --push latest
60+
# Always set stable to default
61+
mike set-default --push stable

.github/workflows/examples.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ jobs:
4141
name: dependencies-${{ matrix.setup }}
4242
path: requirements-freeze.txt
4343
compression-level: 0
44+
45+
- name: APO example (legacy)
46+
run: |
47+
set -ex
48+
. .venv/bin/activate
49+
cd examples/apo
50+
python legacy_apo_client.py &
51+
sleep 3 # Wait for the client to be up
52+
python legacy_apo_server.py
53+
pkill -f legacy_apo_client.py && echo "SIGTERM sent to legacy_apo_client.py" || echo "No legacy_apo_client.py process found"
54+
while pgrep -f legacy_apo_client.py; do
55+
echo "Waiting for legacy_apo_client.py to finish..."
56+
sleep 5
57+
done
58+
echo "legacy_apo_client.py has finished."
59+
sleep 10
60+
env:
61+
OPENAI_API_BASE: ${{ secrets.OPENAI_API_BASE }}
62+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
63+
4464
- name: Prepare Spider dataset
4565
run: |
4666
set -ex
@@ -133,6 +153,20 @@ jobs:
133153
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
134154
id: calc_x_train_v0_2
135155

156+
- name: Calc-X training v0.2 LLM Proxy
157+
run: |
158+
set -ex
159+
source .venv/bin/activate
160+
cd examples/calc_x
161+
../../scripts/restart_ray.sh
162+
sleep 5
163+
PYTHONUNBUFFERED=1 python calc_agent_v0_2_llm_proxy.py
164+
sleep 10
165+
shell: bash
166+
env:
167+
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
168+
id: calc_x_train_v0_2_llm_proxy
169+
136170
- name: Spider training
137171
run: |
138172
set -ex

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
![Agent-lightning-banner](docs/assets/readme-banner.png)
1+
<div style="text-align:center; margin-bottom:20px;">
2+
<img src="docs/assets/readme-banner.png" alt="Agent-lightning-banner" style="max-width:600px"/>
3+
</div>
24

35
# Agent Lightning⚡
46

@@ -31,6 +33,7 @@ Join our [Discord community](https://discord.gg/RYk7CdvDR7) to connect with othe
3133
## ⚡ Community Projects
3234

3335
- [DeepWerewolf](https://github.com/af-74413592/DeepWerewolf) — A case study of agent RL training for the Chinese Werewolf game built with AgentScope and Agent Lightning.
36+
- [AgentFlow](https://agentflow.stanford.edu/) — A modular multi-agent framework that combines planner, executor, verifier, and generator agents with the Flow-GRPO algorithm to tackle long-horizon, sparse-reward tasks.
3437

3538
## ⚡ Installation
3639

agentlightning/adapter/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
from .base import Adapter, TraceAdapter
4-
from .triplet import TraceTripletAdapter
4+
from .triplet import BaseTraceTripletAdapter, LlmProxyTripletAdapter, TraceTripletAdapter
55

6-
__all__ = ["TraceAdapter", "Adapter", "TraceTripletAdapter"]
6+
__all__ = ["TraceAdapter", "Adapter", "BaseTraceTripletAdapter", "TraceTripletAdapter", "LlmProxyTripletAdapter"]

agentlightning/adapter/triplet.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,13 @@ def __repr__(self):
521521
)
522522

523523

524-
class TraceTripletAdapter(TraceAdapter[List[Triplet]]):
524+
class BaseTraceTripletAdapter(TraceAdapter[List[Triplet]]):
525+
"""
526+
Base class for trace triplet adapters.
527+
"""
528+
529+
530+
class TraceTripletAdapter(BaseTraceTripletAdapter):
525531
"""
526532
An adapter to convert OpenTelemetry spans to triplet data.
527533
@@ -592,3 +598,190 @@ def adapt(self, source: Union[List[Span], List[ReadableSpan]], /) -> List[Triple
592598
reward_match=self.reward_match,
593599
)
594600
return trajectory
601+
602+
603+
class LlmProxyTripletAdapter(BaseTraceTripletAdapter):
604+
"""
605+
Converting telemetry data emitted by the LLM Proxy to triplet data.
606+
This adapter is very experimental. Should only be used when the TraceTripletAdapter does not work at all.
607+
608+
IMPORTANT: Do NOT rely on timestamps here. Proxy spans can be emitted from different
609+
machines with unsynchronized clocks. We therefore treat `sequence_id` as the only
610+
reliable ordering primitive and perform "first occurrence" reward matching using
611+
sequence order only.
612+
613+
Strategy:
614+
615+
1) Sort spans by (sequence_id, start_time).
616+
2) Extract LLM calls that expose prompt/response token IDs from either:
617+
- litellm_request (sometimes only metadata, ignore if no token ids)
618+
- raw_gen_ai_request (llm.hosted_vllm.* stringified fields)
619+
3) Extract rewards from spans whose attributes contain an AgentOps-style
620+
reward payload or explicit REWARD span.
621+
4) For each reward with sequence R, assign it to the most recent *unmatched* LLM call
622+
with sequence < R. Ignore timestamps completely.
623+
"""
624+
625+
def _literal_eval_maybe(self, v: Any) -> Any:
626+
import ast
627+
628+
if isinstance(v, str):
629+
try:
630+
return ast.literal_eval(v)
631+
except Exception:
632+
return v
633+
return v
634+
635+
def _extract_tokens_from_raw(self, attrs: Dict[str, Any]) -> Tuple[List[int], List[int]]:
636+
"""Extract token ids from raw_gen_ai_request attributes.
637+
638+
- llm.hosted_vllm.prompt_token_ids: string -> List[int]
639+
- llm.hosted_vllm.response_token_ids: string -> List[List[int]] -> take first
640+
- llm.hosted_vllm.choices: string -> [{'token_ids': [...]}] -> take first
641+
"""
642+
prompt_ids: List[int] = []
643+
resp_ids: List[int] = []
644+
645+
# prompt
646+
p = attrs.get("llm.hosted_vllm.prompt_token_ids")
647+
p = self._literal_eval_maybe(p)
648+
if isinstance(p, list) and all(isinstance(x, int) for x in p): # type: ignore
649+
prompt_ids = cast(List[int], p)
650+
651+
# response preferred path
652+
r = attrs.get("llm.hosted_vllm.response_token_ids")
653+
r = self._literal_eval_maybe(r)
654+
if isinstance(r, list) and len(r) > 0 and isinstance(r[0], list): # type: ignore
655+
first = cast(List[Any], r[0])
656+
if all(isinstance(x, int) for x in first):
657+
resp_ids = cast(List[int], first)
658+
659+
# fallback via choices
660+
if not resp_ids:
661+
choices = attrs.get("llm.hosted_vllm.choices")
662+
choices = self._literal_eval_maybe(choices)
663+
if isinstance(choices, list) and choices:
664+
cand = cast(Any, choices[0])
665+
if isinstance(cand, dict):
666+
tids = cast(Dict[str, Any], cand).get("token_ids")
667+
if isinstance(tids, list) and all(isinstance(x, int) for x in tids): # type: ignore
668+
resp_ids = cast(List[int], tids)
669+
670+
return prompt_ids, resp_ids
671+
672+
def _extract_tokens_from_openai(self, attrs: Dict[str, Any]) -> Tuple[List[int], List[int]]:
673+
prompt_ids = cast(Any, attrs.get("prompt_token_ids") or [])
674+
resp_ids = cast(Any, attrs.get("response_token_ids") or [])
675+
prompt_ids = self._literal_eval_maybe(prompt_ids)
676+
resp_ids = self._literal_eval_maybe(resp_ids)
677+
if not (isinstance(prompt_ids, list) and all(isinstance(x, int) for x in prompt_ids)): # type: ignore
678+
prompt_ids = []
679+
if not (isinstance(resp_ids, list) and all(isinstance(x, int) for x in resp_ids)): # type: ignore
680+
resp_ids = []
681+
return cast(List[int], prompt_ids), cast(List[int], resp_ids)
682+
683+
def _maybe_reward_value(self, span: Span) -> Optional[float]:
684+
"""
685+
Parse reward from typical AgentOps payload or explicit REWARD span.
686+
"""
687+
attrs = span.attributes or {}
688+
689+
# AgentOps new/old keys
690+
for k in ("agentops.task.output", "agentops.entity.output"):
691+
v = attrs.get(k)
692+
v = self._literal_eval_maybe(v)
693+
if isinstance(v, dict) and cast(Dict[str, Any], v).get("type") == "reward":
694+
rv = cast(Dict[str, Any], v).get("value", None)
695+
if rv is None or isinstance(rv, (int, float)):
696+
return None if rv is None else float(rv)
697+
698+
# Explicit reward span
699+
if span.name == SpanNames.REWARD.value:
700+
rv = attrs.get("reward", None)
701+
if rv is None or isinstance(rv, (int, float)):
702+
return None if rv is None else float(rv)
703+
704+
return None
705+
706+
def _request_id_from_attrs(self, attrs: Dict[str, Any]) -> Optional[str]:
707+
# Prefer OpenAI-like id if present, else proxy raw id.
708+
rid = attrs.get("gen_ai.response.id") or attrs.get("llm.hosted_vllm.id")
709+
return str(rid) if isinstance(rid, str) and rid else None
710+
711+
def adapt(self, source: List[Span], /) -> List[Triplet]: # type: ignore
712+
# 1) Sort deterministically by (sequence_id, start_time).
713+
spans = sorted(
714+
source,
715+
key=lambda s: (s.sequence_id, s.start_time),
716+
)
717+
718+
# 2) Collect LLM calls with token IDs.
719+
llm_items: List[Dict[str, Any]] = []
720+
seen_request_ids: set[str] = set()
721+
for s in spans:
722+
attrs = s.attributes or {}
723+
prompt_ids: List[int] = []
724+
resp_ids: List[int] = []
725+
726+
if s.name == "raw_gen_ai_request":
727+
prompt_ids, resp_ids = self._extract_tokens_from_raw(attrs)
728+
elif s.name == "litellm_request":
729+
# Some proxies never include token ids here. Ignore unless present.
730+
prompt_ids, resp_ids = self._extract_tokens_from_openai(attrs)
731+
732+
if prompt_ids and resp_ids:
733+
rid = self._request_id_from_attrs(attrs)
734+
if rid:
735+
# Duplicated request ID. This request is already handled.
736+
if rid in seen_request_ids:
737+
continue
738+
seen_request_ids.add(rid)
739+
llm_items.append(
740+
dict(
741+
span=s,
742+
seq=s.sequence_id,
743+
response_ids=resp_ids,
744+
prompt_ids=prompt_ids,
745+
request_id=rid,
746+
)
747+
)
748+
749+
# Order LLM items by sequence only.
750+
llm_items.sort(key=lambda x: x["seq"])
751+
752+
# Collect rewards by sequence only.
753+
rewards: List[Tuple[int, Optional[float]]] = []
754+
for s in spans:
755+
val = self._maybe_reward_value(s)
756+
if val is not None:
757+
rewards.append((s.sequence_id, val))
758+
759+
# First-occurrence matching by sequence_id only:
760+
# For reward at sequence R, assign to the most recent unmatched LLM with seq < R.
761+
assigned: Dict[str, Optional[float]] = {}
762+
for r_seq, r_val in sorted(rewards, key=lambda x: x[0]):
763+
for item in reversed(llm_items):
764+
sid = item["span"].span_id
765+
if sid in assigned:
766+
continue
767+
if item["seq"] < r_seq:
768+
assigned[sid] = r_val
769+
break
770+
771+
# Build triplets in LLM sequence order.
772+
triplets: List[Triplet] = []
773+
for item in llm_items:
774+
s = item["span"]
775+
triplets.append(
776+
Triplet(
777+
prompt={"token_ids": item["prompt_ids"]},
778+
response={"token_ids": item["response_ids"]},
779+
reward=assigned.get(s.span_id, None),
780+
metadata=dict(
781+
# This is called response_id to align with the other adapters.
782+
response_id=item["request_id"],
783+
),
784+
)
785+
)
786+
787+
return triplets

0 commit comments

Comments
 (0)