Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions agentlightning/litagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,32 +94,6 @@ def runner(self) -> AgentRunner:
raise ValueError("Runner reference is no longer valid (object has been garbage collected).")
return runner

def on_rollout_start(self, task: Task, runner: AgentRunner, tracer: BaseTracer) -> None:
"""Hook called immediately before a rollout begins.

Args:
task: The :class:`Task` object that will be processed.
runner: The :class:`AgentRunner` managing the rollout.
tracer: The tracer instance associated with the runner.

Subclasses can override this method to implement custom logic such as
logging, metric collection, or resource setup. By default, this is a
no-op.
"""

def on_rollout_end(self, task: Task, rollout: Rollout, runner: AgentRunner, tracer: BaseTracer) -> None:
"""Hook called after a rollout completes.

Args:
task: The :class:`Task` object that was processed.
rollout: The resulting :class:`Rollout` object.
runner: The :class:`AgentRunner` managing the rollout.
tracer: The tracer instance associated with the runner.

Subclasses can override this method for cleanup or additional
logging. By default, this is a no-op.
"""

def training_rollout(self, task: TaskInput, rollout_id: str, resources: NamedResources) -> RolloutRawResult:
"""Defines the agent's behavior for a single training task.

Expand Down
201 changes: 201 additions & 0 deletions agentlightning/logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from __future__ import annotations

import json
import logging
from typing import Optional, TYPE_CHECKING
import numpy as np

from .types import Hook

if TYPE_CHECKING:
from .types import Task, Rollout
from .tracer import BaseTracer
from .runner import AgentRunner


def configure_logger(level: int = logging.INFO, name: str = "agentlightning") -> logging.Logger:
Expand All @@ -14,3 +26,192 @@ def configure_logger(level: int = logging.INFO, name: str = "agentlightning") ->
logger.setLevel(level)
logger.propagate = False # prevent double logging
return logger


logger = logging.getLogger(__name__)


class LightningLogger(Hook):
"""Agent-lightning logger that supports tracing events and metrics throughout the training."""

def log_event(self, event: str, data: dict):
"""
Log an event with its associated data when something happens.
"""

def log_metric(self, metric: str, value: float, step: Optional[int] = None):
"""
Log a metric with its value and an optional step.
"""

def log_message(self, level: int, message: str):
"""
Log a message at a specific logging level.
"""

def on_rollout_end(self, task: Task, rollout: Rollout, runner: AgentRunner, tracer: BaseTracer):
"""
By default, each logger automatically logs the rollout event at the end of each rollout.
"""
self.log_event("rollout", {"task": task.model_dump(), "rollout": rollout.model_dump()})


class ConsoleLogger(LightningLogger):
"""A simple logger that logs messages to the console using Python's logging module."""

def __init__(self, level: int = logging.INFO):
self.logger = configure_logger(level, name="agentlightning.ConsoleLogger")
self.default_level = level
self.worker_id: Optional[int] = None

def init_worker(self, worker_id: int):
super().init_worker(worker_id)
self.worker_id = worker_id

def teardown_worker(self, worker_id: int):
super().teardown_worker(worker_id)
self.worker_id = None

def log_event(self, event: str, data: dict):
data_str = str(data)
if len(data_str) > 512:
data_str = f"{data_str[:512]}... (truncated)"
message = f"Event: {event}, Data: {data_str}"
self.log_message(self.default_level, message)

def log_metric(self, metric: str, value: float, step: Optional[int] = None):
step_str = f" at step {step}" if step is not None else ""
message = f"Metric: {metric} = {value}{step_str}"
self.log_message(self.default_level, message)

def log_message(self, level: int, message: str):
if level >= self.default_level:
if self.worker_id is not None:
message = f"(Worker-{self.worker_id}) {message}"
else:
message = f"(Main) {message}"
self.logger.log(level, message)
# else skip logging if below default level


class WandbLogger(LightningLogger):

def __init__(
self,
project: str,
entity: Optional[str] = None,
name: Optional[str] = None,
config: Optional[dict] = None,
*,
flush_every_n_events: int = 128,
aggregate_every_n_metrics: int = 128,
):
import wandb
from wandb.sdk.wandb_run import Run

self.wandb_run: Optional[Run] = None
self.wandb_run_id: Optional[str] = None

self.project = project
self.entity = entity
self.name = name or wandb.util.generate_id()
self.config = config or {}

self.event_table: Optional[wandb.Table] = None
self.flush_every_n_events = flush_every_n_events
self.aggregate_every_n_metrics = aggregate_every_n_metrics

self.metrics_buffer: dict[str, list[float]] = {}

def init_worker(self, worker_id: int):
import wandb

