Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 73 additions & 47 deletions deepeval/tracing/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __del__(self):

class TraceManager:
def __init__(self):
self._state_lock = threading.RLock()
self.traces: List[Trace] = []
self.active_traces: Dict[str, Trace] = {} # Map of trace_uuid to Trace
self.active_spans: Dict[str, BaseSpan] = (
Expand Down Expand Up @@ -270,10 +271,11 @@ def start_new_trace(
metric_collection=metric_collection,
confident_api_key=self.confident_api_key,
)
self.active_traces[trace_uuid] = new_trace
self.traces.append(new_trace)
if self.evaluation_loop:
self.traces_to_evaluate_order.append(trace_uuid)
with self._state_lock:
self.active_traces[trace_uuid] = new_trace
self.traces.append(new_trace)
if self.evaluation_loop:
self.traces_to_evaluate_order.append(trace_uuid)
# Associate the current Golden with this trace so we can
# later evaluate traces against the correct golden, even if more traces
# are created than goldens or the order interleaves.
Expand All @@ -291,8 +293,10 @@ def start_new_trace(

def end_trace(self, trace_uuid: str):
"""End a specific trace by its UUID."""
with self._state_lock:
if trace_uuid not in self.active_traces:
return

if trace_uuid in self.active_traces:
trace = self.active_traces[trace_uuid]
trace.end_time = (
perf_counter() if trace.end_time is None else trace.end_time
Expand Down Expand Up @@ -354,74 +358,96 @@ def end_trace(self, trace_uuid: str):

def set_trace_status(self, trace_uuid: str, status: TraceSpanStatus):
"""Manually set the status of a trace."""
if trace_uuid in self.active_traces:
trace = self.active_traces[trace_uuid]
trace.status = status
with self._state_lock:
if trace_uuid in self.active_traces:
trace = self.active_traces[trace_uuid]
trace.status = status

def add_span(self, span: BaseSpan):
"""Add a span to the active spans dictionary."""
self.active_spans[span.uuid] = span
with self._state_lock:
self.active_spans[span.uuid] = span

def remove_span(self, span_uuid: str):
"""Remove a span from the active spans dictionary."""
if span_uuid in self.active_spans:
del self.active_spans[span_uuid]
with self._state_lock:
if span_uuid in self.active_spans:
del self.active_spans[span_uuid]

def add_span_to_trace(self, span: BaseSpan):
"""Add a span to its trace."""
trace_uuid = span.trace_uuid
if trace_uuid not in self.active_traces:
raise ValueError(
f"Trace with UUID {trace_uuid} does not exist. A span must have a valid trace."
)
with self._state_lock:
trace_uuid = span.trace_uuid
if trace_uuid not in self.active_traces:
raise ValueError(
f"Trace with UUID {trace_uuid} does not exist. A span must have a valid trace."
)

trace = self.active_traces[trace_uuid]
trace = self.active_traces[trace_uuid]

# If this is a root span (no parent), add it to the trace's root_spans
if not span.parent_uuid:
trace.root_spans.append(span)
else:
# This is a child span, find its parent and add it to the parent's children
parent_span = self.get_span_by_uuid(span.parent_uuid)
if parent_span:
# If this is a root span (no parent), add it to the trace's root_spans
if not span.parent_uuid:
trace.root_spans.append(span)
else:
# This is a child span, find its parent and add it to the parent's children
parent_span = self.get_span_by_uuid(span.parent_uuid)
if parent_span:

if (
parent_span.name == EVAL_DUMMY_SPAN_NAME
): # ignored span for evaluation
span.parent_uuid = None
trace.root_spans.remove(parent_span)
if (
parent_span.name == EVAL_DUMMY_SPAN_NAME
): # ignored span for evaluation
span.parent_uuid = None
trace.root_spans.remove(parent_span)
trace.root_spans.append(span)
return

parent_span.children.append(span)
else:
trace.root_spans.append(span)
return

parent_span.children.append(span)
else:
trace.root_spans.append(span)

def get_trace_by_uuid(self, trace_uuid: str) -> Optional[Trace]:
"""Get a trace by its UUID."""
return self.active_traces.get(trace_uuid)
with self._state_lock:
return self.active_traces.get(trace_uuid)

def get_span_by_uuid(self, span_uuid: str) -> Optional[BaseSpan]:
"""Get a span by its UUID."""
return self.active_spans.get(span_uuid)
with self._state_lock:
return self.active_spans.get(span_uuid)

def is_trace_active(self, trace_uuid: str) -> bool:
with self._state_lock:
return trace_uuid in self.active_traces

def get_active_spans_for_trace(self, trace_uuid: str) -> List[BaseSpan]:
with self._state_lock:
return [
span
for span in self.active_spans.values()
if span.trace_uuid == trace_uuid
]

def get_all_traces(self) -> List[Trace]:
"""Get all traces."""
return self.traces
with self._state_lock:
return list(self.traces)

def clear_traces(self):
"""Clear all traces."""
self.traces = []
self.active_traces = {}
self.active_spans = {}
with self._state_lock:
self.traces = []
self.active_traces = {}
self.active_spans = {}

def get_trace_dict(self, trace: Trace) -> Dict:
"""Convert a trace to a dictionary."""
return dataclass_to_dict(trace)

def get_all_traces_dict(self) -> List[Dict]:
"""Get all traces as dictionaries."""
return [self.get_trace_dict(trace) for trace in self.traces]
with self._state_lock:
traces_snapshot = list(self.traces)
return [self.get_trace_dict(trace) for trace in traces_snapshot]

def _print_trace_status(
self,
Expand Down Expand Up @@ -959,7 +985,7 @@ def __enter__(self):
# (a previous failed async operation might leave a dead trace in context)
if (
current_trace
and current_trace.uuid in trace_manager.active_traces
and trace_manager.is_trace_active(current_trace.uuid)
):
self.trace_uuid = current_trace.uuid
else:
Expand Down Expand Up @@ -1095,11 +1121,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if current_span.status == TraceSpanStatus.ERRORED:
current_trace.status = TraceSpanStatus.ERRORED
if current_trace.uuid == current_span.trace_uuid:
other_active_spans = [
span
for span in trace_manager.active_spans.values()
if span.trace_uuid == current_span.trace_uuid
]
other_active_spans = (
trace_manager.get_active_spans_for_trace(
current_span.trace_uuid
)
)

if not other_active_spans:
trace_manager.end_trace(current_span.trace_uuid)
Expand Down
Loading