Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions agentlightning/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _to_rollout_object(
self,
result: RolloutRawResult,
rollout_id: str,
attempt_id: str,
) -> Rollout:
"""Standardizes the agent's return value into a Rollout object.

Expand Down Expand Up @@ -122,6 +123,7 @@ def _to_rollout_object(
# Create the Rollout object with standardized fields
result_dict: Dict[str, Any] = {
"rollout_id": rollout_id,
"attempt_id": attempt_id,
}
if final_reward is not None:
result_dict["final_reward"] = final_reward
Expand Down Expand Up @@ -155,7 +157,7 @@ def run(self) -> bool:
logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
return False

rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout
rollout_obj = Rollout(rollout_id=task.rollout_id, attempt_id=task.attempt_id) # Default empty rollout

try:
try:
Expand All @@ -168,7 +170,7 @@ def run(self) -> bool:
rollout_method = self.agent.training_rollout if task.mode == "train" else self.agent.validation_rollout
# Pass the task input, not the whole task object
result = rollout_method(task.input, task.rollout_id, resources_update.resources)
rollout_obj = self._to_rollout_object(result, task.rollout_id)
rollout_obj = self._to_rollout_object(result, task.rollout_id, task.attempt_id)
end_time = time.time()
logger.info(
f"{self._log_prefix(rollout_id)} Completed in "
Expand Down Expand Up @@ -224,8 +226,8 @@ async def run_async(self) -> bool:
logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
return False

rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout

rollout_obj = Rollout(rollout_id=task.rollout_id, attempt_id=task.attempt_id) # Default empty rollout
Comment on lines +229 to +230
Copy link

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an extra trailing space on line 230 that should be removed.

Suggested change
rollout_obj = Rollout(rollout_id=task.rollout_id, attempt_id=task.attempt_id) # Default empty rollout

Copilot uses AI. Check for mistakes.
try:
try:
self.agent.on_rollout_start(task, self, self.tracer)
Expand All @@ -239,7 +241,7 @@ async def run_async(self) -> bool:
)
# Pass the task input, not the whole task object
result = await rollout_method(task.input, task.rollout_id, resources_update.resources)
rollout_obj = self._to_rollout_object(result, task.rollout_id)
rollout_obj = self._to_rollout_object(result, task.rollout_id, task.attempt_id)
end_time = time.time()
logger.info(
f"{self._log_prefix(rollout_id)} Completed in "
Expand Down
11 changes: 11 additions & 0 deletions agentlightning/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def get_next_task(self) -> Optional[Task]:
update={
"last_claim_time": time.time(),
"num_claims": (task.num_claims or 0) + 1,
"attempt_id": str(uuid.uuid4()),
}
)
self._processing_tasks[task.rollout_id] = task
Expand Down Expand Up @@ -121,6 +122,16 @@ async def store_rollout(self, rollout: Rollout):
Safely stores a completed rollout from a client.
"""
async with self._results_lock:
current_task = self._processing_tasks.get(rollout.rollout_id)
if not current_task:
logger.warning(f"Ignoring rollout {rollout.rollout_id}: task not in processing anymore")
return # drop stale result

if getattr(rollout, "attempt_id", None) != current_task.attempt_id:
Copy link

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using getattr with a default value is unnecessary since rollout.attempt_id is a defined field in the Rollout model. Consider using direct attribute access: rollout.attempt_id != current_task.attempt_id.

Suggested change
if getattr(rollout, "attempt_id", None) != current_task.attempt_id:
if rollout.attempt_id != current_task.attempt_id:

Copilot uses AI. Check for mistakes.
logger.warning(f"Ignoring stale rollout {rollout.rollout_id}: attempt_id mismatch: {rollout.attempt_id} != {current_task.attempt_id}")
logger.warning(f"The rollout: {rollout}")
return # drop stale result

self._processing_tasks.pop(rollout.rollout_id, None)
self._completed_rollouts[rollout.rollout_id] = rollout
logger.info(f"Rollout received and stored: {rollout.rollout_id}")
Expand Down
6 changes: 6 additions & 0 deletions agentlightning/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Rollout(BaseModel):
)
logs: Optional[List[str]] = None

# Optional fields for tracking task lifecycle
attempt_id: Optional[str] = None

# A bucket for any other relevant information
metadata: Dict[str, Any] = Field(default_factory=dict)

Expand All @@ -71,6 +74,9 @@ class Task(BaseModel):
last_claim_time: Optional[float] = None
num_claims: Optional[int] = None

# Optional fields for tracking task lifecycle
attempt_id: Optional[str] = None

# Allow additional metadata fields
metadata: Dict[str, Any] = Field(default_factory=dict)

Expand Down