super().init_worker(worker_id)
self.wandb_run = wandb.init(
project=self.project,
entity=self.entity,
group=self.name,
job_type=f"worker_{worker_id}",
config=self.config,
)
logger.info(f"Wandb run initialized: {self.name} (Worker {worker_id})")
if self.wandb_run is None:
raise RuntimeError("Failed to initialize Wandb run.")
self.wandb_run_id = self.wandb_run.id

def teardown_worker(self, worker_id: int):
import wandb

super().teardown_worker(worker_id)

for metric in self.metrics_buffer:
if len(self.metrics_buffer[metric]) > 0:
self._log_aggregated_metrics(metric)

if len(self.event_table.data) > 0:
logger.info(f"Flushing {len(self.event_table.data)} events to Wandb before finishing...")
wandb.log({"client/events": self.event_table})
self.event_table = None

wandb.finish(exit_code=0)

def teardown(self):
import wandb

super().teardown()
if self.wandb_run is not None:
wandb.finish(exit_code=0)
self.wandb_run = None
self.wandb_run_id = None

def log_event(self, event: str, data: dict):
import wandb

if self.event_table is None:
self.event_table = wandb.Table(columns=["event", "data"])

try:
data_str = json.dumps(data) # Ensure data is JSON serializable
except (TypeError, ValueError):
data_str = str(data)
self.event_table.add_data(event, data_str)

if len(self.event_table.data) % self.flush_every_n_events == 0:
logger.info(f"Flushing {len(self.event_table.data)} events to Wandb...")
wandb.log({"client/events": self.event_table})

def log_metric(self, metric: str, value: float, step: Optional[int] = None):
import wandb

if step is not None:
wandb.log({"client_metric/" + metric: value}, step=step)
else:
wandb.log({"client_metric/" + metric: value})

if metric not in self.metrics_buffer:
self.metrics_buffer[metric] = []
self.metrics_buffer[metric].append(value)
if len(self.metrics_buffer[metric]) >= self.aggregate_every_n_metrics:
self._log_aggregated_metrics(metric, step)
self.metrics_buffer[metric] = []

def log_message(self, level: int, message: str):
pass # Wandb handles logging internally, so we don't need to implement this

def _log_aggregated_metrics(self, metric, step: Optional[int] = None):
import wandb

arr = np.array(self.metrics_buffer[metric])
aggregated_value = {
"mean": float(np.mean(arr)),
"max": float(np.max(arr)),
"min": float(np.min(arr)),
"std": float(np.std(arr)),
"count": int((~np.isnan(arr)).sum()),
}
for key, value in aggregated_value.items():
if value is not None:
if step is not None:
wandb.log({"client_agg/" + metric + "/" + key: value}, step=step)
else:
wandb.log({"client_agg/" + metric + "/" + key: value})
43 changes: 24 additions & 19 deletions agentlightning/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from typing import List, Optional, Union, Dict, Any

import agentops

from opentelemetry.sdk.trace import ReadableSpan

from .client import AgentLightningClient
from .litagent import LitAgent
from .types import Rollout, Task, Triplet, RolloutRawResult
from .types import ParallelWorkerBase
from .types import Rollout, Task, Triplet, RolloutRawResult, ParallelWorkerBase, Hook
from .tracer.base import BaseTracer
from .tracer import TripletExporter

Expand Down Expand Up @@ -43,12 +42,14 @@ def __init__(
triplet_exporter: TripletExporter,
worker_id: Optional[int] = None,
max_tasks: Optional[int] = None,
hooks: Optional[List[Hook]] = None,
):
super().__init__()
self.agent = agent
self.client = client
self.tracer = tracer
self.triplet_exporter = triplet_exporter
self.hooks = hooks or []

# Worker-specific attributes
self.worker_id = worker_id
Expand Down Expand Up @@ -158,10 +159,11 @@ def run(self) -> bool:
rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout

try:
try:
self.agent.on_rollout_start(task, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")
for hook in self.hooks:
try:
hook.on_rollout_start(task, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook: {hook}.")

with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
start_time = time.time()
Expand All @@ -180,10 +182,11 @@ def run(self) -> bool:
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
finally:
try:
self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
for hook in self.hooks:
try:
hook.on_rollout_end(task, rollout_obj, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook: {hook}.")
self.client.post_rollout(rollout_obj)

return True
Expand Down Expand Up @@ -227,10 +230,11 @@ async def run_async(self) -> bool:
rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout

try:
try:
self.agent.on_rollout_start(task, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")
for hook in self.hooks:
try:
hook.on_rollout_start(task, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook: {hook}.")

with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
start_time = time.time()
Expand All @@ -248,10 +252,11 @@ async def run_async(self) -> bool:
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
finally:
try:
self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
for hook in self.hooks:
try:
hook.on_rollout_end(task, rollout_obj, self, self.tracer)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook: {hook}.")
await self.client.post_rollout_async(rollout_obj)

return True
Expand Down
Loading
Loading