Skip to content

Commit 9864b8f

Browse files
authored
New example: AGL Simulation (#367)
1 parent 82d8535 commit 9864b8f

15 files changed

Lines changed: 2412 additions & 0 deletions

File tree

contrib/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
# Put contrib-related gitignore files here.
2+
3+
# recipes/envs related
4+
recipes/envs/agl_envs/
5+
recipes/envs/wandb/
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from __future__ import annotations
4+
5+
from typing import Dict, List, Optional
6+
7+
from agentlightning.adapter.triplet import TracerTraceToTriplet
8+
from agentlightning.types import Span, Triplet
9+
10+
11+
class TracerTraceToTripletGroup(TracerTraceToTriplet):
12+
"""Convert tracer-emitted spans into triplet trajectories.
13+
14+
Attributes:
15+
repair_hierarchy: When `True`, repair the span tree using
16+
[`TraceTree.repair_hierarchy()`][agentlightning.adapter.triplet.TraceTree.repair_hierarchy]
17+
before matching calls and rewards.
18+
llm_call_match: Regular expression pattern that selects LLM call span names.
19+
agent_match: Optional regular expression pattern for agent span names. When omitted, spans
20+
from any agent are considered.
21+
exclude_llm_call_in_reward: When `True`, ignore matches under reward spans while searching
22+
for rewards.
23+
reward_match: Strategy used to associate rewards with LLM calls.
24+
"""
25+
26+
def __init__(self, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
29+
def _extract_span_groups(self, spans):
30+
def resolve_step_count(span, next_span, spans, index):
31+
"""
32+
Determine step_count for a given span using next_span or fallback search.
33+
"""
34+
# CASE A: If next_span exists and parent_id matches
35+
if next_span and span.parent_id == next_span.span_id:
36+
return next_span.attributes.get("step_count")
37+
38+
# CASE B: Fallback — search forward for agentlightning.operation
39+
for s in spans[index + 1 :]:
40+
if s.name == "agentlightning.operation" and span.parent_id == s.span_id:
41+
return s.attributes.get("step_count")
42+
43+
return None
44+
45+
def extract_step_count_from_links(span):
46+
"""
47+
Extract step_count from agentlightning.link.* attributes.
48+
"""
49+
key = span.attributes.get("agentlightning.link.0.key_match")
50+
if key == "step_count":
51+
return span.attributes.get("agentlightning.link.0.value_match")
52+
return None
53+
54+
span_groups = {}
55+
56+
for i, span in enumerate(spans):
57+
next_span = spans[i + 1] if i + 1 < len(spans) else None
58+
step_count = None
59+
60+
if span.name == "openai.chat.completion":
61+
step_count = resolve_step_count(span, next_span, spans, i)
62+
if step_count is None:
63+
continue
64+
65+
step_count = str(step_count)
66+
span_groups.setdefault(step_count, {})
67+
span_groups[step_count]["call_span"] = span
68+
69+
elif span.name == "agentlightning.object":
70+
step_count = extract_step_count_from_links(span)
71+
if step_count is None:
72+
continue
73+
74+
step_count = str(step_count)
75+
span_groups.setdefault(step_count, {})
76+
span_groups[step_count]["object_span"] = span
77+
78+
elif span.name == "agentlightning.annotation":
79+
step_count = extract_step_count_from_links(span)
80+
if step_count is None:
81+
continue
82+
83+
step_count = str(step_count)
84+
span_groups.setdefault(step_count, {})
85+
span_groups[step_count]["annotation_span"] = span
86+
87+
return span_groups
88+
89+
def adapt_group(self, source: Sequence[Span], /) -> List[Triplet]:
90+
span_groups = self._extract_span_groups(source)
91+
92+
def token_ids(span: Optional[Span], key: str) -> list:
93+
return span.attributes.get(key, []) if span else []
94+
95+
def reward0(span: Optional[Span]) -> float:
96+
if not span:
97+
return 0.0
98+
return float(span.attributes.get("agentlightning.reward.0.value", 0.0))
99+
100+
def reward1(span: Optional[Span]) -> Optional[float]:
101+
if not span:
102+
return 0.0
103+
return float(span.attributes.get("agentlightning.reward.1.value", 0.0))
104+
105+
def message(span: Optional[Span]) -> Optional[str]:
106+
if not span:
107+
return ""
108+
return span.attributes.get("agentlightning.object.literal", "")
109+
110+
triplets: List[Triplet] = []
111+
112+
for group in span_groups.values():
113+
call_span = group.get("call_span")
114+
if not token_ids(call_span, "prompt_token_ids") and not token_ids(call_span, "response_token_ids"):
115+
continue
116+
117+
object_span = group.get("object_span")
118+
annotation_span = group.get("annotation_span")
119+
request_id = group.get("request_id")
120+
121+
triplets.append(
122+
Triplet(
123+
prompt={"token_ids": token_ids(call_span, "prompt_token_ids")},
124+
response={"token_ids": token_ids(call_span, "response_token_ids")},
125+
reward=reward0(annotation_span),
126+
metadata={
127+
"response_id": request_id,
128+
"intrinsic_reward": reward1(annotation_span),
129+
"message": message(object_span),
130+
},
131+
)
132+
)
133+
134+
return triplets
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import os
7+
from typing import Any, Dict
8+
9+
import numpy as np
10+
from add_instruction import add_chat_instruction, add_single_instruction
11+
from agl_envs import make_env_manager
12+
from autogen_agentchat.agents import AssistantAgent
13+
from autogen_core.models import ModelFamily
14+
from autogen_ext.models.openai import OpenAIChatCompletionClient
15+
16+
from agentlightning import LLM, LitAgent, NamedResources, Rollout, configure_logger, emit_object, emit_reward, operation
17+
from agentlightning.utils.otel import make_link_attributes
18+
from contrib.recipes.envs.prompt_builder import HistoryPromptBuilder
19+
20+
logger = configure_logger(name=__name__, level=logging.ERROR)
21+
22+
23+
class EnvAgent(LitAgent):
24+
def __init__(self, config, trained_agents: str | None = None) -> None:
25+
super().__init__(trained_agents=trained_agents)
26+
self.config = config
27+
self.env = None
28+
29+
def _build_agent(self, llm: LLM, temperature: float):
30+
model_client = OpenAIChatCompletionClient(
31+
model=llm.model,
32+
base_url=llm.endpoint,
33+
api_key=os.environ.get("OPENAI_API_KEY", "token-abc123"),
34+
model_info={
35+
"vision": False,
36+
"function_calling": True,
37+
"json_output": False,
38+
"family": ModelFamily.UNKNOWN,
39+
"structured_output": False,
40+
},
41+
temperature=temperature,
42+
)
43+
44+
return AssistantAgent(
45+
name="envs",
46+
model_client=model_client,
47+
)
48+
49+
def _get_instructed_prompt(self, prompt, sep="\n\n"):
50+
"""Return instructed observation based on prompt_type and captioner type."""
51+
prompt_type = self.config.captioner.prompt_type
52+
cap_type = self.config.captioner.type
53+
54+
if prompt_type == "chat":
55+
if cap_type == "cot":
56+
return add_chat_instruction(prompt, "cot", sep, self.config.env_name)
57+
elif cap_type == "naive":
58+
return add_chat_instruction(prompt, "naive", sep)
59+
60+
elif prompt_type == "single":
61+
if cap_type == "cot":
62+
return add_single_instruction(prompt, "cot", sep, self.config.env_name)
63+
elif cap_type == "naive":
64+
return add_single_instruction(prompt, "naive", sep, self.config.env_name)
65+
66+
raise ValueError(f"Unsupported prompt_type={prompt_type}, type={cap_type}")
67+
68+
async def rollout_async(
69+
self,
70+
task: Dict[str, Any],
71+
resources: NamedResources,
72+
rollout: Rollout,
73+
) -> float | None:
74+
rollout_id = rollout.rollout_id
75+
logger.info(f"[Rollout {rollout_id}] Task: {task}")
76+
77+
format_penalty = float(self.config["format_penalty"])
78+
reward_scale = float(self.config["reawrd_scale"])
79+
80+
# Setup agent
81+
llm: LLM = resources.get("main_llm")
82+
print("Training with model:", llm.model, "on endpoint:", llm.endpoint)
83+
self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4)
84+
if "max_tokens" in self.config and self.config["max_tokens"] > -1:
85+
self.agent._model_client.max_tokens = self.config["max_tokens"]
86+
87+
try:
88+
# Setup environment
89+
prompt_builder = HistoryPromptBuilder(
90+
max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type
91+
)
92+
93+
self.env = make_env_manager(self.config.env_name, task, self.config)
94+
env_obs, infos, available_actions_hint = self.env.reset()
95+
96+
prompt_builder.init(self.env)
97+
prompt_builder.update_observation(env_obs)
98+
prompt_builder.update_admissible_actions(available_actions_hint)
99+
100+
prompt = prompt_builder.get_prompt()
101+
102+
episode_reward, done = 0.0, False
103+
104+
step_count = 0
105+
while not done:
106+
try:
107+
instructed_prompt = self._get_instructed_prompt(prompt)
108+
109+
# Main agent step
110+
with operation(step_count=step_count):
111+
result = await self.agent._model_client.create(instructed_prompt)
112+
output = result.content
113+
logger.info(f"[LLM output]: {output}")
114+
115+
except Exception as e:
116+
logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True)
117+
break
118+
119+
if self.config.log_env_obs:
120+
emit_object(env_obs, attributes=make_link_attributes({"step_count": str(step_count)}))
121+
122+
env_obs, executed_action, is_valid, step_reward, terminated, truncated, info, available_actions_hint = (
123+
self.env.step(
124+
output,
125+
use_reasoning=self.config.captioner.type == "cot",
126+
use_success_rate=self.config.use_success_rate,
127+
)
128+
)
129+
130+
prompt_builder.update_step_count()
131+
prompt_builder.update_action(executed_action)
132+
prompt_builder.update_observation(env_obs)
133+
prompt_builder.update_admissible_actions(available_actions_hint)
134+
135+
prompt = prompt_builder.get_prompt()
136+
137+
if rollout.mode == "train":
138+
step_reward *= reward_scale
139+
140+
if format_penalty != 0.0:
141+
emit_reward(
142+
{
143+
"extrinsic_reward": step_reward,
144+
"intrinsic_reward": 0.0 if is_valid else -1.0 * format_penalty,
145+
},
146+
primary_key="extrinsic_reward",
147+
attributes=make_link_attributes({"step_count": str(step_count)}),
148+
)
149+
else:
150+
emit_reward(step_reward, attributes=make_link_attributes({"step_count": str(step_count)}))
151+
152+
episode_reward += float(step_reward)
153+
done = np.logical_or(terminated, truncated)
154+
155+
step_count += 1
156+
157+
return episode_reward
158+
159+
finally:
160+
if self.env is not None:
161+
self.env.close()

0 commit comments

Comments
 (0)